In [1]:
import re
import json
from datetime import datetime
from typing import Dict, TypedDict, Any, Tuple
import numpy as np
import boto3
import rasterio
from rasterio.mask import mask
from rasterio.session import AWSSession
from shapely.geometry import shape
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from langchain.tools import tool


def parse_user_message(state: dict) -> dict:
    """
    Parse user message from `state["message"]`, returning aoi, date_range, and thresholds.
    """
    message = state["message"]
    geojson_match = re.search(r'\{.*\}', message, re.DOTALL)
    geojson_str = geojson_match.group(0) if geojson_match else None
    geojson = json.loads(geojson_str)

    date_match = re.search(r'date range:\s*(\d{4}-\d{2}-\d{2})\s*to\s*(\d{4}-\d{2}-\d{2})', message)
    start_date, end_date = date_match.groups()

    conf_match = re.search(r'confidence:\s*([0-9.]+)', message)
    intensity_match = re.search(r'intensity:\s*([0-9.]+)', message)

    return {
        "aoi": geojson,
        "date_range": (start_date, end_date),
        "thresholds": {
            "confidence": float(conf_match.group(1)),
            "intensity": int(float(intensity_match.group(1)))
        }
    }


In [2]:
def filter_and_count_pixels(inputs: Dict) -> Dict:
    """Filter COG pixels based on AOI, date range, and thresholds; return pixel count."""
    
    aoi = inputs["aoi"]
    date_range = inputs["date_range"]
    thresholds = inputs["thresholds"]

    # Parse date range into days since 2015-01-01
    start_date = datetime.strptime(date_range[0], "%Y-%m-%d")
    end_date = datetime.strptime(date_range[1], "%Y-%m-%d")
    start_days = (start_date - datetime(2015, 1, 1)).days
    end_days = (end_date - datetime(2015, 1, 1)).days

    default_path = "s3://gfw-data-lake/umd_glad_dist_alerts/v20250329/raster/epsg-4326/cog/default.tif"
    intensity_path = "s3://gfw-data-lake/umd_glad_dist_alerts/v20250329/raster/epsg-4326/cog/intensity.tif"

    session = boto3.Session(profile_name="zeno_internal_sso")
    geometry = [shape(aoi["geometry"]).__geo_interface__]

    with rasterio.Env(AWSSession(session), AWS_REQUEST_PAYER="requester"):
        with rasterio.open(default_path) as src1:
            default_data, _ = mask(src1, geometry, crop=True)
        with rasterio.open(intensity_path) as src2:
            intensity_data, _ = mask(src2, geometry, crop=True)

    encoded = default_data[0]
    confidence = encoded // 10000
    days_since_2015 = encoded % 10000

    # Intensity band
    intensities = intensity_data[0]

    print("requested start and end days since 2015: ", start_days, end_days)

    # Calculate max number of days since start of 2015
    reference_date = datetime(2015, 1, 1)
    today = datetime.today()
    max_days_since_2015 = (today - reference_date).days
    print(f"days since 2015 in this area of interest — min: {np.min(days_since_2015)}, max: {np.max(days_since_2015[days_since_2015 < max_days_since_2015])}")


    valid_mask = (
        (confidence >= thresholds["confidence"]) &
        (days_since_2015 >= start_days) &
        (days_since_2015 <= end_days) &
        (days_since_2015 < max_days_since_2015) &
        (intensities >= thresholds["intensity"])
    )

    return {"result": int(np.count_nonzero(valid_mask))}

In [3]:
class GraphState(TypedDict):
    aoi: dict         # parsed GeoJSON
    date_range: tuple[str, str]
    thresholds: dict  # {"confidence": int, "intensity": int}
    result: Any       # final pixel count a.k.a. result object


In [4]:
class GraphState(TypedDict):
    message: str
    aoi: dict
    date_range: Tuple[str, str]
    thresholds: Dict[str, float]
    result: Dict[str, Any]

builder = StateGraph(GraphState)

builder.add_node("parse_message", RunnableLambda(parse_user_message))
builder.add_node("filter_and_count", RunnableLambda(filter_and_count_pixels))

builder.set_entry_point("parse_message")
builder.add_edge("parse_message", "filter_and_count")
builder.add_edge("filter_and_count", END)

graph = builder.compile()


In [5]:
#2 is low confidence, 3 is high confidence (1 means nothing and the values should all be >=20,000 and <40,000)

In [6]:
# Test AOI over the Peruvian Amazon (Loreto and Ucayali regions)

message = """
{
  "type": "Feature",
  "geometry": {
    "type": "Polygon",
    "coordinates": [
      [
        [-75.0, -3.0],
        [-70.0, -3.0],
        [-70.0, -6.0],
        [-75.0, -6.0],
        [-75.0, -3.0]
      ]
    ]
  }
}
date range: 2016-01-01 to 2020-12-31
confidence: 3
intensity: 50
"""

result = graph.invoke({"message": message})
print(result)



requested start and end days since 2015:  365 2191
days since 2015 in this area of interest — min: 880, max: 1543
{'message': '\n{\n  "type": "Feature",\n  "geometry": {\n    "type": "Polygon",\n    "coordinates": [\n      [\n        [-75.0, -3.0],\n        [-70.0, -3.0],\n        [-70.0, -6.0],\n        [-75.0, -6.0],\n        [-75.0, -3.0]\n      ]\n    ]\n  }\n}\ndate range: 2016-01-01 to 2020-12-31\nconfidence: 3\nintensity: 50\n', 'aoi': {'type': 'Feature', 'geometry': {'type': 'Polygon', 'coordinates': [[[-75.0, -3.0], [-70.0, -3.0], [-70.0, -6.0], [-75.0, -6.0], [-75.0, -3.0]]]}}, 'date_range': ('2016-01-01', '2020-12-31'), 'thresholds': {'confidence': 3.0, 'intensity': 50}, 'result': 495567}
