In [1]:
!uv pip install -r requirements.txt

[2mAudited [1m155 packages[0m in 66ms[0m


### --- Import Libraries ---


In [2]:
import io
from datetime import datetime, timedelta

import ee
from google.api_core import exceptions, retry
import google.auth
import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured
import requests

### --- Constants and Earth Engine Initialization ---


In [3]:
SCALE = 5000  # meters per pixel
WORLD_SCALE = 10_000
LAND_COVER_DATASET = "GOOGLE/DYNAMICWORLD/V1"  # Dynamic World Land Cover dataset
LAND_COVER_BAND = "Map"  # Land cover classification band
WORLD_POLYGONS = [
    # Americas
    [(-33.0, -7.0), (-55.0, 53.0), (-166.0, 65.0), (-68.0, -56.0)],
    # Africa, Asia, Europe
    [
        (74.0, 71.0),
        (166.0, 55.0),
        (115.0, -11.0),
        (74.0, -4.0),
        (20.0, -38.0),
        (-29.0, 25.0),
    ],
    # Australia
    [(170.0, -47.0), (179.0, -37.0), (167.0, -12.0), (128.0, 17.0), (106.0, -29.0)],
]
POLYGON = [(-140.0, 60.0), (-140.0, -60.0), (-10.0, -60.0), (-10.0, 60.0)]

### --- Earth Engine Initialization ---


In [4]:
project = "ee-rohitp934"
# Use cli to authenticate
# !earthengine authenticate

# Or use the following code to authenticate
def initialize_ee():
  ee.Authenticate()
  ee.Initialize(project=project, opt_url="https://earthengine-highvolume.googleapis.com")

In [5]:
initialize_ee()

### --- Data Retreival Functions ---


In [103]:
def get_modis_ndvi(date: datetime) -> ee.Image:
    """Gets MODIS NDVI data for a given date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(1, 'month')
    return (
        ee.ImageCollection("MODIS/061/MOD13Q1")
        .filterDate(start_date, end_date)
        .select("NDVI")
        .first()
    )

In [104]:
def get_landsat_image(date: datetime) -> ee.Image:
    """Gets a Landsat 8 image for the selected date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(6, 'month')
    return (
        ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
        .filterDate(start_date, end_date)
        .mosaic()
    )

In [8]:
def get_landsat_ndvi(image: ee.Image) -> ee.Image:
    """Calculates NDVI from a Landsat 8 image."""
    return image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI")

In [9]:
def get_landsat_lst(image: ee.Image) -> ee.Image:
    """
    Calculates Land Surface Temperature from a Landsat 8 image.
    This function is based on the formula in the following page https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LC08_C02_T1_L2
    """
    return image.select("ST_B10").multiply(0.00341802).add(149.0).rename("LST")

In [60]:
def get_land_cover(date: datetime, advance: int = 6) -> ee.Image:
    """Gets a Land Cover image for the given date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(advance, 'month')
    return (
        ee.ImageCollection(LAND_COVER_DATASET)
        .filterDate(start_date, end_date)
        .select("label")
        # .first()
        .mosaic()
        # .sort("system:time_start", False)
        # .mosaic()
        # .rename("landcover")
        .unmask(0)  # fill missing values with 0 (water)
        .byte()
    )

### --- Input and Label Image Composition ---


In [11]:
def get_inputs_image(date: datetime) -> ee.Image:
    """Gets an Earth Engine image with all the inputs for the model."""
    # Get MODIS NDVI
    modis_ndvi = get_modis_ndvi(date)

    # Get Landsat data
    landsat_image = get_landsat_image(date)
    landsat_ndvi = get_landsat_ndvi(landsat_image)
    landsat_lst = get_landsat_lst(landsat_image)

    combined_ndvi = ee.Image.cat([landsat_ndvi, modis_ndvi])

    # Combine all input data
    return ee.Image([combined_ndvi, landsat_lst])

In [12]:
def get_labels_image(year: int) -> ee.Image:
    """Gets a Land Cover image for the selected year and preprocesses it."""
    land_cover = get_land_cover(year)
    # Add preprocessing steps if needed (e.g., remapping land cover classes)
    return land_cover

### --- Get input and labels for a given latitude and longitude ---


In [13]:
@retry.Retry(deadline=10 * 60)  # seconds
def get_patch(
    image: ee.Image, lonlat: tuple[float, float], patch_size: int, scale: int
) -> np.ndarray:
    """Fetches a patch of pixels from Earth Engine."""
    point = ee.Geometry.Point(lonlat)
    url = image.getDownloadURL(
        {
            "region": point.buffer(scale * patch_size / 2, 1).bounds(1),
            "dimensions": [patch_size, patch_size],
            # "scale": SCALE,
            "format": "NPY",
        }
    )

    # Retry on "Too Many Requests" errors
    response = requests.get(url)
    if response.status_code == 429:
        raise exceptions.TooManyRequests(response.text)

    # Raise other exceptions
    response.raise_for_status()
    return np.load(io.BytesIO(response.content), allow_pickle=True)

In [15]:
def get_inputs_patch(
    date: datetime, lonlat: tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the inputs patch of pixels for the given point and date."""
    image = get_inputs_image(date)
    # dask.distributed.print(f"{date} : {lonlat} :: Inputs : {image.getInfo()}")
    patch = get_patch(image, lonlat, patch_size, SCALE)
    # dask.distributed.print(f"Input numpy array downloaded for {date} at {lonlat}")
    return structured_to_unstructured(patch)


def get_labels_patch(
    date: datetime, lonlat: tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the labels patch of pixels for the given point and year."""
    image = get_labels_image(date)
    # dask.distributed.print(f"Labels retrieved for {date} at {lonlat}")
    patch = get_patch(image, lonlat, patch_size, SCALE)
    # dask.distributed.print(f"Label numpy array downloaded for {date} at {lonlat}")
    return structured_to_unstructured(patch)

## Creating the dataset


### --- Imports ---


In [16]:
import logging
import random
from datetime import datetime, timedelta

import dask.bag as db
import dask
from dask.bag.core import Bag
from dask.distributed import Client
import numpy as np
import os
import pandas as pd
import uuid

### --- Configs ---


In [57]:
NUM_SAMPLES = 1
PARTITION_SIZE = 1
NUM_SAMPLES_PER_PARTITION = NUM_SAMPLES // PARTITION_SIZE
PATCH_SIZE = 128
START_DATE = "2015-01-01"
END_DATE = "2021-01-01"
start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

years = range(start_date.year, end_date.year + 1)
dates = [datetime(year, 1, 1) for year in years]

### --- Sample Points ---


In [18]:
def random_date(start: datetime, end: datetime):
    """Generate a random datetime between `start` and `end`"""
    return start + timedelta(
        # Get a random amount of seconds between `start` and `end`
        seconds=random.randint(0, int((end - start).total_seconds())),
    )

In [96]:
def sample_points(seed: int) -> tuple:
    global numFailures, numSuccesses, dates
    """Samples points within the defined polygon for the given year."""
    initialize_ee()
    # while True:
    # dask.distributed.print(f"Getting land cover for {date}")

    dates = [d + timedelta(days=(30 * seed)) for d in dates[2:3]]
    results = []
    for date in dates:
        tries = 2
        advance = 6
        while tries > 0:
            try:
                land_cover = get_land_cover(date, advance)
                # dask.distributed.print("Retrieved image")
                # snapshot_date = land_cover.date().format().getInfo()
                # snapshot_date = datetime.strptime(snapshot_date, "%Y-%m-%dT%H:%M:%S")
                # dask.distributed.print(f"Land cover snapshot date: {snapshot_date}")
                points = land_cover.stratifiedSample(
                    numPoints=NUM_SAMPLES_PER_PARTITION,
                    region=ee.Geometry.MultiPolygon(WORLD_POLYGONS),
                    scale=WORLD_SCALE,
                    geometries=True,
                    seed=seed,
                    classBand="label",
                    tileScale=4,
                )
                # print(f"Found for date {date} {points.size().getInfo()} points")
                if int(points.size().getInfo()) > 0:
                    # dask.distributed.print(f"Found for date {date} {points.size().getInfo()} points")
                    # numSuccesses.set(numSuccesses.get() + 1)
                    #     point = points.toList(points.size()).getInfo()[0]
                    # return (date, point["geometry"]["coordinates"])
                    features = points.getInfo()["features"]
                    # print(features)
                    for feature in features:
                        results.append((date, feature["geometry"]["coordinates"]))

                    # results.append((date, coordinates))
                    break
                # numFailures.set(numFailures.get() + 1)
                # start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
                # end_date = datetime.strptime(END_DATE, "%Y-%m-%d")
                # date = random_date(start_date, end_date)
                tries -= 1
            except Exception as e:
                tries -= 1
                advance = 3
                print(f"Error occurred for {date}: {e} ... Retrying {tries} times")

    return results

### --- Prepare Training Data ---


In [108]:
def get_training_example(date: datetime, point: tuple) -> tuple:
    """Gets an (inputs, labels) training example for land cover change prediction."""
    # dask.distributed.print(f"Getting inputs for {date} at {point}")
    inputs = get_inputs_patch(date, point, PATCH_SIZE)
    # Get land cover for the next day
    # dask.distributed.print(f"Getting labels for {date} at {point}")
    labels = get_labels_patch(date + timedelta(days=1), point, PATCH_SIZE)
    return (inputs, labels)

In [118]:
from typing import Iterator


def try_get_example(sample_point: tuple[datetime, tuple]) -> tuple:
    """Wrapper to handle errors during training data generation."""
    ee.Initialize(project=project)
    date, point = sample_point
    try:
        return get_training_example(date, point)
    except Exception as e:
        dask.distributed.print(f"Error occurred: {e}")

# --- Dask Workflow for Dataset Creation ---

In [22]:
def write_npz(data: Bag, data_path: str) -> str:
    """Writes an (inputs, labels) set of data into a compressed NumPy file.

    Args:
        batch: Batch of (inputs, labels) pairs of NumPy arrays.
        data_path: Directory path to save files to.

    Returns: The filename of the data file.
    """
    initialize_ee()
    # data.compute()
    # dask.distributed.print(f"Writing {len(data)} data points to {data_path}")
    filename = os.path.join(data_path, f"{uuid.uuid4()}.npz")
    with open(filename, "xb") as f:
        inputs = [x for (x, _) in data]
        labels = [y for (_, y) in data]
        np.savez_compressed(f, inputs=inputs, labels=labels)
    logging.info(filename)
    return filename

In [22]:
from dask.distributed import Variable

client = Client()

numFailures = Variable("numFailures")
numFailures.set(0)
numSuccesses = Variable("numSuccesses")
numSuccesses.set(0)

client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 16.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:57093,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 16.00 GiB

0,1
Comm: tcp://127.0.0.1:57104,Total threads: 2
Dashboard: http://127.0.0.1:57109/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:57096,
Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-fhl7js1n,Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-fhl7js1n

0,1
Comm: tcp://127.0.0.1:57105,Total threads: 2
Dashboard: http://127.0.0.1:57108/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:57098,
Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-9ofqrktu,Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-9ofqrktu

0,1
Comm: tcp://127.0.0.1:57106,Total threads: 2
Dashboard: http://127.0.0.1:57112/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:57100,
Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-4v5xfcar,Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-4v5xfcar

0,1
Comm: tcp://127.0.0.1:57107,Total threads: 2
Dashboard: http://127.0.0.1:57114/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:57102,
Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-adhy7rw1,Local directory: /var/folders/j1/mgpcsdr54rv4c_hcgbz4n98w0000gn/T/dask-scratch-space/worker-adhy7rw1


In [24]:
start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

years = range(start_date.year, end_date.year + 1)
dates = [datetime(year, 1, 1) for year in years if year != 2020]

# random_dates = [random_date(start_date, end_date) for _ in range(samples)]

# dates = 
print(dates)

[datetime.datetime(2015, 1, 1, 0, 0), datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2017, 1, 1, 0, 0), datetime.datetime(2018, 1, 1, 0, 0), datetime.datetime(2019, 1, 1, 0, 0), datetime.datetime(2021, 1, 1, 0, 0)]


In [25]:
points = None

In [25]:
def run(data_path: str, samples: int = NUM_SAMPLES) -> None:
    """Runs the Dask workflow to generate the dataset."""
    global points
    # Generate dates from the start date to the end date
    start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
    end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

    years = range(start_date.year, end_date.year + 1)
    dates = [datetime(year, 1, 1) for year in years]
    # random_dates = [random_date(start_date, end_date) for _ in range(samples)]

    # Authenticate and initialize Earth Engine.
    # initialize_ee()

    # def wrapper(data, data_path):
    #     return write_npz(data, data_path)

    points = (
        db.from_sequence(dates, npartitions=PARTITION_SIZE)
        .map(sample_points)
        .compute()
    )
    dask.distributed.print(numFailures.get())
    dask.distributed.print(numSuccesses.get())
    dask.distributed.print("Starting training data generation")
    points = points.filter(lambda x, _: x is not None)
    
    # training_data.map_partitions(wrapper, data_path=data_path).compute()

In [28]:
start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

years = range(start_date.year, end_date.year + 1)
dates = [datetime(year, 1, 1) for year in years]

points = sample_points(dates[6])
points

Retrieved image
Found for date 2021-01-01 00:00:00 9 points


(datetime.datetime(2021, 1, 1, 0, 0),
 {'type': 'FeatureCollection',
  'columns': {'label': 'Short<0, 255>'},
  'properties': {'band_order': ['label']},
  'features': [{'type': 'Feature',
    'geometry': {'geodesic': False,
     'type': 'Point',
     'coordinates': [-89.24762347727446, -3.9975030143318704]},
    'id': '0',
    'properties': {'label': 0}},
   {'type': 'Feature',
    'geometry': {'geodesic': False,
     'type': 'Point',
     'coordinates': [-60.59136591386172, -27.443531929851382]},
    'id': '1',
    'properties': {'label': 1}},
   {'type': 'Feature',
    'geometry': {'geodesic': False,
     'type': 'Point',
     'coordinates': [-86.01368845444418, 43.25388093035496]},
    'id': '2',
    'properties': {'label': 2}},
   {'type': 'Feature',
    'geometry': {'geodesic': False,
     'type': 'Point',
     'coordinates': [-51.06922390219479, -26.54521664573186]},
    'id': '3',
    'properties': {'label': 3}},
   {'type': 'Feature',
    'geometry': {'geodesic': False,
     't

### --- Perform Dataset Creation ---

In [None]:
logging.getLogger().setLevel(logging.INFO)
run("data/climate_change/")

In [None]:
training_data = (
        db.from_sequence(points, npartitions=PARTITION_SIZE)
        .map(try_get_example)
        .filter(lambda x: x is not None)
    )
    training_data = list(training_data)

    db.from_sequence(training_data, npartitions=PARTITION_SIZE).map_partitions(
        write_npz, data_path=data_path
    ).compute()

In [120]:
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions

start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
end_date = datetime.strptime(END_DATE, "%Y-%m-%d")

years = range(start_date.year, end_date.year + 1)
dates = [datetime(year, 1, 1) for year in years]
beam_options = PipelineOptions(
    [],
    save_main_session=True,
    # setup_file="./setup.py",
    # max_num_workers=max_requests,  # distributed runners
    direct_num_workers=2,  # direct runner
    # disk_size_gb=50,
)
with beam.Pipeline(options=beam_options) as pipeline:
    # if not points:
    points = (
        pipeline
        # | "🌱 Make seeds" >> beam.Create(dates)
        | "🌱 Make seeds" >> beam.Create([5])
        | "📌 Sample points" >> beam.FlatMap(sample_points)
        
    )

    (
        points
        | "🃏 Reshuffle" >> beam.Reshuffle()
        | "📑 Get examples" >> beam.Map(try_get_example)
        | "🗂️ Batch examples" >> beam.BatchElements(10)
        | "📝 Write NPZ Files" >> beam.Map(write_npz, data_path="data/climate_change")
        # | "Print" >> beam.Map(print)
    )

