In [1]:
!uv pip install -r requirements.txt

[2mAudited [1m155 packages[0m in 81ms[0m


### --- Import Libraries ---


In [2]:
import io
from datetime import datetime, timedelta

import ee
from google.api_core import exceptions, retry

import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured
import requests
from typing import List, Tuple

### --- Constants and Earth Engine Initialization ---


In [3]:
SCALE = 5000  # meters per pixel
WORLD_SCALE = 10_000
LAND_COVER_DATASET = "GOOGLE/DYNAMICWORLD/V1"  # Dynamic World Land Cover dataset
LAND_COVER_BAND = "Map"  # Land cover classification band
WORLD_POLYGONS = [
    # Americas
    [(-33.0, -7.0), (-55.0, 53.0), (-166.0, 65.0), (-68.0, -56.0)],
    # Africa, Asia, Europe
    [
        (74.0, 71.0),
        (166.0, 55.0),
        (115.0, -11.0),
        (74.0, -4.0),
        (20.0, -38.0),
        (-29.0, 25.0),
    ],
    # Australia
    [(170.0, -47.0), (179.0, -37.0), (167.0, -12.0), (128.0, 17.0), (106.0, -29.0)],
]
POLYGON = [(-140.0, 60.0), (-140.0, -60.0), (-10.0, -60.0), (-10.0, 60.0)]

### --- Earth Engine Initialization ---


In [4]:
project = "ee-rohitp934"
# Use cli to authenticate
# !earthengine authenticate

# Or use the following code to authenticate
def initialize_ee():
  ee.Authenticate()
  ee.Initialize(project=project, opt_url="https://earthengine-highvolume.googleapis.com")

In [5]:
initialize_ee()

### --- Data Retreival Functions ---


In [6]:
def get_modis_ndvi(date: datetime) -> ee.Image:
    """Gets MODIS NDVI data for a given date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(1, 'month')
    return (
        ee.ImageCollection("MODIS/061/MOD13Q1")
        .filterDate(start_date, end_date)
        .select("NDVI")
        .first()
    )

In [7]:
def get_landsat_image(date: datetime) -> ee.Image:
    """Gets a Landsat 8 image for the selected date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(6, 'month')
    return (
        ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
        .filterDate(start_date, end_date)
        .mosaic()
    )

In [8]:
def get_landsat_ndvi(image: ee.Image) -> ee.Image:
    """Calculates NDVI from a Landsat 8 image."""
    return image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI")

In [9]:
def get_landsat_lst(image: ee.Image) -> ee.Image:
    """
    Calculates Land Surface Temperature from a Landsat 8 image.
    This function is based on the formula in the following page https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LC08_C02_T1_L2
    """
    return image.select("ST_B10").multiply(0.00341802).add(149.0).rename("LST")

In [10]:
def get_land_cover(date: datetime, advance: int = 6) -> ee.Image:
    """Gets a Land Cover image for the given date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(advance, 'month')
    return (
        ee.ImageCollection(LAND_COVER_DATASET)
        .filterDate(start_date, end_date)
        .select("label")
        .mosaic()
        .unmask(0)  # fill missing values with 0 (water)
        .byte()
    )

### --- Input and Label Image Composition ---


In [11]:
def get_inputs_image(date: datetime) -> ee.Image:
    """Gets an Earth Engine image with all the inputs for the model."""
    # Get MODIS NDVI
    modis_ndvi = get_modis_ndvi(date)

    # Get Landsat data
    landsat_image = get_landsat_image(date)
    landsat_ndvi = get_landsat_ndvi(landsat_image)
    landsat_lst = get_landsat_lst(landsat_image)

    combined_ndvi = ee.Image.cat([landsat_ndvi, modis_ndvi])

    # Combine all input data
    return ee.Image([combined_ndvi, landsat_lst])

In [12]:
def get_labels_image(year: int) -> ee.Image:
    """Gets a Land Cover image for the selected year and preprocesses it."""
    land_cover = get_land_cover(year)
    # Add preprocessing steps if needed (e.g., remapping land cover classes)
    return land_cover

### --- Get input and labels for a given latitude and longitude ---


In [13]:
@retry.Retry(deadline=10 * 60)  # seconds
def get_patch(
    image: ee.Image, lonlat: Tuple[float, float], patch_size: int, scale: int
) -> np.ndarray:
    """Fetches a patch of pixels from Earth Engine."""
    point = ee.Geometry.Point(lonlat)
    url = image.getDownloadURL(
        {
            "region": point.buffer(scale * patch_size / 2, 1).bounds(1),
            "dimensions": [patch_size, patch_size],
            # "scale": SCALE,
            "format": "NPY",
        }
    )

    # Retry on "Too Many Requests" errors
    response = requests.get(url)
    if response.status_code == 429:
        raise exceptions.TooManyRequests(response.text)

    # Raise other exceptions
    response.raise_for_status()
    return np.load(io.BytesIO(response.content), allow_pickle=True)

In [14]:
def get_inputs_patch(
    date: datetime, lonlat: Tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the inputs patch of pixels for the given point and date."""
    image = get_inputs_image(date)
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)


def get_labels_patch(
    date: datetime, lonlat: Tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the labels patch of pixels for the given point and year."""
    image = get_labels_image(date)
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)

## Creating the dataset


### --- Imports ---


In [15]:
import logging
import random
from datetime import datetime

import numpy as np
import os
import uuid
from typing import Tuple, List

### --- Configs ---


In [16]:
NUM_SAMPLES = 250
PARTITION_SIZE = 4
NUM_SAMPLES_PER_PARTITION = NUM_SAMPLES // PARTITION_SIZE
PATCH_SIZE = 128
START_DATE = "2015-01-01"
END_DATE = "2021-01-01"
start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

years = range(start_date.year, end_date.year + 1)
dates = [datetime(year, 1, 1) for year in years]

### --- Sample Points ---


In [17]:
def random_date(start: datetime, end: datetime):
    """Generate a random datetime between `start` and `end`"""
    return start + timedelta(
        # Get a random amount of seconds between `start` and `end`
        seconds=random.randint(0, int((end - start).total_seconds())),
    )

In [18]:
def sample_points(seed: int) -> List[Tuple[datetime, Tuple[float, float]]]:
    global numFailures, numSuccesses, dates
    """Samples points within the defined polygon for the given year."""
    initialize_ee()
    updated_dates: List[datetime] = [d + timedelta(days=(30 * seed)) for d in dates]
    results = []
    for date in updated_dates:
        tries = 2
        advance = 6
        while tries > 0:
            try:
                land_cover = get_land_cover(date, advance)
                # snapshot_date = land_cover.date().format().getInfo()
                # snapshot_date = datetime.strptime(snapshot_date, "%Y-%m-%dT%H:%M:%S")
                points = land_cover.stratifiedSample(
                    numPoints=NUM_SAMPLES_PER_PARTITION,
                    region=ee.Geometry.MultiPolygon(WORLD_POLYGONS),
                    scale=WORLD_SCALE,
                    geometries=True,
                    seed=seed,
                    classBand="label",
                    tileScale=4,
                )
                if int(points.size().getInfo()) > 0:
                    features = points.getInfo()["features"]
                    for feature in features:
                        results.append((date, feature["geometry"]["coordinates"]))
                    break
                tries -= 1
            except Exception as e:
                tries -= 1
                advance = 3
                print(f"Error occurred for {date}: {e} ... Retrying {tries} times")

    return results

### --- Prepare Training Data ---


In [19]:
def get_training_example(
    date: datetime, point: Tuple[float, float]
) -> Tuple[np.ndarray, np.ndarray]:
    """Gets an (inputs, labels) training example for land cover change prediction."""
    # dask.distributed.print(f"Getting inputs for {date} at {point}")
    inputs = get_inputs_patch(date, point, PATCH_SIZE)
    # Get land cover for the next day
    # dask.distributed.print(f"Getting labels for {date} at {point}")
    labels = get_labels_patch(date + timedelta(days=1), point, PATCH_SIZE)
    return (inputs, labels)

In [20]:
def try_get_example(
    sample_point: Tuple[datetime, Tuple[float, float]],
) -> Tuple[np.ndarray, np.ndarray]:
    """Wrapper to handle errors during training data generation."""
    ee.Initialize(project=project)
    date, point = sample_point
    try:
        return get_training_example(date, point)
    except Exception as e:
        print(f"Error occurred: {e}")

# --- Dask Workflow for Dataset Creation ---

In [21]:
def write_npz(data: List[Tuple[np.ndarray, np.ndarray]], data_path: str) -> str:
    """Writes an (inputs, labels) set of data into a compressed NumPy file.

    Args:
        batch: Batch of (inputs, labels) pairs of NumPy arrays.
        data_path: Directory path to save files to.

    Returns: The filename of the data file.
    """
    initialize_ee()
    filename = os.path.join(data_path, f"{uuid.uuid4()}.npz")
    with open(filename, "xb") as f:
        inputs = [x for (x, _) in data]
        labels = [y for (_, y) in data]
        np.savez_compressed(f, inputs=inputs, labels=labels)
    logging.info(filename)
    return filename

In [22]:
start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

years = range(start_date.year, end_date.year + 1)
dates = [datetime(year, 1, 1) for year in years if year != 2020]

# random_dates = [random_date(start_date, end_date) for _ in range(samples)]

# dates = 
print(dates)

[datetime.datetime(2015, 1, 1, 0, 0), datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2017, 1, 1, 0, 0), datetime.datetime(2018, 1, 1, 0, 0), datetime.datetime(2019, 1, 1, 0, 0), datetime.datetime(2021, 1, 1, 0, 0)]


In [23]:
points = None

### --- Perform Dataset Creation ---

In [25]:
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions

start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

years = range(start_date.year, end_date.year + 1)
dates = [datetime(year, 1, 1) for year in years]
beam_options = PipelineOptions(
    [],
    save_main_session=True,
    # setup_file="./setup.py",
    # max_num_workers=max_requests,
    direct_running_mode="multi_threading",  # distributed runners
    direct_num_workers=0,  # direct runner
    # disk_size_gb=50,
)
with beam.Pipeline(options=beam_options) as pipeline:
    # if not points:
    points = (
        pipeline
        | "🌱 Make seeds" >> beam.Create(range(4))
        | "📌 Sample points" >> beam.FlatMap(sample_points)
        
    )

    (
        points
        | "🃏 Reshuffle" >> beam.Reshuffle()
        | "📑 Get examples" >> beam.Map(try_get_example)
        | "🗂️ Batch examples" >> beam.BatchElements(10)
        | "📝 Write NPZ Files" >> beam.Map(write_npz, data_path="data/climate_change")
        # | "Print" >> beam.Map(print)
    )



Error occurred for 2018-03-02 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2018-01-31 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2018-04-01 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2019-04-01 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2019-01-01 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2019-03-02 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2020-03-31 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2020-01-01 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2021-01-01 00:00:00: User memory limit exceeded. ... Retrying 1 times




Error occurred for 2021-04-01 00:00:00: User memory limit exceeded. ... Retrying 1 times
Error occurred for 2021-03-02 00:00:00: User memory limit exceeded. ... Retrying 1 times


