In [2]:
import requests

import io
import itertools
from retry import retry
import multiprocessing as mp

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import ee

import torch
import torchvision
from torchvision import transforms

import shapely.geometry
from shapely.geometry import Point

import geopandas as gpd

from model.windtopocali import WindTopoCali

ee_crs = 'EPSG:4326'
meter_crs = 'EPSG:5070'
gsd = 100
rtma_gsd = 2500

weather_means = [0,0,0,0,0,0,0]
weather_std = [1,1,1,1,1,1,1]
topo_means = [0,0,0,0]
topo_std = [1,1,1,1]

ee.Authenticate()
ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com',project='eofm-benchmark')

---

Load Saved Model + torch initializations

---

In [3]:
model = WindTopoCali(
    n_weather_channels=7,
    wind_in_size=64,
    n_topo_channels=4,
    topo_in_size=128
)

model.load_state_dict(torch.load('pretrained_model.pth',weights_only=True))
model.eval()

WindTopoCali(
  (lr_weather_net): ResNet(
    (conv1): Conv2d(7, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer2): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(

In [20]:
weather_hr_transform = torch.nn.Sequential(
    transforms.CenterCrop(16),
    transforms.Normalize(mean=weather_means,std = weather_std),
    transforms.Resize(64)
)

weather_lr_transform = torch.nn.Sequential(
    transforms.CenterCrop(8),
    transforms.Normalize(mean=weather_means,std = weather_std),
    transforms.Resize(64)
)
topo_hr_transform = torch.nn.Sequential(
    transforms.CenterCrop(400),
    transforms.Normalize(mean = topo_means,std=topo_std),
    transforms.Resize(128)
)

topo_lr_transform = torch.nn.Sequential(
    transforms.CenterCrop(100),
    transforms.Normalize(mean = topo_means,std=topo_std),
    transforms.Resize(128)
)

---

Define Sample Region & Get Coordinates

---

In [49]:
center_lon, center_lat = -123.14, 41.12
time = '2024-01-01T00:00:00'
roi_size = 1000

In [50]:
point = Point([center_lon,center_lat])

pp = gpd.GeoSeries(point).set_crs(ee_crs).to_crs(meter_crs).buffer(roi_size)

top, left, bottom, right = pp.total_bounds

gdf_grid = gpd.GeoDataFrame(
    geometry=[
        shapely.geometry.Point(x,y)
        for x in np.arange(top,bottom,gsd)
        for y in np.arange(left,right,gsd)
    ],
    crs=meter_crs
).to_crs(ee_crs)

In [51]:
dt = ee.Date(time)

rtma_proj = ee.Projection(crs='EPSG:5070').atScale(meters=rtma_gsd)

topo_proj = ee.Projection(crs='EPSG:5070').atScale(meters=gsd)


In [52]:
@retry(tries=100,delay=3,backoff=2)
def build_request(idx,point,date=dt,proj_rtma=rtma_proj,proj_topo=topo_proj):
    lon, lat = point.x, point.y


    rtma_projection = proj_rtma.getInfo()
    topo_projection = proj_topo.getInfo()
    
    rtma_point = ee.Geometry.Point(coords=[lon,lat]).transform(proj_rtma)
    rtma_box = rtma_point.buffer(distance=25000).bounds(proj=proj_rtma,maxError=0.1)

    topo_point = ee.Geometry.Point(coords=[lon,lat]).transform(proj_topo)
    topo_box = topo_point.buffer(distance=25000).bounds(proj=proj_topo,maxError=0.1)


    #terrain
    dem = ee.Image('USGS/SRTMGL1_003').clip(topo_box).select('elevation')
    slope = ee.Terrain.slope(dem).rename('slope')
    aspect = ee.Terrain.aspect(dem).rename('aspect')
    mtpi = ee.Image('CSP/ERGo/1_0/Global/SRTM_mTPI').clip(topo_box).select('elevation').rename('mtpi')

    topo_frame = dem.addBands(slope).addBands(aspect).addBands(mtpi)


    topo_crs = topo_projection['crs']
    topo_trns = topo_projection['transform']

    topo_url = topo_frame.getDownloadURL(
        {
            'bands':['elevation','slope','aspect','mtpi'],
            'region':topo_box,
            'scale':100,
            'crs':topo_crs,
            'crsTransform':topo_trns,
            'format':'NPY'
        }
    )
    
    # print(f'Topo Download URL: {topo_url}')

    #weather
    rtma_frame = ee.ImageCollection('NOAA/NWS/RTMA') \
        .filterDate(date,date.advance(1,'hour')) \
        .filterBounds(rtma_box) \
        .first() \
        .clip(rtma_box) \
        .select(['TMP','UGRD','VGRD','SPFH','ACPC01','HGT','GUST'])
    
    rtma_crs = rtma_projection['crs']
    rtma_transform = rtma_projection['transform']

    rtma_url = rtma_frame.getDownloadURL(
        {
            'bands':['TMP','UGRD','VGRD','SPFH','ACPC01','HGT','GUST'],
            'region':rtma_box,
            'scale':2500,
            'crs':rtma_crs,
            'crsTransform':rtma_transform,
            'format':'NPY'
        }
    )

    response = requests.get(rtma_url)
    if response.status_code != 200:
        raise response.raise_for_status()
    rtma_data = np.load(io.BytesIO(response.content))
    
    print(f'ID: {idx} RTMA loaded')
    
    response = requests.get(topo_url)
    if response.status_code != 200:
        raise response.raise_for_status()

    topo_data = np.load(io.BytesIO(response.content))
    print(f'ID: {idx} Topo Loaded')
    

    data = {
        'coordinates':[lon,lat],
        'rtma':rtma_data,
        'topo':topo_data
    }

    return idx,data

In [53]:
pool = mp.Pool(25)
results = pool.starmap(build_request,enumerate(gdf_grid.geometry))
pool.close()
pool.join()

ID: 25 RTMA loaded
ID: 75 RTMA loaded
ID: 0 RTMA loaded
ID: 50 RTMA loaded
ID: 0 Topo Loaded
ID: 50 Topo Loaded
ID: 75 Topo Loaded
ID: 76 RTMA loadedID: 1 RTMA loadedID: 51 RTMA loaded


ID: 25 Topo Loaded
ID: 1 Topo Loaded
ID: 76 Topo Loaded
ID: 51 Topo Loaded
ID: 26 RTMA loaded
ID: 26 Topo Loaded
ID: 52 RTMA loaded
ID: 77 RTMA loaded
ID: 2 RTMA loaded
ID: 52 Topo Loaded
ID: 27 RTMA loaded
ID: 2 Topo Loaded
ID: 77 Topo Loaded
ID: 27 Topo Loaded
ID: 53 RTMA loaded
ID: 53 Topo Loaded
ID: 78 RTMA loaded
ID: 3 RTMA loaded
ID: 28 RTMA loaded
ID: 28 Topo Loaded
ID: 78 Topo Loaded
ID: 54 RTMA loaded
ID: 3 Topo Loaded
ID: 29 RTMA loaded
ID: 54 Topo Loaded
ID: 79 RTMA loaded
ID: 29 Topo Loaded
ID: 55 RTMA loaded
ID: 79 Topo Loaded
ID: 4 RTMA loaded
ID: 80 RTMA loaded
ID: 55 Topo Loaded
ID: 4 Topo Loaded
ID: 80 Topo Loaded
ID: 56 RTMA loaded
ID: 5 RTMA loaded
ID: 56 Topo Loaded
ID: 81 RTMA loaded
ID: 5 Topo Loaded
ID: 81 Topo Loaded
ID: 57 RTMA loaded
ID: 6 RTMA loaded
ID: 82 RTMA loaded
ID: 6 

In [54]:
results_dict = dict(results)

In [55]:
weather_bands = ['TMP','UGRD','VGRD','SPFH','ACPC01','HGT','GUST']
topo_bands = ['elevation','slope','aspect','mtpi']

In [None]:
weather_ims = [np.stack([results_dict[i]['rtma'][band] for band in weather_bands]) for i in results_dict.keys()] 
topo_ims = [np.stack([results_dict[i]['topo'][band] for band in topo_bands]) for i in results_dict.keys()]
im_source_coords = [results_dict[i]['coords'] for i in results_dict.keys()]

weather_lr_tensors = []
weather_hr_tensors = []
topo_lr_tensors = []
topo_hr_tensors = []

for i, (weather_im, topo_im) in enumerate(zip(weather_ims,topo_ims)):
    weather_tensor = torch.from_numpy(weather_im).float()
    topo_tensor = torch.from_numpy(topo_im).float()

    weather_lr_tensor = weather_lr_transform(weather_tensor)
    weather_hr_tensor = weather_hr_transform(weather_tensor)

    topo_lr_tensor = topo_lr_transform(topo_tensor)
    topo_hr_tensor = topo_hr_transform(topo_tensor)

    weather_lr_tensors.append(weather_lr_tensor)
    weather_hr_tensors.append(weather_hr_tensor)

    topo_lr_tensors.append(topo_lr_tensor)
    topo_hr_tensors.append(topo_hr_tensor)



weather_lr_tensors = torch.stack(weather_lr_tensors)
weather_hr_tensors = torch.stack(weather_hr_tensors)

topo_lr_tensors = torch.stack(topo_lr_tensors)
topo_hr_tensors = torch.stack(topo_hr_tensors)

In [None]:
weather_ims

In [63]:
batch = {
    'rtma':weather_hr_tensors,
    'rtma_lr':weather_lr_tensors,
    'topo':topo_hr_tensors,
    'topo_lr':topo_lr_tensors
}
predicitions = model(batch).detach().numpy()