In [1]:
import xarray as xr
from datetime import datetime

import torch

from aurora import AuroraSmall, Batch, Metadata, rollout
import matplotlib.pyplot as plt

from pathlib import Path

import cdsapi
import numpy as np
from sklearn.metrics import root_mean_squared_error
import gcsfs

from torch.utils.data import Dataset
from aurora import Batch, Metadata
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np

In [3]:
import sys
sys.path.append(os.path.abspath("../src"))
from utils import get_surface_feature_target_data, get_atmos_feature_target_data
from utils import get_static_feature_target_data, create_batch, predict_fn, rmse_weights
from utils import rmse_fn, plot_rmses, create_hrest0_batch

In [4]:
from evaluation import evaluation
from lora import create_custom_model


# Data

In [5]:

fs = gcsfs.GCSFileSystem(token="anon")

store = fs.get_mapper('gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr')
full_era5 = xr.open_zarr(store=store, consolidated=True, chunks=None)



# start_time, end_time = '2022-11-01', '2023-01-31'
start_time, end_time = '2022-01-01', '2022-01-03' #'2021-12-31'
# start_time, end_time = '2023-01-08', '2023-01-31'



lat_max = -22.00 
lat_min = -37.75  

lon_min = 15.25   
lon_max = 35.00   
sliced_era5_SA = (
    full_era5
    .sel(
        time=slice(start_time, end_time),
        latitude=slice(lat_max, lat_min),
        longitude=slice(lon_min, lon_max)  
    )
)

################################"" get hres data
store_hrest0 = fs.get_mapper('gs://weatherbench2/datasets/hres_t0/2016-2022-6h-1440x721.zarr')
full_hrest0 = xr.open_zarr(store=store_hrest0, consolidated=True, chunks=None)
sliced_hrest0_sa = full_hrest0.sel(time=slice(start_time, end_time), 
                                   latitude=slice(lat_min, lat_max), 
                                   longitude=slice(lon_min, lon_max))



# Models

In [6]:
model_initial = AuroraSmall()

model_initial.load_state_dict(torch.load('../model/aurora.pth'))

<All keys matched successfully>

In [7]:
fine_tuned_model = AuroraSmall()
fine_tuned_model = AuroraSmall(
    use_lora=False,  # fine_tuned_Model was not fine-tuned.
    autocast=True,  # Use AMP.
)
fine_tuned_model = create_custom_model(fine_tuned_model, lora_r = 8, lora_alpha = 16)
checkpoint = torch.load('../model/training/hrest0/checkpoint_epoch_1.pth')

fine_tuned_model.load_state_dict(checkpoint['model_state_dict'])


<All keys matched successfully>

In [8]:

results = evaluation(fine_tuned_model, model_initial, sliced_era5_SA, sliced_hrest0_sa)


torch.Size([2, 64, 80])
torch.Size([2, 64, 80])
torch.Size([2, 64, 80])
torch.Size([2, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Size([2, 13, 64, 80])
torch.Si

KeyError: 'surface_rmses_fine_tuned_model'

In [9]:
counter = results['counter']
surface_rmses_fine_tuned = results['surface_rmses_fine_tuned']
atmospheric_rmses_fine_tuned = results['atmospheric_rmses_fine_tuned']
surface_rmses_non_fine_tuned = results['surface_rmses_non_fine_tuned']
atmospheric_rmses_non_fine_tuned = results['atmospheric_rmses_non_fine_tuned']

In [18]:
surface_rmses_fine_tuned
atmospheric_rmses_fine_tuned["z"]

array([[112.88531494, 104.86317444, 108.99510193,  86.60694885,
        129.3555603 ,  86.82122803, 121.90444946,  97.22524261],
       [137.08757019, 126.01667023, 122.29566193,  96.03895569,
        130.85035706,  94.33375549, 131.88903809,  98.2884903 ],
       [150.93162537, 138.71955872, 129.84275818, 107.51730347,
        146.92102051, 109.06533051, 144.121521  , 108.76496887],
       [159.08731079, 152.52082825, 136.38058472, 110.89865112,
        155.52893066, 115.36819458, 150.75010681, 113.94164276],
       [148.19334412, 138.41098022, 130.07852173, 106.17796326,
        142.08538818, 103.62815094, 141.52178955, 105.71464539],
       [137.09613037, 124.01824951, 115.9083252 ,  90.1552887 ,
        127.59199524,  95.03741455, 125.45706177,  93.01181793],
       [110.51914978, 102.21260071,  92.29176331,  77.02348328,
        103.15515137,  79.1570282 , 101.21658325,  79.46543884],
       [ 91.01585388,  87.10150146,  79.53807068,  61.61869049,
         84.9493866 ,  60.7066421

In [17]:
surface_rmses_non_fine_tuned
atmospheric_rmses_non_fine_tuned["z"]


array([[ 5.27385998,  8.407691  , 12.59131432, 19.2553196 , 25.17331886,
        27.77643967, 30.57715034, 37.59294128],
       [ 4.25438213,  7.63996601, 11.8037281 , 11.32155609, 11.43088341,
        13.26637459, 13.44604301, 11.35198975],
       [ 5.00291348,  9.86188698, 13.97521973, 14.91846561, 15.71817589,
        18.06719971, 18.30081367, 16.84731293],
       [ 4.19200993,  7.62254715, 10.44234848, 11.44184685, 12.36076546,
        13.44333649, 14.7951889 , 16.20301437],
       [ 3.32055855,  5.4545908 ,  6.74199867,  7.02580118,  8.40130806,
         9.5207653 , 11.23363304, 12.91157055],
       [ 2.77650118,  4.35646343,  5.2814703 ,  5.03923798,  6.32980728,
         8.0811367 ,  9.89044189, 10.81741524],
       [ 2.26897621,  2.65574074,  3.70426846,  3.28074837,  4.52630758,
         6.55645561,  8.41282654,  8.32696533],
       [ 2.09883285,  2.25853634,  3.69951558,  2.47040367,  3.35075045,
         5.22627449,  6.9094882 ,  5.97267437],
       [ 2.05786943,  2.43028498

In [11]:
t0, t1, t2, t3 = selected_times[0], selected_times[0+1], selected_times[0+2], selected_times[0+9]

# Load required time slices once 
sa_feature_hrest0_data = sliced_hrest0_sa.sel(time=slice(t0, t1))
sa_feature_era5_data = sliced_era5_SA.sel(time=slice(t0, t1))
sa_target_era5_data = sliced_era5_SA.sel(time=slice(t0, t1))

sa_targets_hrest0_data = sliced_hrest0_sa.sel(time=slice(t2, t3))


# Extract features and targets
sa_feature_surface_data, sa_target_surface_data = get_surface_feature_target_data(sa_feature_hrest0_data, sa_targets_hrest0_data)
sa_feature_atmos_data, sa_target_atmos_data = get_atmos_feature_target_data(sa_feature_hrest0_data, sa_targets_hrest0_data)
sa_feature_static_data, sa_target_static_data = get_static_feature_target_data(sa_feature_era5_data, sa_target_era5_data)

# Create input and target batches
input_batch = create_hrest0_batch(sa_feature_surface_data, sa_feature_atmos_data, sa_feature_static_data).to("cuda")
target_batch = create_hrest0_batch(sa_target_surface_data, sa_target_atmos_data, sa_target_static_data).to("cuda")


In [12]:
target_batch= target_batch.surf_vars["2t"]

In [13]:
target_batch.shape

torch.Size([1, 2, 64, 80])