In [43]:
from datetime import datetime
from dateutil.relativedelta import relativedelta
import ee
import requests
import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured
import io
from typing import Tuple
import geemap
import jax
import numpy as np
import asyncio
import ml_collections
from plotly.graph_objects import Image
from plotly.subplots import make_subplots
import orbax.checkpoint as ocp
import jax.numpy as jnp
import flaxtrainer as ft
from flaxtrainer import CNN_LandCover, CNN_LST

In [2]:
PATCH_SIZE = 128
SCALE=5000
LANDSAT_BANDS = ["B4", "B3", "B2"]
LAND_COVER_DATASET = "GOOGLE/DYNAMICWORLD/V1"  # Dynamic World Land Cover dataset
LANDSAT_VIS_PARAMS = {"bands": LANDSAT_BANDS, "min": 0, "max": 0.5}
LONLAT = (-73.9974, 44.2823)
LANDSAT_DATASET = "LANDSAT/LC08/C02/T1_L2"

In [3]:
project = "bigdata-ahhcash"
def initialize_ee():
  ee.Authenticate()
  ee.Initialize(project=project, opt_url="https://earthengine-highvolume.googleapis.com")

In [4]:
initialize_ee()

In [5]:
def mask_landsat8_clouds(image: ee.Image) -> ee.Image:
    """Mask clouds in Landsat 8 images."""
    CLOUD_SHADOW_BIT = 3
    CLOUD_BIT = 5
    qa = image.select('QA_PIXEL')
    cloud_mask = qa.bitwiseAnd(1 << CLOUD_SHADOW_BIT).eq(0) \
                   .And(qa.bitwiseAnd(1 << CLOUD_BIT).eq(0))
    return image.updateMask(cloud_mask)

In [6]:
def apply_scale_factors(image):
    """Apply scale factors to Landsat 8 bands."""
    optical_bands = image.select('SR_B.').multiply(0.0000275).add(-0.2)
    thermal_bands = image.select('ST_B.*').multiply(0.00341802).add(149.0)
    return image.addBands(optical_bands, None, True).addBands(
        thermal_bands, None, True
    )

In [7]:
def get_landsat_image(datetime: datetime = datetime.now() - relativedelta(years=2), default_value: float = 0.0) -> ee.Image:
    """Get a Landsat image for the given year."""
    start = ee.Date(datetime)
    end = ee.Date(datetime).advance(6, "month")
    return (
        ee.ImageCollection(LANDSAT_DATASET)
        .filterDate(start, end)
        .map(mask_landsat8_clouds)
        .map(apply_scale_factors)
        .median()
        .unmask(default_value)
        .float()
    )

In [8]:
def get_landsat_ndvi(image: ee.Image) -> ee.Image:
    """Calculates NDVI from a Landsat 8 image."""
    return image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI")

In [9]:
def get_land_cover(date: datetime = datetime.now() - relativedelta(years=5)) -> ee.Image:
    """Gets a Land Cover image for the given date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(6, 'month')
    return (
        ee.ImageCollection(LAND_COVER_DATASET)
        .filterDate(start_date, end_date)
        .select("label")
        .mosaic()
        .unmask(0)  # fill missing values with 0 (water)
        .byte()
    )

In [10]:
def get_landsat_lst(image: ee.Image) -> ee.Image:
    """
    Calculates Land Surface Temperature from a Landsat 8 image.
    This function is based on the formula in the following page https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LC08_C02_T1_L2
    """
    return image.select("ST_B10").multiply(0.00341802).add(149.0).rename("LST")

In [11]:
def get_land_cover(date: datetime, advance: int = 6) -> ee.Image:
    """Gets a Land Cover image for the given date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(advance, 'month')
    return (
        ee.ImageCollection(LAND_COVER_DATASET)
        .filterDate(start_date, end_date)
        .select("label")
        .mosaic()
        .unmask(0)  # fill missing values with 0 (water)
        .byte()
    )

In [12]:
def get_lst(date: datetime) -> ee.Image:
    """Gets the Land Surface Temperature for the given date."""

    landsat_image = get_landsat_image(date)

    return get_landsat_lst(landsat_image)

In [13]:
def get_modis_ndvi(date: datetime) -> ee.Image:
    """Gets MODIS NDVI data for a given date."""
    start_date = ee.Date(date)
    end_date = ee.Date(date).advance(1, 'month')
    return (
        ee.ImageCollection("MODIS/061/MOD13Q1")
        .filterDate(start_date, end_date)
        .select("NDVI")
        .first()
    )

In [14]:
def get_inputs_image(date: datetime) -> ee.Image:
    """Gets an Earth Engine image with all the inputs for the model."""
    # Get MODIS NDVI
    modis_ndvi = get_modis_ndvi(date)

    # Get Landsat data
    landsat_image = get_landsat_image(date)
    landsat_ndvi = get_landsat_ndvi(landsat_image)
    landsat_lst = get_landsat_lst(landsat_image)

    combined_ndvi = ee.Image.cat([landsat_ndvi, modis_ndvi])

    # Combine all input data
    return ee.Image([combined_ndvi, landsat_lst])

In [15]:
def get_patch(
    image: ee.Image, lonlat: Tuple[float, float], patch_size: int, scale: int
) -> np.ndarray:
    """Fetches a patch of pixels from Earth Engine."""
    point = ee.Geometry.Point(lonlat)
    url = image.getDownloadURL(
        {
            "region": point.buffer(scale * patch_size / 2, 1).bounds(1),
            "dimensions": [patch_size, patch_size],
            # "scale": SCALE,
            "format": "NPY",
        }
    )

    # Retry on "Too Many Requests" errors
    response = requests.get(url)
    if response.status_code == 429:
        raise Exception("Too Many Requests")

    # Raise other exceptions
    response.raise_for_status()
    return np.load(io.BytesIO(response.content), allow_pickle=True)

In [17]:
def get_inputs_patch(
    date: datetime, lonlat: Tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the inputs patch of pixels for the given point and date."""
    image = get_inputs_image(date)
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)


def get_labels_landcover_patch(
    date: datetime, lonlat: Tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the labels patch of pixels in LandCover for the given point and year."""
    image = get_land_cover(date + relativedelta(weeks=24))
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)


def get_labels_lst_patch(
    date: datetime, lonlat: Tuple[float, float], patch_size: int
) -> np.ndarray:
    """Gets the labels patch of pixels in LST for the given point and year."""
    image = get_lst(date + relativedelta(weeks=24))
    patch = get_patch(image, lonlat, patch_size, SCALE)
    return structured_to_unstructured(patch)

In [18]:
def render_landsat(image: ee.Image) -> geemap.Map:
    """Renders a Landsat image."""
    map = geemap.Map()
    map.set_center(LONLAT[0], LONLAT[1], 9)
    map.scroll_wheel_zoom = False
    map.dragging = False
    map.addLayerControl()
    return map

In [19]:
render_landsat(get_landsat_image())

Map(center=[44.2823, -73.9974], controls=(WidgetControl(options=['position', 'transparent_bg'], widget=SearchD…

In [22]:
inputs = get_inputs_patch(datetime.now() - relativedelta(years=2), LONLAT, PATCH_SIZE)
labels_lst = get_labels_lst_patch(datetime.now() - relativedelta(years=2), LONLAT, PATCH_SIZE)
labels_lc = get_labels_landcover_patch(datetime.now() - relativedelta(years=2), LONLAT, PATCH_SIZE)

inputs.shape, labels_lst.shape, labels_lc.shape

((128, 128, 3), (128, 128, 1), (128, 128, 1))

In [28]:
import shutil
import os

ModuleNotFoundError: No module named 'zarr'

In [44]:
model_lc = CNN_LandCover()
model_lst = CNN_LST()
config = ml_collections.ConfigDict()

config.learning_rate = 0.0002
config.batch_size = 32
config.num_epochs = 100
config.train_test_split = 0.9

rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
init_lc_state = ft.create_train_state(init_rng, config, "lc")
init_lst_state = ft.create_train_state(init_rng, config, "lst")

In [45]:
lc_ckpt_dir = "../inference/restore/lc"
lst_ckpt_dir = "../inference/restore/lst"
lc_path = ocp.test_utils.erase_and_create_empty(lc_ckpt_dir)
lst_path = ocp.test_utils.erase_and_create_empty(lst_ckpt_dir)
lc_ckptr = ocp.StandardCheckpointer()
lst_ckptr = ocp.StandardCheckpointer()

In [47]:
lc_ckptr.save(os.path.abspath("../inference/restore/lc/"), args=ocp.args.StandardSave(init_lc_state))



In [48]:
lst_ckptr.save(os.path.abspath("../inference/restore/lst/"), args=ocp.args.StandardSave(init_lst_state))

In [40]:
shutil.copytree(os.path.abspath("../inference/models/flax/lst/"), os.path.abspath("../inference/restore/lst/"), dirs_exist_ok=True)

'/Users/aakashshankar/Repos/big_data_project/inference/restore/lst'

In [50]:
state_lc = lc_ckptr.restore(os.path.abspath("../inference/restore/lc"))



In [52]:
state_lst = lst_ckptr.restore(os.path.abspath("../inference/restore/lst"))



In [53]:
model_lc.param = state_lc