In [None]:
import ee
import random
import math
import json
import os

ee.Initialize()

In [None]:
# 1. Parameters
tile_size = 256  # px
tile_scale = 100  # m/px

crs = "EPSG:3857"
tile_width = tile_scale * tile_size
time_range = ("2015-01-01", "2025-01-01")
cloud_threshold = 10  # %
water_max_ratio = 0.9
n_tiles = 9000

json_dir = "training-data/satellite-raw"
if not os.path.exists(json_dir):
    os.makedirs(json_dir)

json_files = sorted([f for f in os.listdir(json_dir) if f.endswith(".json")])
print(len(json_files), "existing JSON files in", json_dir)
last_json = json_files[-1] if json_files else None
if last_json:
    last_index = int(last_json.split("_")[1])
    i_tile = last_index + 1
else:
    i_tile = 0
print("Starting with tile index", i_tile)

In [None]:
def random_lat_uniform_area(min_lat_deg, max_lat_deg):
    min_lat_rad = math.radians(min_lat_deg)
    max_lat_rad = math.radians(max_lat_deg)
    u = random.uniform(math.sin(min_lat_rad), math.sin(max_lat_rad))
    lat_rad = math.asin(u)
    return math.degrees(lat_rad)


land = ee.FeatureCollection("USDOS/LSIB_SIMPLE/2017")
valid_tiles = []


def export_tile(i_tile):
    lon = random.uniform(-180, 180)
    lat = random_lat_uniform_area(-60, 80)
    point = ee.Geometry.Point([lon, lat])

    # Only keep points on land
    if land.filterBounds(point).size().getInfo() == 0:
        return None

    s2 = (
        ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
        .filterDate(time_range[0], time_range[1])
        .filterBounds(point)
        .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", cloud_threshold))
        .sort("system:time_start")
        .first()
    )

    footprint = s2.geometry()
    point = footprint.centroid()
    region = point.buffer(tile_width * 1.02 / 2).bounds()  # 1% buffer on each side

    i_tile_str = f"{i_tile:05d}"

    # Export RGB
    rgb = (
        s2.select(["B4", "B3", "B2"])
        .clip(region)
        .reproject(crs=crs, scale=tile_scale)
        .min(3000)
        .divide(3000)
        .multiply(255)
        .uint8()
    )

    ee.batch.Export.image.toDrive(
        image=rgb,
        description=f"tile_{i_tile_str}_rgb",
        folder="earthengine",
        fileNamePrefix=f"tile_{i_tile_str}_rgb",
        region=region,
        scale=tile_scale,
        crs=crs,
        maxPixels=1e9,
    ).start()

    # Cloud mask
    qa = s2.select("QA60")
    opaque = qa.bitwiseAnd(1 << 10).gt(0)
    cirrus = qa.bitwiseAnd(1 << 11).gt(0)

    cld = (
        opaque.Or(cirrus)
        .rename("cloud_mask")
        .clip(region)
        .reproject(crs=crs, scale=tile_scale)
    )

    ee.batch.Export.image.toDrive(
        image=cld,
        description=f"tile_{i_tile_str}_cld",
        folder="earthengine",
        fileNamePrefix=f"tile_{i_tile_str}_cld",
        region=region,
        scale=tile_scale,
        crs=crs,
        maxPixels=1e9,
    ).start()

    # Export DEM
    dem = (
        ee.ImageCollection("COPERNICUS/DEM/GLO30")
        .mosaic()
        .select("DEM")
        .clip(region)
        .reproject(crs=crs, scale=tile_scale)
        .int16()
    )

    ee.batch.Export.image.toDrive(
        image=dem.clip(region),
        description=f"tile_{i_tile_str}_dem",
        folder="earthengine",
        fileNamePrefix=f"tile_{i_tile_str}_dem",
        region=region,
        scale=tile_scale,
        crs=crs,
        maxPixels=1e9,
    ).start()

    # Export Landcover
    landcover = (
        ee.Image("ESA/WorldCover/v100/2020")
        .clip(region)
        .reproject(crs=crs, scale=tile_scale)
        .toUint8()
    )

    esa_original = [10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 100]
    esa_remapped = [4, 5, 6, 8, 9, 3, 2, 1, 7, 7, 10]
    landcover = landcover.remap(esa_original, esa_remapped)

    ee.batch.Export.image.toDrive(
        image=landcover,
        description=f"tile_{i_tile_str}_lct",
        folder="earthengine",
        fileNamePrefix=f"tile_{i_tile_str}_lct",
        region=region,
        scale=tile_scale,
        crs=crs,
        maxPixels=1e9,
    ).start()

    # Extract auxiliary metadata
    solar_zenith = s2.get("MEAN_SOLAR_ZENITH_ANGLE").getInfo()
    solar_azimuth = s2.get("MEAN_SOLAR_AZIMUTH_ANGLE").getInfo()
    month = ee.Date(s2.get("system:time_start")).get("month").getInfo()

    json_data = {
        "solar_zenith": round(solar_zenith, 2),
        "solar_azimuth": round(solar_azimuth, 2),
        "month": month,
        "tile_size": tile_size,
        "tile_scale": tile_scale,
        "centre_lon": round(point.coordinates().get(0).getInfo(), 4),
        "centre_lat": round(point.coordinates().get(1).getInfo(), 4),
        "tiles": {
            "rgb": f"tile_{i_tile_str}_rgb.tif",
            "dem": f"tile_{i_tile_str}_dem.tif",
            "lct": f"tile_{i_tile_str}_lct.tif",
            "cld": f"tile_{i_tile_str}_cld.tif",
        },
    }

    with open(f"{json_dir}/tile_{i_tile_str}_met.json", "w") as f:
        json.dump(json_data, f, indent=2)
    return json_data

In [None]:
import time

n_tiles = 15000

while i_tile < n_tiles:
    try:
        json_data = export_tile(i_tile)
        if json_data is not None:
            valid_tiles.append(json_data)
            print(f"Tile {i_tile} processed.")
            i_tile += 1
    except Exception as e:
        # if failed, wait 1 minute and retry
        print(f"Tile {i_tile} failed with error: {e}, retrying in 1 minute...")
        time.sleep(60)

print(f"Export tasks launched for {i_tile} tiles.")

In [None]:
json_files = glob.glob("training-data/satellite-raw/tile_*_met.json")

print(f"Found {len(json_files)} JSON files to process")
for json_path in json_files:
    process_from_json(json_path, "training-data/satellite-raw/raw", "training-data/satellite")

In [None]:
import json


def refactor_metadata(json_path, save=False):

    with open(json_path, "r") as f:
        meta = json.load(f)

        # Check if tile keys are already reformatted
        tile_keys = ["tile_rgb", "tile_dem", "tile_lct", "tile_cld"]
        if "tiles" not in meta:
            tile_keys_present = [key for key in tile_keys if key in meta]
            meta["tiles"] = {
                key.replace("tile_", ""): meta[key] for key in tile_keys_present
            }

            # Remove old keys
            for key in tile_keys_present:
                del meta[key]

    if save:
        with open(json_path, "w") as f:
            json.dump(meta, f, indent=2)

    return meta

In [1]:
import glob
import rasterio
import json
import os
import numpy as np


def process_from_json(
    json_path,
    input_dir,
    output_dir,
    water_max_ratio=0.9,
    max_black_ratio=0.3,
    max_white_ratio=0.9,
):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output folder {output_dir}")

    with open(json_path, "r") as f:
        meta = json.load(f)

    tile_size = meta.get("tile_size", 256)

    keys = meta.get("tiles", {}).keys()
    input_paths = []
    output_paths = []

    for key in keys:
        input_file = meta.get("tiles").get(key)
        input_path = os.path.join(input_dir, input_file)

        if not os.path.exists(input_path):
            print(f"File {input_path} not found, skipping tile")
            return False
    

        if key == "dem":
            with rasterio.open(input_path, "r") as src:
                dem = src.read(1)
                dem[np.isnan(dem)] = 0
                # skip if > 90% of pixels are water
                if np.sum(dem == 0) / dem.size > water_max_ratio:
                    print(f"Skipping tile {input_file} due to too much water")
                    return False

        if key == "rgb":
            with rasterio.open(input_path, "r") as src:
                rgb = src.read()
                rgb_brightness_approx = np.mean(rgb, axis=0)
                
                # skip if too black
                if (
                    np.sum(rgb_brightness_approx < 5) / rgb_brightness_approx.size
                    > max_black_ratio
                ):
                    print(f"Skipping tile {input_file} due to too black")
                    return False
                
                # skip if too white
                if (
                    np.sum(rgb_brightness_approx > 250) / rgb_brightness_approx.size
                    > max_white_ratio
                ):
                    print(f"Skipping tile {input_file} due to too white")
                    return False

        output_path = os.path.join(output_dir, input_file)

        input_paths.append(input_path)
        output_paths.append(output_path)

    for input_path, output_path in zip(input_paths, output_paths):
        truncate_tile(input_path, size=[tile_size, tile_size], output_path=output_path)

    json_output = os.path.join(output_dir, os.path.basename(json_path))
    with open(json_output, "w") as f:
        json.dump(meta, f, indent=2)
        
    return True


def truncate_tile(input_path, size=[256, 256], output_path=None):
    if output_path is None:
        output_path = input_path

    # check output folder exists
    output_dir = os.path.dirname(output_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output folder {output_dir}")
    # check input file exists
    if not os.path.exists(input_path):
        raise ValueError(f"File {input_path} not found")

    with rasterio.open(input_path, "r") as input_img:
        img = input_img.read()

        # Truncate or crop to 256x256 from top-left
        try:
            img_cropped = img[:, : size[0], : size[1]]
        except Exception as e:
            print(img.shape)
            raise ValueError(
                f"Error cropping image {input_path} to {size[0]}x{size[1]}: {e}"
            )

        # Update dataset dimensions in place
    with rasterio.open(
        output_path,
        "w",
        driver="GTiff",
        height=size[0],
        width=size[1],
        count=input_img.count,
        dtype=img.dtype,
        crs=input_img.crs,
        transform=input_img.transform,
    ) as output_img:
        output_img.write(img_cropped)

In [2]:
json_files = sorted(glob.glob("training-data/satellite-raw/tile_*_met.json"))

print(f"Found {len(json_files)} JSON files to process")

success_count = 0

for json_path in json_files:
    is_success = process_from_json(
        json_path, "training-data/satellite-raw/", "training-data/satellite"
    )
    if is_success:
        success_count += 1

print(
    f"Processed {success_count} ({success_count/len(json_files)*100}%) tiles successfully."
)

Found 13563 JSON files to process
Skipping tile tile_00005_rgb.tif due to too white
Skipping tile tile_00018_rgb.tif due to too white
Skipping tile tile_00019_rgb.tif due to too white
Skipping tile tile_00024_rgb.tif due to too white
Skipping tile tile_00025_rgb.tif due to too white
Skipping tile tile_00039_rgb.tif due to too white
Skipping tile tile_00042_rgb.tif due to too white
Skipping tile tile_00049_rgb.tif due to too white
Skipping tile tile_00060_rgb.tif due to too black
Skipping tile tile_00068_rgb.tif due to too white
Skipping tile tile_00076_dem.tif due to too much water
Skipping tile tile_00084_rgb.tif due to too white
Skipping tile tile_00086_rgb.tif due to too white
Skipping tile tile_00133_rgb.tif due to too white
Skipping tile tile_00137_rgb.tif due to too white
Skipping tile tile_00144_rgb.tif due to too white
Skipping tile tile_00147_rgb.tif due to too white
Skipping tile tile_00153_rgb.tif due to too white
Skipping tile tile_00157_dem.tif due to too much water
Skippi