In [None]:
!uv pip install -r requirements.txt

### --- Import Libraries ---


In [1]:
import io
from datetime import datetime, timedelta

import ee
from google.api_core import exceptions, retry
import google.auth
import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured
import requests

### --- Constants and Earth Engine Initialization ---


In [45]:
SCALE = 5000  # meters per pixel
WORLD_SCALE = 10_000
LAND_COVER_DATASET = "GOOGLE/DYNAMICWORLD/V1"  # Dynamic World Land Cover dataset
LAND_COVER_BAND = "Map"  # Land cover classification band
WORLD_POLYGONS = [
    # Americas
    [(-33.0, -7.0), (-55.0, 53.0), (-166.0, 65.0), (-68.0, -56.0)],
    # Africa, Asia, Europe
    [
        (74.0, 71.0),
        (166.0, 55.0),
        (115.0, -11.0),
        (74.0, -4.0),
        (20.0, -38.0),
        (-29.0, 25.0),
    ],
    # Australia
    [(170.0, -47.0), (179.0, -37.0), (167.0, -12.0), (128.0, 17.0), (106.0, -29.0)],
]
POLYGON = [(-140.0, 60.0), (-140.0, -60.0), (-10.0, -60.0), (-10.0, 60.0)]

### --- Earth Engine Initialization ---


In [3]:
project = "ee-rohitp934"
# Use cli to authenticate
# !earthengine authenticate

# Or use the following code to authenticate
def initialize_ee():
  ee.Authenticate()
  ee.Initialize(project=project, opt_url="https://earthengine-highvolume.googleapis.com")

In [22]:
initialize_ee()

### --- Data Retreival Functions ---


In [4]:
def get_modis_ndvi(date: datetime) -> ee.Image:
    """Gets MODIS NDVI data for a given date."""
    return (
        ee.ImageCollection("MODIS/006/MOD13A2")
        .filterDate(date, date + timedelta(days=1))
        .select("NDVI")
        .first()
    )

In [5]:
def get_landsat_image(date: datetime) -> ee.Image:
    """Gets a Landsat 8 image for the selected date."""
    return (
        ee.ImageCollection("LANDSAT/8/C01/T1_SR")
        .filterDate(date, date  + timedelta(days=1))
        .mosaic()
    )

In [6]:
def get_landsat_ndvi(image: ee.Image) -> ee.Image:
    """Calculates NDVI from a Landsat 8 image."""
    return image.normalizedDifference(["B5", "B4"]).rename("NDVI")

In [7]:
def get_landsat_lst(image: ee.Image) -> ee.Image:
    """
    Calculates Land Surface Temperature from a Landsat 8 image.
    This function is based on the formula in the following page https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LC08_C02_T1_L2
    """
    return image.select("ST_B10").multiply(0.00341802).add(149.0).rename("LST")

In [56]:
def get_land_cover(date: datetime) -> ee.Image:
    """Gets a Land Cover image for the given date."""
    return (
        ee.ImageCollection(LAND_COVER_DATASET)
        .filterDate(date, date + timedelta(days=30))
        .select("label")
        .first()
        # .rename("landcover")
        .unmask(0)  # fill missing values with 0 (water)
        .byte()
    )

### --- Input and Label Image Composition ---


In [9]:
def get_inputs_image(date: datetime) -> ee.Image:
    """Gets an Earth Engine image with all the inputs for the model."""
    # Get MODIS NDVI
    modis_ndvi = get_modis_ndvi(date)

    # Get Landsat data
    landsat_image = get_landsat_image(date)
    landsat_ndvi = get_landsat_ndvi(landsat_image)
    landsat_lst = get_landsat_lst(landsat_image)

    # Combine all input data
    return ee.Image([modis_ndvi, landsat_ndvi, landsat_lst])

In [10]:
def get_labels_image(year: int) -> ee.Image:
    """Gets a Land Cover image for the selected year and preprocesses it."""
    land_cover = get_land_cover(year)
    # Add preprocessing steps if needed (e.g., remapping land cover classes)
    return land_cover

### --- Get input and labels for a given latitude and longitude ---


In [None]:
@retry.Retry(deadline=10 * 60)  # seconds
def get_patch(
    image: ee.Image, lonlat: tuple[float, float], patch_size: int, scale: int
) -> np.ndarray:
    """Fetches a patch of pixels from Earth Engine."""
    point = ee.Geometry.Point(lonlat)
    url = image.getDownloadURL(
        {
            "region": point.buffer(scale * patch_size / 2, 1).bounds(1),
            "dimensions": [patch_size, patch_size],
            "scale": SCALE,
            "format": "NPY",
        }
    )

    # Retry on "Too Many Requests" errors
    response = requests.get(url)
    if response.status_code == 429:
        raise exceptions.TooManyRequests(response.text)

    # Raise other exceptions
    response.raise_for_status()
    return np.load(io.BytesIO(response.content), allow_pickle=True)

In [12]:
def get_inputs_patch(
    date: datetime, lonlat: tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the inputs patch of pixels for the given point and date."""
    image = get_inputs_image(date)
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)


def get_labels_patch(
    date: datetime, lonlat: tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the labels patch of pixels for the given point and year."""
    image = get_labels_image(date)
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)

## Creating the dataset


### --- Imports ---


In [13]:
import logging
import random
from datetime import datetime, timedelta

import dask.bag as db
import dask
from dask.bag.core import Bag
from dask.distributed import Client
import numpy as np
import os
import pandas as pd
import uuid

### --- Configs ---


In [14]:
NUM_SAMPLES = 1000
PATCH_SIZE = 128
PARTITION_SIZE = 10
START_DATE = "2015-07-01"
END_DATE = "2021-12-01"

### --- Sample Points ---


In [60]:
def sample_points(date: datetime) -> tuple:
    """Samples points within the defined polygon for the given year."""
    initialize_ee()
    dask.distributed.print(f"Getting land cover for {date}")
    land_cover = get_land_cover(date)
    snapshot_date = land_cover.date().format().getInfo()
    snapshot_date = datetime.strptime(snapshot_date, "%Y-%m-%dT%H:%M:%S")
    dask.distributed.print(f"Land cover snapshot date: {snapshot_date}")
    points = land_cover.stratifiedSample(
        numPoints=1,
        region=ee.Geometry.MultiPolygon(WORLD_POLYGONS),
        scale=WORLD_SCALE,
        geometries=True,
    )
    dask.distributed.print(f"Found for date {date} {points.size().getInfo()} points")
    point = points.toList(points.size()).getInfo()[0]
    return (snapshot_date, point["geometry"]["coordinates"])

### --- Prepare Training Data ---


In [16]:
def get_training_example(date: datetime, point: tuple) -> tuple:
    """Gets an (inputs, labels) training example for land cover change prediction."""
    inputs = get_inputs_patch(date, point, PATCH_SIZE)
    # Get land cover for the next day
    labels = get_labels_patch(date + timedelta(days=1), point, PATCH_SIZE)
    return (inputs, labels)

In [26]:
import dask.distributed


def try_get_example(date: datetime, point: tuple) -> tuple | None:
    """Wrapper to handle errors during training data generation."""
    ee.Initialize(project=project)
    dask.distributed.print(f"Generating training data for {date} at {point}")
    try:
        return get_training_example(date, point)
    except Exception as e:
        dask.distributed.print(f"Error occurred: {e}")

In [18]:
def random_date(start: datetime, end: datetime):
    """Generate a random datetime between `start` and `end`"""
    return start + timedelta(
        # Get a random amount of seconds between `start` and `end`
        seconds=random.randint(0, int((end - start).total_seconds())),
    )

# --- Dask Workflow for Dataset Creation ---

In [19]:
def write_npz(data: Bag, data_path: str) -> str:
    """Writes an (inputs, labels) set of data into a compressed NumPy file.

    Args:
        batch: Batch of (inputs, labels) pairs of NumPy arrays.
        data_path: Directory path to save files to.

    Returns: The filename of the data file.
    """
    initialize_ee()
    data.compute()
    dask.distributed.print(f"Writing {len(data)} data points to {data_path}")
    filename = os.path.join(data_path, f"{uuid.uuid4()}.npz")
    with open(filename, "xb") as f:
        inputs = [x for (x, _) in data]
        labels = [y for (_, y) in data]
        np.savez_compressed(f, inputs=inputs, labels=labels)
    logging.info(filename)
    return filename

In [32]:
client = Client()
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 16.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:50690,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 16.00 GiB

0,1
Comm: tcp://127.0.0.1:50703,Total threads: 2
Dashboard: http://127.0.0.1:50706/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:50693,
Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-ha6qnvma,Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-ha6qnvma

0,1
Comm: tcp://127.0.0.1:50704,Total threads: 2
Dashboard: http://127.0.0.1:50705/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:50695,
Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-j7edok6n,Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-j7edok6n

0,1
Comm: tcp://127.0.0.1:50702,Total threads: 2
Dashboard: http://127.0.0.1:50708/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:50697,
Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-hs_4r563,Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-hs_4r563

0,1
Comm: tcp://127.0.0.1:50701,Total threads: 2
Dashboard: http://127.0.0.1:50707/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:50699,
Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-jmfhklvc,Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-jmfhklvc


Found {'type': 'FeatureCollection', 'columns': {'label': 'Short<0, 255>'}, 'properties': {'band_order': ['label']}, 'features': []}
Found {'type': 'FeatureCollection', 'columns': {'label': 'Short<0, 255>'}, 'properties': {'band_order': ['label']}, 'features': [{'type': 'Feature', 'geometry': {'geodesic': False, 'type': 'Point', 'coordinates': [-165.20373842447157, 53.23057531563824]}, 'id': '0', 'properties': {'label': 0}}]}
Found {'type': 'FeatureCollection', 'columns': {'label': 'Short<0, 255>'}, 'properties': {'band_order': ['label']}, 'features': []}
Getting land cover for 2019-11-21 04:37:21
Land cover snapshot date: ee.Date({
  "functionInvocationValue": {
    "functionName": "Image.date",
    "arguments": {
      "image": {
        "functionInvocationValue": {
          "functionName": "Image.byte",
          "arguments": {
            "value": {
              "functionInvocationValue": {
                "functionName": "Image.unmask",
                "arguments": {
            

In [48]:
def run(data_path: str, samples: int = NUM_SAMPLES) -> None:
    """Runs the Dask workflow to generate the dataset."""

    # Generate dates from the start date to the end date
    start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
    end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

    random_dates = [random_date(start_date, end_date) for _ in range(samples)]

    # Authenticate and initialize Earth Engine.
    initialize_ee()
    print(client)

    def wrapper(data, data_path):
        return write_npz(data, data_path)

    points = (
        db.from_sequence(random_dates, npartitions=PARTITION_SIZE)
        .map(sample_points)
        .compute()
    )
    training_data = list(
        db.from_sequence(points, npartitions=PARTITION_SIZE)
        .map(try_get_example)
        .filter(lambda x: x is not None)
    )
    # training_data = list(bag.map(try_get_example).filter(lambda x: x is not None))

    db.from_sequence(training_data, npartitions=PARTITION_SIZE).map_partitions(
        wrapper, data_path=data_path
    ).compute()
    # training_data.map_partitions(wrapper, data_path=data_path).compute()

### --- Perform Dataset Creation ---

In [61]:
logging.getLogger().setLevel(logging.INFO)
run("data/climate_change/")

<Client: 'tcp://127.0.0.1:50690' processes=4 threads=8, memory=16.00 GiB>
Getting land cover for 2019-06-09 05:08:32
Getting land cover for 2016-03-22 10:29:51
Getting land cover for 2019-01-14 23:26:05
Getting land cover for 2021-01-20 22:43:43
Getting land cover for 2019-04-19 10:32:41
Getting land cover for 2018-08-13 06:29:24
Getting land cover for 2021-09-12 11:11:08
Land cover snapshot date: 2019-06-09 05:33:03
Getting land cover for 2020-11-02 01:35:47
Land cover snapshot date: 2016-03-22 10:50:38
Land cover snapshot date: 2021-01-20 22:56:35
Land cover snapshot date: 2019-01-14 23:41:48
Land cover snapshot date: 2021-09-12 11:37:22
Land cover snapshot date: 2019-04-19 10:36:45
Land cover snapshot date: 2020-11-02 01:36:19
Land cover snapshot date: 2018-08-13 06:30:17
Found for date 2021-09-12 11:11:08 0 points


Key:       ('sample_points-2bc64a07a647649ce1c0e9f9e9584cc4', 9)
Function:  execute_task
args:      ((<function reify at 0x1095f6840>, (<function map_chunk at 0x1095f6ca0>, <function sample_points at 0x12d693920>, [[datetime.datetime(2021, 9, 12, 11, 11, 8), datetime.datetime(2017, 12, 6, 14, 47, 29), datetime.datetime(2016, 4, 2, 19, 18, 48), datetime.datetime(2016, 2, 2, 5, 46, 36), datetime.datetime(2017, 3, 1, 5, 31, 16), datetime.datetime(2020, 4, 14, 22, 56, 17), datetime.datetime(2020, 1, 1, 8, 25, 9), datetime.datetime(2020, 3, 20, 4, 14, 3), datetime.datetime(2021, 9, 10, 16, 42, 29), datetime.datetime(2018, 11, 4, 7, 52, 48), datetime.datetime(2018, 9, 15, 1, 47, 10), datetime.datetime(2015, 11, 14, 19, 38, 11), datetime.datetime(2018, 7, 21, 10, 26, 55), datetime.datetime(2018, 3, 16, 1, 17, 8), datetime.datetime(2017, 6, 27, 18, 2, 9), datetime.datetime(2019, 11, 28, 8, 14, 18), datetime.datetime(2020, 10, 30, 20, 0, 55), datetime.datetime(2021, 11, 23, 2, 54, 21), datetime

EEException: Collection.toList: The value of 'count' must be positive. Got: 0.

Found for date 2019-04-19 10:32:41 1 points
Found for date 2019-01-14 23:26:05 2 points
Getting land cover for 2017-03-30 00:18:29
Land cover snapshot date: 2017-03-30 01:00:09
Found for date 2019-06-09 05:08:32 4 points
Getting land cover for 2016-12-31 12:59:58
Land cover snapshot date: 2016-12-31 14:27:24
Found for date 2021-01-20 22:43:43 0 points


Key:       ('sample_points-2bc64a07a647649ce1c0e9f9e9584cc4', 5)
Function:  execute_task
args:      ((<function reify at 0x109766980>, (<function map_chunk at 0x109766de0>, <function sample_points at 0x10df98860>, [[datetime.datetime(2021, 1, 20, 22, 43, 43), datetime.datetime(2016, 3, 24, 19, 53, 56), datetime.datetime(2016, 6, 16, 6, 38, 58), datetime.datetime(2016, 9, 2, 6, 38, 26), datetime.datetime(2016, 2, 29, 1, 22, 6), datetime.datetime(2018, 11, 19, 1, 23, 13), datetime.datetime(2020, 11, 13, 12, 19, 32), datetime.datetime(2019, 3, 11, 11, 56, 21), datetime.datetime(2017, 4, 27, 15, 10, 12), datetime.datetime(2016, 12, 2, 10, 38, 27), datetime.datetime(2016, 6, 15, 3, 21, 49), datetime.datetime(2018, 9, 24, 10, 43), datetime.datetime(2017, 8, 10, 18, 22, 12), datetime.datetime(2015, 9, 19, 22, 5, 14), datetime.datetime(2018, 12, 12, 17, 29, 37), datetime.datetime(2016, 7, 6, 1, 56, 58), datetime.datetime(2018, 5, 11, 10, 18, 8), datetime.datetime(2020, 8, 21, 21, 20, 1), datet

Found for date 2018-08-13 06:29:24 1 points
Found for date 2016-03-22 10:29:51 2 points
Getting land cover for 2016-11-23 20:49:20
Land cover snapshot date: 2016-11-23 21:18:11
Getting land cover for 2015-08-04 21:08:37
Land cover snapshot date: 2015-08-05 01:15:29
Found for date 2020-11-02 01:35:47 0 points


Key:       ('sample_points-2bc64a07a647649ce1c0e9f9e9584cc4', 6)
Function:  execute_task
args:      ((<function reify at 0x1091828e0>, (<function map_chunk at 0x109182d40>, <function sample_points at 0x118449620>, [[datetime.datetime(2020, 11, 2, 1, 35, 47), datetime.datetime(2019, 3, 23, 23, 12, 13), datetime.datetime(2020, 7, 27, 10, 39, 25), datetime.datetime(2020, 2, 29, 1, 52, 7), datetime.datetime(2018, 8, 16, 2, 19, 6), datetime.datetime(2019, 5, 29, 0, 30, 36), datetime.datetime(2020, 3, 15, 8, 19, 8), datetime.datetime(2018, 7, 22, 14, 11, 32), datetime.datetime(2020, 6, 25, 0, 46, 50), datetime.datetime(2015, 8, 26, 8, 44, 44), datetime.datetime(2021, 4, 22, 9, 54, 42), datetime.datetime(2015, 11, 3, 16, 15, 53), datetime.datetime(2016, 2, 10, 1, 56, 28), datetime.datetime(2019, 1, 17, 1, 21, 41), datetime.datetime(2019, 6, 1, 22, 56, 36), datetime.datetime(2017, 9, 19, 18, 56, 2), datetime.datetime(2019, 12, 26, 5, 39, 56), datetime.datetime(2021, 9, 9, 4, 48, 25), datetime.

Found for date 2016-11-23 20:49:20 0 points


Key:       ('sample_points-2bc64a07a647649ce1c0e9f9e9584cc4', 4)
Function:  execute_task
args:      ((<function reify at 0x1091828e0>, (<function map_chunk at 0x109182d40>, <function sample_points at 0x118449ee0>, [[datetime.datetime(2016, 3, 22, 10, 29, 51), datetime.datetime(2016, 11, 23, 20, 49, 20), datetime.datetime(2021, 2, 17, 23, 55, 15), datetime.datetime(2017, 4, 27, 3, 31, 28), datetime.datetime(2017, 9, 1, 14, 59, 46), datetime.datetime(2017, 8, 8, 23, 39, 48), datetime.datetime(2019, 9, 2, 14, 52, 25), datetime.datetime(2016, 7, 13, 0, 4, 18), datetime.datetime(2018, 6, 1, 16, 54, 10), datetime.datetime(2018, 1, 31, 11, 41, 4), datetime.datetime(2018, 11, 1, 17, 58, 30), datetime.datetime(2017, 9, 28, 6, 41, 27), datetime.datetime(2020, 5, 6, 14, 1), datetime.datetime(2018, 9, 24, 5, 3, 7), datetime.datetime(2020, 3, 9, 3, 37, 40), datetime.datetime(2017, 11, 3, 11, 8, 11), datetime.datetime(2019, 11, 2, 19, 4, 35), datetime.datetime(2020, 11, 4, 9, 7, 44), datetime.dateti

Found for date 2016-12-31 12:59:58 1 points
Found for date 2015-08-04 21:08:37 0 points


Key:       ('sample_points-2bc64a07a647649ce1c0e9f9e9584cc4', 3)
Function:  execute_task
args:      ((<function reify at 0x1095f6840>, (<function map_chunk at 0x1095f6ca0>, <function sample_points at 0x12d692de0>, [[datetime.datetime(2019, 4, 19, 10, 32, 41), datetime.datetime(2015, 8, 4, 21, 8, 37), datetime.datetime(2021, 3, 2, 13, 38, 55), datetime.datetime(2016, 8, 8, 16, 32, 42), datetime.datetime(2018, 7, 30, 19, 17, 31), datetime.datetime(2015, 11, 17, 20, 49, 54), datetime.datetime(2020, 11, 1, 17, 39, 50), datetime.datetime(2017, 6, 17, 15, 13, 7), datetime.datetime(2016, 10, 16, 14, 25, 39), datetime.datetime(2015, 11, 3, 16, 40, 50), datetime.datetime(2020, 3, 18, 14, 57, 58), datetime.datetime(2016, 6, 3, 2, 48, 14), datetime.datetime(2020, 4, 7, 10, 54, 9), datetime.datetime(2021, 7, 25, 5, 25, 13), datetime.datetime(2018, 11, 23, 14, 11, 39), datetime.datetime(2021, 9, 10, 7, 28, 42), datetime.datetime(2020, 7, 16, 3, 28, 31), datetime.datetime(2017, 11, 23, 13, 43, 11), 

Getting land cover for 2017-10-18 17:38:18
Land cover snapshot date: 2017-10-18 17:39:45
Found for date 2017-10-18 17:38:18 0 points
