In [None]:
import os
import math
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import geopandas as gpd
import torch
import rasterio
import ee

from shapely.geometry import Point
from rasterio.warp import transform

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Fill missing data of WorldClim, SoilGrids, and human footprint with nearest non-missing value
def find_nearest_non_missing(data, row, col, no_data_value, max_radius=100):
    rows, cols = data.shape
    for radius in range(1, max_radius + 1):
        for dy in range(-radius, radius + 1):
            for dx in range(-radius, radius + 1):
                r, c = row + dy, col + dx
                if 0 <= r < rows and 0 <= c < cols and not np.isclose(data[r, c], no_data_value, atol=0):
                    return data[r, c].item()
    return None  # Return None if no valid value is found within the max_radius

## Location

In [None]:
# Define the bounding box for the area of interest

# Europe
xmin, xmax = -10, 31
ymin, ymax = 36, 56
resolution = 10 / 111

# World
#xmin, xmax = -180, 180
#ymin, ymax = -60, 85
#resolution = 50 / 111

In [None]:
url = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip"
world = gpd.read_file(url)

def generate_grid(xmin, xmax, ymin, ymax, resolution):
    # Generate a grid of points with corresponding resolution
    x_coords = np.arange(xmin, xmax, resolution)
    y_coords = np.arange(ymin, ymax, resolution)
    grid_points = [Point(x, y) for x in x_coords for y in y_coords]

    grid_gdf = gpd.GeoDataFrame(geometry=grid_points)

    # Filter grid points that are on land
    land_gdf = grid_gdf[grid_gdf.within(world.unary_union)]
    land_points_coordinates = land_gdf.geometry.apply(lambda point: (point.x, point.y)).tolist()
    return land_points_coordinates

In [None]:
land_points_coordinates = generate_grid(xmin, xmax, ymin, ymax, resolution)
np.save("europe_map_coordinates.npy", land_points_coordinates)
# np.save("world_map_coordinates.npy", land_points_coordinates)
print("Total number of land points:", len(land_points_coordinates))

In [None]:
locations = np.load("europe_map_coordinates.npy")
# locations = np.load("world_map_coordinates.npy")

## WorldClim

In [None]:
worldclim_variables = ['bio_' + str(i+1) for i in range(19)]
worldclim_data = np.zeros((len(locations), 19), dtype="float32")

no_data_value = -3.4e+38

for j, wv in enumerate(worldclim_variables):
    print(f"Processing {wv}")
    with rasterio.open(f"worldclim/wc2.1_30s_{wv}.tif") as src:

        data = src.read(1)
        for i, val in enumerate(src.sample(locations)):
            if np.isclose(val, no_data_value, atol=0):
                x, y = locations[i]
                row, col = src.index(x, y)
                val = find_nearest_non_missing(data, row, col, no_data_value)
            worldclim_data[i, j] = val

In [None]:
np.save("europe_map_worldclim.npy", worldclim_data)
# np.save("world_map_worldclim.npy", worldclim_data)

## SoilGrids

In [None]:
soilgrid_data = np.zeros((len(locations), 8))
soil_variables = []

for j, soil_file in enumerate(os.listdir("soilgrids250")):
    soil_variable = soil_file[:6]
    soil_variables.append(soil_variable)
    print(f"Processing {soil_variable}")
    with rasterio.open(f"soilgrids250/{soil_file}") as src:
        if soil_variable in ["ORCDRC", "CECSOL", "BDTICM", "BLDFIE"]:
            no_data_value = -32768.0
        elif soil_variable in ["PHIHOX", "CLYPPT", "SLTPPT", "SNDPPT"]:
            no_data_value = 255
        else:
            raise ValueError(f"Unknown missing value for {soil_variable}")
        data = src.read(1)
        for i, val in enumerate(src.sample(locations)):
            if val == no_data_value:
                x, y = locations[i]
                row, col = src.index(x, y)
                val = find_nearest_non_missing(data, row, col, no_data_value)
            soilgrid_data[i, j] = val

In [None]:
np.save("europe_map_soilgrids.npy", soilgrid_data)
# np.save("world_map_soilgrids.npy", soilgrid_data)

## Topographic

In [None]:
ee.Authenticate(auth_mode="notebook")
ee.Initialize(project="TOFILL")

In [None]:
batch_size = 1000
num_batches = math.ceil(len(locations) / batch_size)

all_values = []

print(f"#batches: {num_batches}")

for b in range(num_batches):
    
    if b % 25 == 0:
        print(f"Batch: {b}")

    batch_locations = locations[b*batch_size:(b+1)*batch_size]
    
    point_list = []
    for lon, lat in batch_locations:
        point = ee.Geometry.Point(lon, lat)
        point_list.append(ee.Feature(point))

    feature_collection = ee.FeatureCollection(point_list)

    # Load SRTM DEM dataset and compute slope and aspect
    dataset = ee.Image('CGIAR/SRTM90_V4')
    elevation = dataset.select('elevation')
    slope = ee.Terrain.slope(elevation)
    aspect = ee.Terrain.aspect(elevation)

    # Combine elevation, slope, and aspect into a single image
    terrain_image = elevation.addBands(slope).addBands(aspect).rename(['elevation', 'slope', 'aspect'])

    # Sample the image at the feature locations
    sampled_values = terrain_image.reduceRegions(
        collection=feature_collection,
        reducer=ee.Reducer.first(),
        scale=90  # SRTM has a resolution of 90m
    )
    
    values = sampled_values.getInfo()
    all_values.append(values)
    
all_results = []
i = 0
for values in all_values:
    for feature in values["features"]:
        all_results.append([feature["properties"].get('elevation'), feature["properties"].get('slope'), feature["properties"].get('aspect')])
        
all_results = np.array(all_results)

In [None]:
np.save("europe_map_topography.npy", all_results.astype("float32"))

## Human infuence

In [None]:
human_data = np.zeros((len(locations), 9))

human_variables = ["HFP2009", "Built2009", "Croplands2005", "Lights2009", "Navwater2009", "Pasture2009", "Popdensity2010", "Railways", "Roads"]

for j, human_file in enumerate(["HFP2009.tif", "Built2009.tif", "croplands2005.tif", "Lights2009.tif", "Navwater2009.tif", "Pasture2009.tif", "Popdensity2010.tif", "Railways.tif", "Roads.tif"]):
    with rasterio.open(f"human_footprint_venter/Dryadv2/Maps/{human_file}") as src:
        print(f"Processing {human_file}")
        
        raster_crs = src.crs

        # Transform coordinates to the raster CRS
        longitudes = locations[:, 0]
        latitudes = locations[:, 1]
        x_coords, y_coords = transform('EPSG:4326', raster_crs, longitudes, latitudes)

        # Sample the raster at each coordinate
        data = src.read(1)
        for i, val in enumerate(src.sample(zip(x_coords, y_coords))):
            if val[0] == src.nodata:
                row, col = src.index(x_coords[i], y_coords[i])
                val = find_nearest_non_missing(data, row, col, src.nodata)
                human_data[i, j] = val
            else:
                human_data[i, j] = val[0]

In [None]:
np.save("europe_map_human.npy", human_data)

## Satclip

In [None]:
from satclip.satclip.load import get_satclip

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = get_satclip('satclip-vit16-l40.ckpt', device=device) # Only loads location encoder by default
model.eval()
with torch.no_grad():
    emb = model(torch.Tensor(locations).double().to(device)).detach().cpu()

In [None]:
np.save('europe_satclip_embeddings.npy', emb.numpy())

#