In [5]:
from aurora import Aurora, rollout
import torch

In [1]:
import xarray as xr
from datetime import datetime

import torch

from aurora import AuroraSmall, Batch, Metadata, rollout
import matplotlib.pyplot as plt

from pathlib import Path

import cdsapi
import numpy as np
from sklearn.metrics import root_mean_squared_error
import gcsfs

from torch.utils.data import Dataset
from aurora import Batch, Metadata
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
sys.path.append(os.path.abspath("../src"))
from utils import get_surface_feature_target_data, get_atmos_feature_target_data
from utils import get_static_feature_target_data, create_batch, predict_fn, rmse_weights
from utils import rmse_fn, plot_rmses
from utils import rmse_fn, plot_rmses, create_hrest0_batch

In [6]:
model = Aurora(
    use_lora=False,  # Model was not fine-tuned.
)


model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
torch.save(model.state_dict(), "../model/aurora-0.25-pretrained_big.pth")




In [None]:
model = Aurora(
    use_lora=False,  # Model was not fine-tuned.
)


model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
torch.save(model.state_dict(), "../model/aurora-0.25-pretrained_big.pth")




In [8]:
model = AuroraSmall(
    use_lora=False,  # Model was not fine-tuned.
)


model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
torch.save(model.state_dict(), "../model/urora-0.25-small-pretrained1.pth")




# Load model saved

In [4]:
model = AuroraSmall(
    use_lora=False,  # Model was not fine-tuned.
    autocast=True,  # Use AMP.
    stabilise_level_agg=True
)


In [5]:
model.load_state_dict(torch.load('../model/aurora-0.25-pretrained.pth'))

<All keys matched successfully>

# Get some data

## Era5

In [3]:
fs = gcsfs.GCSFileSystem(token="anon")

store = fs.get_mapper('gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr')
full_era5 = xr.open_zarr(store=store, consolidated=True, chunks=None)

# Hres t0

In [4]:
store_hrest0 = fs.get_mapper('gs://weatherbench2/datasets/hres_t0/2016-2022-6h-1440x721.zarr')
full_hrest0 = xr.open_zarr(store=store_hrest0, consolidated=True, chunks=None)

In [5]:
full_hrest0.level

### Subset data

#### world

In [8]:
start_time, end_time = '2022-11-02', '2022-11-05'

sliced_era5_world = (
    full_era5
    .sel(time=slice(start_time, end_time))
)

sliced_hrest0_world = full_hrest0.sel(time=slice(start_time, end_time))


# Retrieve the dates of observation where the model returns NAN values

# Run the big model

In [None]:
# Constants
STATIC_VARIABLES = ["land_sea_mask", "soil_type", "geopotential_at_surface"]
surface_vars_names = ["2t", "10u", "10v", "msl"]
selected_times = sliced_hrest0_world.time

# Compute RMSE weights once
world_rmse_weights = rmse_weights(
    sliced_hrest0_world.latitude, 
    sliced_hrest0_world.longitude
)[1:, :]

# Initialize result lists for this iteration
nan_dates_list = []

############# Model
for i in range(len(selected_times) - 3):
    # Get feature and target data for this timestep
    world_feature_hrest0_data = sliced_hrest0_world.sel(
        time=slice(selected_times[i], selected_times[i+1])
    )
    
    world_target_hrest0_data = sliced_hrest0_world.sel(
        time=slice(selected_times[i+2], selected_times[i+3])
    )
    
    world_feature_era5_data = sliced_era5_world.sel(
        time=slice(selected_times[i], selected_times[i+1])
    )
    
    world_target_era5_data = sliced_era5_world.sel(
        time=slice(selected_times[i+2], selected_times[i+3])
    )
    
    # Extract features and targets for all surface variables at once
    world_feature_surface_data, world_target_surface_data = get_surface_feature_target_data(
        world_feature_hrest0_data, 
        world_target_hrest0_data
    )
    
    world_feature_atmos_data, world_target_atmos_data = get_atmos_feature_target_data(
        world_feature_hrest0_data, 
        world_target_hrest0_data
    )
    
    world_feature_static_data, world_target_static_data = get_static_feature_target_data(
        world_feature_era5_data, 
        world_target_era5_data, 
        STATIC_VARIABLES
    )
    
    # Create batches for all surface variables at once
    world_feature_batch = create_hrest0_batch(
        world_feature_surface_data, 
        world_feature_atmos_data, 
        world_feature_static_data
    )
    
    world_target_batch = create_hrest0_batch(
        world_target_surface_data, 
        world_target_atmos_data, 
        world_target_static_data
    )
    
    # Predictions for all surface variables
    world_predictions = predict_fn(
        model=model,
        batch=world_feature_batch
    )
    
    for var in surface_vars_names:
        world_rmses, _ = rmse_fn(
            predictions=world_predictions, 
            target_batch=world_target_batch,
            var_name=var, 
            weigths=world_rmse_weights
        )
        if  np.isnan(world_rmses).any():
            print(world_rmses)
            nan_dates_list.append(world_feature_batch.metadata.time)
            break