# Download and Process Satellite Imagery from Google Earth Engine
## Download Data

In [None]:
import ee
import random
import math
import json
import os

ee.Authenticate()
ee.Initialize()

In [None]:
# 1. Parameters
tile_size = 256  # px
tile_scale = 200  # m/px

crs = "EPSG:3857"
tile_width = tile_scale * tile_size
time_range = ("2015-01-01", "2025-01-01")
cloud_threshold = 10  # %
water_max_ratio = 0.9

json_dir = "training-data/satellite"
if not os.path.exists(json_dir):
    os.makedirs(json_dir)

json_files = sorted([f for f in os.listdir(json_dir) if f.endswith(".json")])
print(len(json_files), "existing JSON files in", json_dir)
last_json = json_files[-1] if json_files else None
if last_json:
    last_index = int(last_json.split("_")[1])
    i_tile = last_index + 1
else:
    i_tile = 0
print("Starting with tile index", i_tile)

land = ee.FeatureCollection("USDOS/LSIB_SIMPLE/2017")

used_ids_file = os.path.join(json_dir, "used_ids.txt")

used_ids = set()
if os.path.exists(used_ids_file):
    with open(used_ids_file, "r") as f:
        used_ids = set(f.read().splitlines())
else: 
    with open(used_ids_file, "w") as f:
        f.write("")
        
drive_folder = ""

In [None]:
def random_lat_uniform_area(min_lat_deg, max_lat_deg):
    min_lat_rad = math.radians(min_lat_deg)
    max_lat_rad = math.radians(max_lat_deg)
    u = random.uniform(math.sin(min_lat_rad), math.sin(max_lat_rad))
    lat_rad = math.asin(u)
    return math.degrees(lat_rad)


def export_tile(i_tile):
    lon = random.uniform(-180, 180)
    lat = random_lat_uniform_area(-60, 80)
    point = ee.Geometry.Point([lon, lat])

    # Only keep points on land
    if land.filterBounds(point).size().getInfo() == 0:
        return None, None

    s2_collection = (
        ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
        .filterDate(time_range[0], time_range[1])
        .filterBounds(point)
        .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", cloud_threshold))
    )
    image_ids = s2_collection.aggregate_array("system:index").getInfo()
    # find images that have not been used yet
    image_ids = [img_id for img_id in image_ids if img_id not in used_ids]
    if len(image_ids) == 0:
        return None, None
    img_id = random.choice(image_ids)
    s2 = ee.Image(f"COPERNICUS/S2_SR_HARMONIZED/{img_id}")

    footprint = s2.geometry()
    point = footprint.centroid()
    region = point.buffer(tile_width * 1.02 / 2).bounds()  # 1% buffer on each side

    i_tile_str = f"{i_tile:05d}"

    # Export RGB
    rgb = (
        s2.select(["B4", "B3", "B2"])
        .clip(region)
        .reproject(crs=crs, scale=tile_scale)
        .min(3000)
        .divide(3000)
        .multiply(255)
        .uint8()
    )

    ee.batch.Export.image.toDrive(
        image=rgb,
        description=f"tile_{i_tile_str}_rgb",
        # folder=drive_folder,
        fileNamePrefix=f"tile_{i_tile_str}_rgb",
        region=region,
        scale=tile_scale,
        crs=crs,
        maxPixels=1e9,
    ).start()

    # Cloud mask
    qa = s2.select("QA60")
    opaque = qa.bitwiseAnd(1 << 10).gt(0)
    cirrus = qa.bitwiseAnd(1 << 11).gt(0)

    cld = (
        opaque.Or(cirrus)
        .rename("cloud_mask")
        .clip(region)
        .reproject(crs=crs, scale=tile_scale)
    )

    ee.batch.Export.image.toDrive(
        image=cld,
        description=f"tile_{i_tile_str}_cld",
        # folder=drive_folder,
        fileNamePrefix=f"tile_{i_tile_str}_cld",
        region=region,
        scale=tile_scale,
        crs=crs,
        maxPixels=1e9,
    ).start()

    # Export DEM
    dem = (
        ee.ImageCollection("COPERNICUS/DEM/GLO30")
        .mosaic()
        .select("DEM")
        .clip(region)
        .reproject(crs=crs, scale=tile_scale)
        .int16()
    )

    ee.batch.Export.image.toDrive(
        image=dem.clip(region),
        description=f"tile_{i_tile_str}_dem",
        # folder=drive_folder,
        fileNamePrefix=f"tile_{i_tile_str}_dem",
        region=region,
        scale=tile_scale,
        crs=crs,
        maxPixels=1e9,
    ).start()

    # Export Landcover
    landcover = (
        ee.Image("ESA/WorldCover/v100/2020")
        .clip(region)
        .reproject(crs=crs, scale=tile_scale)
        .toUint8()
    )

    esa_original = [10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 100]
    esa_remapped = [4, 5, 6, 8, 9, 3, 2, 1, 7, 7, 10]
    landcover = landcover.remap(esa_original, esa_remapped)

    ee.batch.Export.image.toDrive(
        image=landcover,
        description=f"tile_{i_tile_str}_lct",
        # folder=drive_folder,
        fileNamePrefix=f"tile_{i_tile_str}_lct",
        region=region,
        scale=tile_scale,
        crs=crs,
        maxPixels=1e9,
    ).start()

    # Extract auxiliary metadata
    solar_zenith = s2.get("MEAN_SOLAR_ZENITH_ANGLE").getInfo()
    solar_azimuth = s2.get("MEAN_SOLAR_AZIMUTH_ANGLE").getInfo()
    month = ee.Date(s2.get("system:time_start")).get("month").getInfo()

    json_data = {
        "solar_zenith": round(solar_zenith, 2),
        "solar_azimuth": round(solar_azimuth, 2),
        "month": month,
        "tile_size": tile_size,
        "tile_scale": tile_scale,
        "centre_lon": round(point.coordinates().get(0).getInfo(), 4),
        "centre_lat": round(point.coordinates().get(1).getInfo(), 4),
        "tiles": {
            "rgb": f"tile_{i_tile_str}_rgb.tif",
            "dem": f"tile_{i_tile_str}_dem.tif",
            "lct": f"tile_{i_tile_str}_lct.tif",
            "cld": f"tile_{i_tile_str}_cld.tif",
        },
    }

    with open(f"{json_dir}/tile_{i_tile_str}_met.json", "w") as f:
        json.dump(json_data, f, indent=2)
    return json_data, img_id

In [None]:
import time

# n_tiles = 18000
n_tiles = i_tile + 300

while i_tile < n_tiles:
    try:
        json_data, img_id = export_tile(i_tile)
        if json_data is not None:
            print(f"Tile {i_tile} processed.")
            i_tile += 1
            used_ids.add(img_id)
            with open(used_ids_file, "w") as f:
                json.dump(list(used_ids), f, indent=2)
    except Exception as e:
        # if failed, wait 1 minute and retry
        print(f"Tile {i_tile} failed with error: {e}, retrying in 3 minutes...")
        time.sleep(60 * 3)

## Quality check and cleaning
### Refactoring file names (no longer needed)

In [None]:
import json


def refactor_metadata(json_path, save=False):

    with open(json_path, "r") as f:
        meta = json.load(f)

        # Check if tile keys are already reformatted
        tile_keys = ["tile_rgb", "tile_dem", "tile_lct", "tile_cld"]
        if "tiles" not in meta:
            tile_keys_present = [key for key in tile_keys if key in meta]
            meta["tiles"] = {
                key.replace("tile_", ""): meta[key] for key in tile_keys_present
            }

            # Remove old keys
            for key in tile_keys_present:
                del meta[key]

    if save:
        with open(json_path, "w") as f:
            json.dump(meta, f, indent=2)

    return meta

### Purging duplicate images

In [None]:
import glob
import json
import os

input_dir = "training-data/satellite/"
json_files = sorted(glob.glob(os.path.join(input_dir, "tile_*_met.json")))


seen_centers = {}
duplicate_count = 0

for json_path in json_files:
    with open(json_path, "r") as f:
        meta = json.load(f)
        centre = (
            meta["centre_lon"],
            meta["centre_lat"],
            meta["solar_azimuth"],
            meta["solar_zenith"],
            meta["tile_scale"],
        )

        if centre in seen_centers:
            keys = meta["tiles"].keys()
            tiles = [os.path.join(input_dir, meta["tiles"][key]) for key in keys]
            for tile in tiles:
                os.remove(tile)
            os.remove(json_path)
            duplicate_count += 1
        else:
            seen_centers[centre] = json_path

print(f"Purged {duplicate_count} duplicate centers.")

### Remove incomplete/invalid tiles

In [None]:
input_dir = "training-data/satellite/"
all_files = sorted(glob.glob(os.path.join(input_dir, "tile_*")))
img_files = sorted(glob.glob(os.path.join(input_dir, "*.tif")))

# remove files with paranthesis in the name
img_files = [f for f in img_files if "(" in f]

for f in img_files:
    img_base_name = os.path.basename(f)
    tile_id = img_base_name.split("_")[1]
    invalid_files = [f for f in all_files if tile_id in f]
    for invalid_file in invalid_files:
        if not os.path.exists(invalid_file):
            continue
        os.remove(invalid_file)
        print(f"Removed {invalid_file}")

In [None]:
input_dir = "training-data/satellite/"
json_files = sorted(glob.glob(os.path.join(input_dir, "tile_*_met.json")))


for json_path in json_files:
    with open(json_path, "r") as f:
        meta = json.load(f)
        keys = meta["tiles"].keys()
        img_tiles = [os.path.join(input_dir, meta["tiles"][key]) for key in keys]
        if not all(os.path.exists(tile) for tile in img_tiles):
            print(f"Missing tiles for {json_path}")

### Checking and cleaning data

In [None]:
import glob
import rasterio
import json
import os
import numpy as np


def process_from_json(
    json_path,
    input_dir,
    output_dir,
    tile_size=256,
    water_max_ratio=0.9,
    max_black_ratio=0.2,
    max_white_ratio=0.7,
    rotate=False,
    be_quiet=False,
):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        if not be_quiet:
            print(f"Created output folder {output_dir}")

    tile_base_name = os.path.basename(json_path).replace("_met.json", "")

    with open(json_path, "r") as f:
        meta = json.load(f)
        # Check tile size is what we want
        if meta.get("tile_size", None) is not tile_size:
            return False
        # Assign northing direction
        meta["north_dir"] = 0 / 360  # in normalised azimuth

    tile_size = meta.get("tile_size", 256)

    keys = meta.get("tiles", {}).keys()
    input_paths = []
    output_paths = []

    for key in keys:
        input_file = meta.get("tiles").get(key)
        input_path = os.path.join(input_dir, input_file)

        if not os.path.exists(input_path):
            if not be_quiet:
                print(f"File {input_path} not found, skipping tile")
            return False

        if key == "dem":
            with rasterio.open(input_path, "r") as src:
                dem = src.read(1)
                dem[np.isnan(dem)] = 0
            # skip if > 90% of pixels are water
            if np.mean(dem == 0) > water_max_ratio:
                if not be_quiet:
                    print(f"Skipping tile {input_file} due to too much water")
                return False

        if key == "rgb":
            with rasterio.open(input_path, "r") as src:
                rgb = src.read()
                rgb_brightness_approx = np.mean(rgb, axis=0)

            # skip if too black
            if np.mean(rgb_brightness_approx < 5) > max_black_ratio:
                if not be_quiet:
                    print(f"Skipping tile {input_file} due to too black")
                return False

            # skip if too white
            if np.mean(rgb_brightness_approx > 250) > max_white_ratio:
                if not be_quiet:
                    print(f"Skipping tile {input_file} due to too white")
                return False

        output_path = os.path.join(output_dir, input_file)

        input_paths.append(input_path)
        output_paths.append(output_path)

    if rotate:
        rotate_range = [0, 90, 180, 270]
    else:
        rotate_range = [0]
    for rotation in rotate_range:
        meta["north_dir"] = rotation / 360  # in normalised azimuth
        tile_rot_base_name = tile_base_name + f"_r{rotation:03d}"
        for input_path, key in zip(input_paths, keys):
            output_name = tile_rot_base_name + f"_{key}"
            output_path = os.path.join(output_dir, output_name + ".tif")
            truncate_tile(
                input_path,
                size=[tile_size, tile_size],
                output_path=output_path,
                rotation=rotation,
            )
            meta["tiles"][key] = output_name + ".tif"

        json_name = tile_rot_base_name + "_met.json"
        json_output = os.path.join(output_dir, json_name)
        with open(json_output, "w") as f:
            json.dump(meta, f, indent=2)

    return True


def truncate_tile(input_path, size=[256, 256], output_path=None, rotation=0):
    if output_path is None:
        output_path = input_path

    # check output folder exists
    output_dir = os.path.dirname(output_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output folder {output_dir}")
    # check input file exists
    if not os.path.exists(input_path):
        raise ValueError(f"File {input_path} not found")

    with rasterio.open(input_path, "r") as input_img:
        img = input_img.read()

        # Truncate or crop to 256x256 from top-left
        try:
            img_cropped = img[:, : size[0], : size[1]]
        except Exception as e:
            print(img.shape)
            raise ValueError(
                f"Error cropping image {input_path} to {size[0]}x{size[1]}: {e}"
            )

    if rotation != 0:
        # Rotate the image
        img_cropped = np.rot90(img_cropped, k=-rotation // 90, axes=(1, 2))

        # Update dataset dimensions in place
    with rasterio.open(
        output_path,
        "w",
        driver="GTiff",
        height=size[0],
        width=size[1],
        count=input_img.count,
        dtype=img.dtype,
        crs=input_img.crs,
        transform=input_img.transform,
    ) as output_img:
        output_img.write(img_cropped)

In [None]:
from tqdm import tqdm

json_files = sorted(glob.glob("training-data/satellite/tile_*_met.json"))
# json_files = json_files[:10]  # for testing

print(f"Found {len(json_files)} JSON files to process")

tile_size = 256
rotate = False
output_dir = f"training-data/satellite-T{tile_size:04d}-R{int(rotate)}"

success_count = 0

for json_path in tqdm(json_files, desc="Processing tiles"):
    is_success = process_from_json(
        json_path,
        "training-data/satellite/",
        output_dir,
        tile_size=tile_size,
        rotate=rotate,
        water_max_ratio=0.9,
        max_black_ratio=0.1,
        max_white_ratio=0.6,
        be_quiet=True,
    )
    if is_success:
        success_count += 1

print(
    f"Processed {success_count} ({success_count/len(json_files)*100}%) tiles successfully."
)