In [2]:
import sys
sys.path.append('../src')


In [3]:
from lora import LoRA


In [4]:
import xarray as xr
from datetime import datetime

import torch

from aurora import AuroraSmall, Batch, Metadata, rollout
import matplotlib.pyplot as plt

from pathlib import Path

import cdsapi
import numpy as np
from sklearn.metrics import root_mean_squared_error
import gcsfs

from torch.utils.data import Dataset
from aurora import Batch, Metadata
import os

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import sys
sys.path.append(os.path.abspath("../src"))
from utils import get_surface_feature_target_data, get_atmos_feature_target_data
from utils import get_static_feature_target_data, create_batch, predict_fn, rmse_weights
from utils import rmse_fn, plot_rmses

# Load model

In [6]:
from aurora import AuroraSmall

model = AuroraSmall(
    use_lora=False,  # Model was not fine-tuned.
    autocast=True,  # Use AMP.
)
# model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")


In [7]:
# torch.save(model.state_dict(), "../model/aurora-pretrained.pth")

In [34]:
model = AuroraSmall(
    use_lora=False,  # Model was not fine-tuned.
    autocast=True,  # Use AMP.
)
model.load_state_dict(torch.load('../model/aurora-pretrained.pth'))

<All keys matched successfully>

In [35]:
def print_trainable_parameters(model):
    parameters, trainable = 0, 0
    
    for _, p in model.named_parameters():
        parameters += p.numel()
        trainable += p.numel() if p.requires_grad else 0
    print(trainable)
    print(f"trainable parameters: {trainable:,}/{parameters:,} ({100 * trainable / parameters:.2f}%)")

In [36]:
print_trainable_parameters(model)

112797584
trainable parameters: 112,797,584/112,797,584 (100.00%)


In [37]:
for param in model.parameters():
    param.requires_grad = False


In [38]:
print_trainable_parameters(model)

0
trainable parameters: 0/112,797,584 (0.00%)


In [23]:
for param in model.backbone.time_mlp.parameters():
    param.requires_grad = True


In [24]:
print_trainable_parameters(model)

114951184
trainable parameters: 114,951,184/114,951,184 (100.00%)


In [14]:
for param in model.backbone.encoder_layers[0].blocks[0].norm1.ln_modulation.parameters():
    param.requires_grad = True

In [15]:
print_trainable_parameters(model)

263168
trainable parameters: 263,168/112,797,584 (0.23%)


# Get south africa Data

In [11]:
fs = gcsfs.GCSFileSystem(token="anon")

store = fs.get_mapper('gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr')
full_era5 = xr.open_zarr(store=store, consolidated=True, chunks=None)

In [12]:
start_time, end_time = '2022-12-01', '2023-01-31'


In [13]:


atmostpheric_variables = ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity", "geopotential"]
surface_vars = ['2m_temperature', '10m_u_component_of_wind', '10m_v_component_of_wind', 'mean_sea_level_pressure']
static_variables = ["land_sea_mask", "soil_type", "geopotential_at_surface"]


In [14]:

lat_max = -22.00 
lat_min = -37.75  

lon_min = 15.25   
lon_max = 35.00   
sliced_era5_SA = (
    full_era5
    .sel(
        time=slice(start_time, end_time),
        latitude=slice(lat_max, lat_min),
        longitude=slice(lon_min, lon_max)  
    )
    .isel(time=slice(None, -2))
)

target_sliced_era5_SA = (
    full_era5
    .sel(
        time=slice(start_time, end_time),
        latitude=slice(lat_max, lat_min),
        longitude=slice(lon_min, lon_max)  
    )  
    .isel(time=slice(2, None))  # Skip the first two time steps
)

In [15]:
surf_vars_ds_SA = sliced_era5_SA[surface_vars]

target_surf_vars_ds_SA = target_sliced_era5_SA[surface_vars]

atmos_vars_ds_SA = sliced_era5_SA[atmostpheric_variables]

target_atmos_vars_ds_SA = target_sliced_era5_SA[atmostpheric_variables]

static_vars_ds_SA = sliced_era5_SA[static_variables]

target_static_vars_ds_SA = target_sliced_era5_SA[static_variables]

In [16]:

class ERA5ZarrDataset(Dataset):
    def __init__(self, surf_vars_ds, atmos_vars_ds, static_vars_ds, sequence_length):
        self.surf_vars_ds = surf_vars_ds
        self.atmos_vars_ds = atmos_vars_ds
        self.static_vars_ds = static_vars_ds
        self.sequence_length = sequence_length
        self.time_indices = range(sequence_length, len(surf_vars_ds.time))

    def __len__(self):
        return len(self.time_indices)

    def __getitem__(self, idx):
        i = self.time_indices[idx]

        surf_vars = {
            "2t": torch.from_numpy(self.surf_vars_ds["2m_temperature"].values[[i - 1, i]][None]),
            "10u": torch.from_numpy(self.surf_vars_ds["10m_u_component_of_wind"].values[[i - 1, i]][None]),
            "10v": torch.from_numpy(self.surf_vars_ds["10m_v_component_of_wind"].values[[i - 1, i]][None]),
            "msl": torch.from_numpy(self.surf_vars_ds["mean_sea_level_pressure"].values[[i - 1, i]][None]),
        }

        static_vars = {
            "z": torch.from_numpy(self.static_vars_ds["geopotential_at_surface"].values),
            "slt": torch.from_numpy(self.static_vars_ds["soil_type"].values),
            "lsm": torch.from_numpy(self.static_vars_ds["land_sea_mask"].values),
        }

        atmos_vars = {
            "t": torch.from_numpy(self.atmos_vars_ds["temperature"].values[[i - 1, i]][None]),
            "u": torch.from_numpy(self.atmos_vars_ds["u_component_of_wind"].values[[i - 1, i]][None]),
            "v": torch.from_numpy(self.atmos_vars_ds["v_component_of_wind"].values[[i - 1, i]][None]),
            "q": torch.from_numpy(self.atmos_vars_ds["specific_humidity"].values[[i - 1, i]][None]),
            "z": torch.from_numpy(self.atmos_vars_ds["geopotential"].values[[i - 1, i]][None]),
        }

        metadata=Metadata(
        lat=torch.from_numpy(self.surf_vars_ds.latitude.values),
        lon=torch.from_numpy(self.surf_vars_ds.longitude.values),
        time=(self.surf_vars_ds.time.values.astype("datetime64[s]").tolist()[i],),
        atmos_levels=tuple(int(level) for level in self.atmos_vars_ds.level.values)
    )


        return Batch(surf_vars=surf_vars, static_vars=static_vars, atmos_vars=atmos_vars, metadata=metadata)


In [17]:
SA_batches = ERA5ZarrDataset(surf_vars_ds_SA, atmos_vars_ds_SA, static_vars_ds_SA,1)
target_SA_batches = ERA5ZarrDataset(target_surf_vars_ds_SA, target_atmos_vars_ds_SA, target_static_vars_ds_SA,1)

# Fine tuning process

In [29]:
import torch.optim as optim
import torch.nn as nn

In [39]:
class LoRA(nn.Module):
    def __init__(self, original_layer, rank=4):
        super(LoRA, self).__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.lora_A = nn.Parameter(torch.randn(rank, original_layer.in_features))
        self.lora_B = nn.Parameter(torch.randn(original_layer.out_features, rank))

    def forward(self, x):
        return self.original_layer(x) + (x @ self.lora_A.T) @ self.lora_B.T


In [40]:
def apply_lora_to_model(model, rank=4):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            setattr(model, name, LoRA(module, rank))
        else:
            apply_lora_to_model(module, rank)
    return model


model = apply_lora_to_model(model, rank=4)


In [41]:
model = apply_lora_to_model(model, rank=4)

In [42]:
print_trainable_parameters(model)

2153600
trainable parameters: 2,153,600/114,951,184 (1.87%)


In [None]:
import torch.optim as optim
import torch.nn as nn

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

model.train()
for epoch in range(1):
    print("start")
    for inputs, targets in zip(SA_batches, target_SA_batches):
        print("OK")
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        print(loss)
        loss.backward()
        optimizer.step()


start


In [None]:
torch.save(model.state_dict(), "aurora_lora_finetuned.pth")


In [45]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

In [None]:

sa_latitudes = target_sliced_era5_SA.latitude
sa_longitudes = target_sliced_era5_SA.longitude

selected_times =  target_sliced_era5_SA.time
sa_rmses_list=[]
for i in range(0, len(target_sliced_era5_SA.time)-3):
    # get current and previous time step data
    world_feature_data =  (
            target_sliced_era5_SA
            .sel(time=slice(selected_times[i], selected_times[i+1]))
        )
    sa_feature_data =  (
            sliced_era5_SA
            .sel(time=slice(selected_times[i], selected_times[i+1]))
        )

    sa_target_data =  (
            target_sliced_era5_SA
            .sel(time=slice(selected_times[i+2], selected_times[i+3]))
        )
    
    # get each type of data(surface, static atmosphere)

    sa_feature_surface_data, sa_target_surface_data = get_surface_feature_target_data(sa_feature_data, sa_target_data)
    sa_feature_atmos_data, sa_target_atmos_data = get_atmos_feature_target_data(sa_feature_data, sa_target_data)
    sa_feature_static_data, sa_target_static_data = get_static_feature_target_data(sa_feature_data, sa_target_data)
    
    # create batch for each of them

    sa_feature_bacth =  create_batch(sa_feature_surface_data, sa_feature_atmos_data, sa_feature_static_data)
    sa_target_bacth = create_batch(sa_target_surface_data, sa_target_atmos_data, sa_target_static_data)
    # get prediction
    # sa_predictions = predict_fn(batch=sa_feature_bacth)
    # # compute the rmse
    
    # sa_rmses, sa_pred_dates = rmse_fn(predictions=sa_predictions, 
    #         target_batch=sa_target_bacth, var_name="2t",
    #         weigths=sa_rmse_weights, area="sa")
    # # append result to the list
    # world_rmses_list.append(world_rmses); pred_dates_list.append(world_pred_dates)
    # sa_rmses_list.append(sa_rmses)
    print("Start")
    
    optimizer.zero_grad()
    outputs = model(sa_feature_bacth)
    loss = criterion(outputs, sa_target_bacth)
    print(loss)
    loss.backward()
    optimizer.step()
    
# model.configure_activation_checkpointing()

# pred = model.forward(batch)
# loss = ...
# loss.backward()

In [48]:

sa_latitudes = target_sliced_era5_SA.latitude
sa_longitudes = target_sliced_era5_SA.longitude
sa_rmse_weights = rmse_weights(sa_latitudes, sa_longitudes)
selected_times =  target_sliced_era5_SA.time
sa_rmses_list=[]
for i in range(0, len(target_sliced_era5_SA.time)-3):
    # get current and previous time step data
    world_feature_data =  (
            target_sliced_era5_SA
            .sel(time=slice(selected_times[i], selected_times[i+1]))
        )
    sa_feature_data =  (
            sliced_era5_SA
            .sel(time=slice(selected_times[i], selected_times[i+1]))
        )

    sa_target_data =  (
            target_sliced_era5_SA
            .sel(time=slice(selected_times[i+2], selected_times[i+3]))
        )
    
    # get each type of data(surface, static atmosphere)

    sa_feature_surface_data, sa_target_surface_data = get_surface_feature_target_data(sa_feature_data, sa_target_data)
    sa_feature_atmos_data, sa_target_atmos_data = get_atmos_feature_target_data(sa_feature_data, sa_target_data)
    sa_feature_static_data, sa_target_static_data = get_static_feature_target_data(sa_feature_data, sa_target_data)
    
    # create batch for each of them

    sa_feature_bacth =  create_batch(sa_feature_surface_data, sa_feature_atmos_data, sa_feature_static_data)
    sa_target_bacth = create_batch(sa_target_surface_data, sa_target_atmos_data, sa_target_static_data)
    # get prediction
    # sa_predictions = predict_fn(batch=sa_feature_bacth)
    # # compute the rmse
    
    # sa_rmses, sa_pred_dates = rmse_fn(predictions=sa_predictions, 
    #         target_batch=sa_target_bacth, var_name="2t",
    #         weigths=sa_rmse_weights, area="sa")
    # # append result to the list
    # world_rmses_list.append(world_rmses); pred_dates_list.append(world_pred_dates)
    # sa_rmses_list.append(sa_rmses)
    print("Start")
    


    model = model.cuda()
    model.train()
    model.configure_activation_checkpointing()

    pred = model.forward(sa_feature_bacth)
    loss , _ = rmse_fn(predictions=[pred], 
             target_batch=sa_target_bacth, var_name="2t",
             weigths=sa_rmse_weights, area="sa")
    loss.backward()

Start


AttributeError: 'list' object has no attribute 'backward'

In [None]:
sa_rmses, _ = rmse_fn(predictions=pred, 
             target_batch=sa_target_bacth, var_name="2t",
             weigths=sa_rmse_weights, area="sa")