In [1]:
import json
from datetime import datetime
from typing import Dict, TypedDict, Union

import boto3
import numpy as np
import rasterio
from langchain.chains import LLMChain
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableLambda
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from pydantic import BaseModel
from pystac_client import Client
from rasterio.mask import mask
from rasterio.session import AWSSession
from shapely.geometry import shape
from stackstac import stack

NL_classification_values = {
    2: "natural forests",
    3: "natural short vegetation",
    4: "natural water",
    5: "mangroves",
    6: "bare",
    7: "snow",
    8: "wet natural forests",
    9: "natural peat forests",
    10: "wet natural short vegetation",
    11: "natural peat short vegetation",
    12: "crop",
    13: "built",
    14: "non-natural tree cover",
    15: "non-natural short vegetation",
    16: "non-natural water",
    17: "wet non-natural tree cover",
    18: "non-natural peat tree cover",
    19: "wet non-natural short vegetation",
    20: "non-natural peat short vegetation",
    21: "non-natural bare"
}

DEFAULT_COG_PATH = "s3://gfw-data-lake/umd_glad_dist_alerts/v20250329/raster/epsg-4326/cog/default.tif"
INTENSITY_COG_PATH = "s3://gfw-data-lake/umd_glad_dist_alerts/v20250329/raster/epsg-4326/cog/intensity.tif"
STAC_API_URL = "https://eoapi.zeno-staging.ds.io/stac"
AWS_PROFILE = "zeno_internal_sso"
REFERENCE_DATE = datetime(2015, 1, 1)

def convert_date_range_to_days(date_range):
    start_date = datetime.strptime(str(date_range[0])[:10], "%Y-%m-%d")
    end_date = datetime.strptime(str(date_range[1])[:10], "%Y-%m-%d")
    return (start_date - REFERENCE_DATE).days, (end_date - REFERENCE_DATE).days, start_date, end_date

def load_alert_data(geometry):
    session = boto3.Session(profile_name=AWS_PROFILE)
    with rasterio.Env(AWSSession(session), AWS_REQUEST_PAYER="requester"):
        with rasterio.open(DEFAULT_COG_PATH) as src1:
            default_data, _ = mask(src1, geometry, crop=True)
        with rasterio.open(INTENSITY_COG_PATH) as src2:
            intensity_data, _ = mask(src2, geometry, crop=True)
    return default_data[0], intensity_data[0]

def load_natural_lands_mosaic(aoi_geom, start_date, end_date):
    stac = Client.open(STAC_API_URL)
    search = stac.search(
        collections=["natural-lands-map-v1-1"],
        intersects=aoi_geom,
        datetime=f"{start_date.date().isoformat()}/{end_date.date().isoformat()}",
        max_items=50,
    )
    items = list(search.get_items())
    if not items:
        raise ValueError("No Natural Lands items found for AOI.")

    print(f"Found {len(items)} items")

    da = stack(
        items,
        bounds_latlon=aoi_geom.bounds,
        snap_bounds=True,
        epsg=4326
    )

    # Apply chunking
    da = da.chunk({"x": 1024, "y": 1024})


    # Collapse time dimension if multiple timestamps
    if "time" in da.dims:
        da = da.astype("int16").max("time")

    return da.squeeze()


#@tool
def filter_and_count_pixels(state: Dict) -> Dict:
    """Filters pixels from COGs using confidence, intensity, and date thresholds, then counts land cover class occurrences."""
    parsed = state["parsed_params"]

    print("Parsed parameters:", parsed)

    start_days, end_days, start_date, end_date = convert_date_range_to_days(
        (parsed["start_date"], parsed["end_date"])
    )
    geometry = [shape(parsed["aoi"]["geometry"])]

    # Load and process data
    encoded, intensities = load_alert_data(geometry)
    confidence = encoded // 10000
    days_since_2015 = encoded % 10000

    land_cover = load_natural_lands_mosaic(shape(parsed["aoi"]["geometry"]), start_date, end_date)
    land_cover = land_cover.squeeze()

    if land_cover.shape != confidence.shape:
        raise ValueError(f"Shape mismatch: {land_cover.shape} vs {confidence.shape}")

    max_days = (datetime.today() - REFERENCE_DATE).days
    days_since_2015[days_since_2015 == 9999] = days_since_2015[days_since_2015 < 9999].max()

    valid_mask = (
        (confidence >= parsed["confidence_threshold"]) &
        (days_since_2015 >= start_days) &
        (days_since_2015 <= end_days) &
        (days_since_2015 < max_days) &
        (intensities >= parsed["intensity_threshold"])
    )

    results = {}
    valid_lc = land_cover.values[valid_mask]
    unique_classes, counts = np.unique(valid_lc, return_counts=True)
    for cls, count in zip(unique_classes, counts):
        results[int(cls)] = {
            "class": NL_classification_values.get(int(cls), "Unknown"),
            "count": int(count)
        }

    result = {
        "summary_json": {
            "start_date": parsed["start_date"],
            "end_date": parsed["end_date"],
            "current_max_days_possible_based_on_today": max_days,
            "current_min_days_since_2015_in_aoi": days_since_2015.min(),
            "current_max_days_since_2015_in_aoi": days_since_2015.max(),
            "confidence_threshold": parsed["confidence_threshold"],
            "intensity_threshold": parsed["intensity_threshold"],
            "results": results
        }
    }

    print(result)

    return result


class ParsedParams(BaseModel):
    start_date: str
    end_date: str
    confidence_threshold: float
    intensity_threshold: int
    aoi: Union[str, Dict]  # named region or GeoJSON


class GraphState(TypedDict, total=False):
    input: str
    parsed_params: Dict
    results: Dict
    summary: str

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

class ParseParamsNode:
    def __init__(self):
        self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
        self.prompt = PromptTemplate(
            input_variables=["input_text"],
            template="""
Extract the following parameters from the input text:
- start_date (YYYY-MM-DD)
- end_date (YYYY-MM-DD)
- confidence_threshold (float)
- intensity_threshold (int)
- aoi (string or GeoJSON)

Input:
{input_text}

Return a JSON object only with these keys and values.
Please output ONLY valid JSON with all braces closed. Do not add any extra text or explanations.
Make sure the JSON is valid and complete.
"""
        )
        self.chain = LLMChain(llm=self.llm, prompt=self.prompt)

    def run(self, input_text: str) -> ParsedParams:
        response = self.chain.run(input_text=input_text)

        # Parse JSON output
        try:
            params_dict = json.loads(response)
        except json.JSONDecodeError as e:
            raise ValueError(f"Failed to parse JSON from LLM output: {response}") from e

        # Return Pydantic model
        return ParsedParams(**params_dict)


parser_node = ParseParamsNode()

def parse_message(state: GraphState) -> GraphState:
    parsed = parser_node.run(state["input"])
    return {**state, "parsed_params": parsed}

def filter_and_count_node(state: GraphState) -> GraphState:
    result = filter_and_count_pixels({"parsed_params": state["parsed_params"].dict()})
    return {**state, "results": result["summary_json"]["results"]}

summary_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful assistant summarizing geospatial alert statistics."),
    ("human", "Here are the results:\n{summary_json}")
])
summarize_chain = summary_prompt | llm | RunnableLambda(lambda msg: {"summary": msg.content})

def summarize_results(state: GraphState) -> GraphState:
    summary_text = summarize_chain.invoke({"summary_json": state["results"]})["summary"]
    return {**state, "summary": summary_text}

builder = StateGraph(GraphState)

builder.add_node("parse_message", parse_message)
builder.add_node("filter_and_count", filter_and_count_node)
builder.add_node("summarize_results", summarize_results)

builder.set_entry_point("parse_message")
builder.add_edge("parse_message", "filter_and_count")
builder.add_edge("filter_and_count", "summarize_results")
builder.add_edge("summarize_results", END)

graph = builder.compile()


  self.chain = LLMChain(llm=self.llm, prompt=self.prompt)


In [2]:
geojson_str = """
{
  "type": "Feature",
  "geometry": {
    "type": "Polygon",
    "coordinates": [
      [
        [-75.0, -3.0],
        [-70.0, -3.0],
        [-70.0, -6.0],
        [-75.0, -6.0],
        [-75.0, -3.0]
      ]
    ]
  }
}
"""

initial_state = {
    "input": (
        "Find alerts from January 1, 2019 to December 31, 2020, with confidence of 3 and intensity above 50 "
        "in this AOI (in GeoJSON format): ```json\n" +
        geojson_str.strip() +
        "\n```"
    )
}


final_state = graph.invoke(initial_state)


  response = self.chain.run(input_text=input_text)
/var/folders/3l/ltxhyhhn7jn3xwtvypy7f9ym0000gn/T/ipykernel_41987/2415962420.py:221: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  result = filter_and_count_pixels({"parsed_params": state["parsed_params"].dict()})


Parsed parameters: {'start_date': '2019-01-01', 'end_date': '2020-12-31', 'confidence_threshold': 3.0, 'intensity_threshold': 50, 'aoi': {'type': 'Feature', 'geometry': {'type': 'Polygon', 'coordinates': [[[-75.0, -3.0], [-70.0, -3.0], [-70.0, -6.0], [-75.0, -6.0], [-75.0, -3.0]]]}}}




Found 2 items
{'summary_json': {'start_date': '2019-01-01', 'end_date': '2020-12-31', 'current_max_days_possible_based_on_today': 3799, 'current_min_days_since_2015_in_aoi': 880, 'current_max_days_since_2015_in_aoi': 1543, 'confidence_threshold': 3.0, 'intensity_threshold': 50, 'results': {2: {'class': 'natural forests', 'count': 659}, 4: {'class': 'natural water', 'count': 555}, 6: {'class': 'bare', 'count': 336}, 8: {'class': 'wet natural forests', 'count': 271}, 9: {'class': 'natural peat forests', 'count': 213}, 10: {'class': 'wet natural short vegetation', 'count': 115}, 11: {'class': 'natural peat short vegetation', 'count': 188}, 12: {'class': 'crop', 'count': 3096}, 13: {'class': 'built', 'count': 20}, 15: {'class': 'non-natural short vegetation', 'count': 9}}}}


In [3]:
final_state

{'input': 'Find alerts from January 1, 2019 to December 31, 2020, with confidence of 3 and intensity above 50 in this AOI (in GeoJSON format): ```json\n{\n  "type": "Feature",\n  "geometry": {\n    "type": "Polygon",\n    "coordinates": [\n      [\n        [-75.0, -3.0],\n        [-70.0, -3.0],\n        [-70.0, -6.0],\n        [-75.0, -6.0],\n        [-75.0, -3.0]\n      ]\n    ]\n  }\n}\n```',
 'parsed_params': ParsedParams(start_date='2019-01-01', end_date='2020-12-31', confidence_threshold=3.0, intensity_threshold=50, aoi={'type': 'Feature', 'geometry': {'type': 'Polygon', 'coordinates': [[[-75.0, -3.0], [-70.0, -3.0], [-70.0, -6.0], [-75.0, -6.0], [-75.0, -3.0]]]}}),
 'results': {2: {'class': 'natural forests', 'count': 659},
  4: {'class': 'natural water', 'count': 555},
  6: {'class': 'bare', 'count': 336},
  8: {'class': 'wet natural forests', 'count': 271},
  9: {'class': 'natural peat forests', 'count': 213},
  10: {'class': 'wet natural short vegetation', 'count': 115},
  11:

In [4]:
print("Final Summary:", final_state["summary"])

Final Summary: Here is a summary of the geospatial alert statistics:

- Natural forests: 659 alerts
- Natural water: 555 alerts
- Bare land: 336 alerts
- Wet natural forests: 271 alerts
- Natural peat forests: 213 alerts
- Wet natural short vegetation: 115 alerts
- Natural peat short vegetation: 188 alerts
- Crop fields: 3096 alerts
- Built-up areas: 20 alerts
- Non-natural short vegetation: 9 alerts
