In [None]:
!uv pip install -r requirements.txt

### --- Import Libraries ---


In [6]:
import io
from datetime import datetime, timedelta

import ee
from google.api_core import exceptions, retry
import google.auth
import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured
import requests

### --- Constants and Earth Engine Initialization ---


In [7]:
SCALE = 10  # meters per pixel
LAND_COVER_DATASET = "GOOGLE/DYNAMICWORLD/V1"  # Dynamic World Land Cover dataset
LAND_COVER_BAND = "Map"  # Land cover classification band
WORLD_POLYGONS = [
    # Americas
    [(-33.0, -7.0), (-55.0, 53.0), (-166.0, 65.0), (-68.0, -56.0)],
    # Africa, Asia, Europe
    [
        (74.0, 71.0),
        (166.0, 55.0),
        (115.0, -11.0),
        (74.0, -4.0),
        (20.0, -38.0),
        (-29.0, 25.0),
    ],
    # Australia
    [(170.0, -47.0), (179.0, -37.0), (167.0, -12.0), (128.0, 17.0), (106.0, -29.0)],
]

### --- Earth Engine Initialization ---


In [8]:
project = "ee-rohitp934"
# Use cli to authenticate
# !earthengine authenticate

# Or use the following code to authenticate
def initialize_ee():
  ee.Authenticate()
  ee.Initialize(project=project, opt_url="https://earthengine-highvolume.googleapis.com")

### --- Data Retreival Functions ---


In [9]:
def get_modis_ndvi(date: datetime) -> ee.Image:
    """Gets MODIS NDVI data for a given date."""
    return (
        ee.ImageCollection("MODIS/006/MOD13A2")
        .filterDate(date, date + timedelta(days=1))
        .select("NDVI")
        .first()
    )

In [10]:
def get_landsat_image(date: datetime) -> ee.Image:
    """Gets a Landsat 8 image for the selected date."""
    return (
        ee.ImageCollection("LANDSAT/LC08/C01/T1_SR")
        .filterDate(date, date + timedelta(days=1))
        .filterBounds(ee.Geometry.Polygon(WORLD_POLYGONS))
        .sort("CLOUD_COVER")
        .first()
    )

In [11]:
def get_landsat_ndvi(image: ee.Image) -> ee.Image:
    """Calculates NDVI from a Landsat 8 image."""
    return image.normalizedDifference(["B5", "B4"]).rename("NDVI")

In [12]:
def get_landsat_lst(image: ee.Image) -> ee.Image:
    """
    Calculates Land Surface Temperature from a Landsat 8 image.
    This function is based on the formula in the following page https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LC08_C02_T1_L2
    """
    return image.select("ST_B10").multiply(0.00341802).add(149.0).rename("LST")

In [45]:
def get_land_cover(date: datetime) -> ee.Image:
    """Gets a Land Cover image for the given date."""
    return (
        ee.ImageCollection(LAND_COVER_DATASET)
        .filterDate(date, date + timedelta(days=1))
        .select("label")
        .first()
        .rename("landcover")
        .unmask(0)  # fill missing values with 0 (water)
        .byte()
    )

### --- Input and Label Image Composition ---


In [15]:
def get_inputs_image(date: datetime) -> ee.Image:
    """Gets an Earth Engine image with all the inputs for the model."""
    # Get MODIS NDVI
    modis_ndvi = get_modis_ndvi(date)

    # Get Landsat data
    landsat_image = get_landsat_image(date)
    landsat_ndvi = get_landsat_ndvi(landsat_image)
    landsat_lst = get_landsat_lst(landsat_image)

    # Combine all input data
    return ee.Image([modis_ndvi, landsat_ndvi, landsat_lst])

In [16]:
def get_labels_image(year: int) -> ee.Image:
    """Gets a Land Cover image for the selected year and preprocesses it."""
    land_cover = get_land_cover(year)
    # Add preprocessing steps if needed (e.g., remapping land cover classes)
    return land_cover

### --- Get input and labels for a given latitude and longitude ---


In [17]:
@retry.Retry(deadline=10 * 60)  # seconds
def get_patch(
    image: ee.Image, lonlat: tuple[float, float], patch_size: int, scale: int
) -> np.ndarray:
    """Fetches a patch of pixels from Earth Engine."""
    point = ee.Geometry.Point(lonlat)
    url = image.getDownloadURL(
        {
            "region": point.buffer(scale * patch_size / 2, 1).bounds(1),
            "dimensions": [patch_size, patch_size],
            "format": "NPY",
        }
    )

    # Retry on "Too Many Requests" errors
    response = requests.get(url)
    if response.status_code == 429:
        raise exceptions.TooManyRequests(response.text)

    # Raise other exceptions
    response.raise_for_status()
    return np.load(io.BytesIO(response.content), allow_pickle=True)

In [18]:
def get_inputs_patch(
    date: datetime, lonlat: tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the inputs patch of pixels for the given point and date."""
    image = get_inputs_image(date)
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)


def get_labels_patch(
    date: datetime, lonlat: tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the labels patch of pixels for the given point and year."""
    image = get_labels_image(date)
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)

## Creating the dataset


### --- Imports ---


In [31]:
import logging
import random
from datetime import datetime, timedelta

import dask
import dask.bag as db
from dask.distributed import Client
import numpy as np
import os
import pandas as pd
import uuid

### --- Configs ---


In [20]:
NUM_SAMPLES = 1000
PATCH_SIZE = 128
PARTITION_SIZE = 10
START_DATE = "2015-07-01"
END_DATE = "2021-12-01"

### --- Sample Points ---


In [21]:
def sample_points(date: datetime) -> tuple:
    """Samples points within the defined polygon for the given year."""
    land_cover = get_land_cover(date)
    points = land_cover.stratifiedSample(
        numPoints=1,
        region=ee.Geometry.Polygon(WORLD_POLYGONS),
        scale=SCALE,
        geometries=True,
    )
    point = points.toList(points.size()).getInfo()[0]
    return (date, point["geometry"]["coordinates"])

### --- Prepare Training Data ---


In [22]:
def get_training_example(date: datetime, point: tuple) -> tuple:
    """Gets an (inputs, labels) training example for land cover change prediction."""
    inputs = get_inputs_patch(date, point, PATCH_SIZE)
    # Get land cover for the next day
    labels = get_labels_patch(date + timedelta(days=1), point, PATCH_SIZE)
    return (inputs, labels)

In [47]:
import dask.distributed


def try_get_example(date: datetime, point: tuple) -> tuple | None:
    """Wrapper to handle errors during training data generation."""
    initialize_ee()
    dask.distributed.print(f"Generating training data for {date} at {point}")
    try:
        return get_training_example(date, point)
    except Exception as e:
        dask.distributed.print(f"Error occurred: {e}")

In [32]:
def random_date(start: datetime, end: datetime):
    """Generate a random datetime between `start` and `end`"""
    return start + timedelta(
        # Get a random amount of seconds between `start` and `end`
        seconds=random.randint(0, int((end - start).total_seconds())),
    )

# --- Dask Workflow for Dataset Creation ---

In [48]:
def write_npz(data: list[tuple[np.ndarray, np.ndarray]], data_path: str) -> str:
    """Writes an (inputs, labels) set of data into a compressed NumPy file.

    Args:
        batch: Batch of (inputs, labels) pairs of NumPy arrays.
        data_path: Directory path to save files to.

    Returns: The filename of the data file.
    """
    initialize_ee()
    dask.distributed.print(f"Writing {len(data)} data points to {data_path}")
    filename = os.path.join(data_path, f"{uuid.uuid4()}.npz")
    with open(filename, "xb") as f:
        inputs = [x for (x, _) in data]
        labels = [y for (_, y) in data]
        np.savez_compressed(f, inputs=inputs, labels=labels)
    logging.info(filename)
    return filename

In [53]:
def run(data_path: str, samples: int = NUM_SAMPLES) -> None:
    """Runs the Dask workflow to generate the dataset."""

    # Generate dates from the start date to the end date
    start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
    end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

    random_dates = [random_date(start_date, end_date) for _ in range(samples)]

    with Client() as client:  # Start a Dask client
        # Authenticate and initialize Earth Engine.
        initialize_ee()
        print(client)

        def wrapper(data, data_path):
            return write_npz(data, data_path)

        bag = db.from_sequence(random_dates, npartitions=PARTITION_SIZE).map(
            sample_points
        )
        training_data = bag.map(try_get_example).filter(lambda x: x is not None)

        training_data.map_partitions(wrapper, data_path=data_path).compute()

### --- Perform Dataset Creation ---

In [54]:
logging.getLogger().setLevel(logging.INFO)
run("data/climate_change/")

<Client: 'tcp://127.0.0.1:64672' processes=4 threads=8, memory=16.00 GiB>


Key:       ('filter-lambda-sample_points-try_get_example-wrapper-b9d7c970f068fadae26323a23cc9a1c7', 9)
Function:  execute_task
args:      ((subgraph_callable-3e9823a451188dfaf9dda3b26adb2f91, (<class 'filter'>, <function run.<locals>.<lambda> at 0x10b939760>, (<function map_chunk at 0x10ba962a0>, <function try_get_example at 0x10b9bb920>, [(<function map_chunk at 0x10ba962a0>, <function sample_points at 0x10bce0860>, [[datetime.datetime(2017, 2, 8, 9, 10, 37), datetime.datetime(2021, 1, 1, 21, 1, 4), datetime.datetime(2021, 9, 4, 21, 56, 18), datetime.datetime(2020, 10, 16, 9, 50, 29), datetime.datetime(2018, 5, 1, 9, 1, 41), datetime.datetime(2015, 10, 29, 16, 45, 31), datetime.datetime(2020, 3, 22, 1, 40, 39), datetime.datetime(2018, 2, 17, 2, 40, 16), datetime.datetime(2020, 11, 20, 9, 36, 25), datetime.datetime(2015, 11, 27, 0, 17, 39), datetime.datetime(2017, 7, 9, 15, 55, 9), datetime.datetime(2020, 8, 19, 15, 7, 54), datetime.datetime(2018, 3, 27, 1, 19, 9), datetime.datetime(20

TypeError: object of type 'filter' has no len()