In [1]:
import sys
sys.path.append('../src')


In [2]:
from lora import LoRA


In [3]:
import xarray as xr
from datetime import datetime

import torch

from aurora import AuroraSmall, Batch, Metadata, rollout
import matplotlib.pyplot as plt

from pathlib import Path

import cdsapi
import numpy as np
from sklearn.metrics import root_mean_squared_error
import gcsfs

from torch.utils.data import Dataset
from aurora import Batch, Metadata
import os

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import sys
sys.path.append(os.path.abspath("../src"))
from utils import get_surface_feature_target_data, get_atmos_feature_target_data
from utils import get_static_feature_target_data, create_batch, predict_fn, rmse_weights
from utils import rmse_fn, plot_rmses, custom_rmse

In [5]:
from batch import ERA5ZarrDataset

In [6]:
fs = gcsfs.GCSFileSystem(token="anon")

store = fs.get_mapper('gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr')
full_era5 = xr.open_zarr(store=store, consolidated=True, chunks=None)

In [7]:
start_time, end_time = '2022-12-01', '2023-01-31'

In [8]:

lat_max = -22.00 
lat_min = -37.75  

lon_min = 15.25   
lon_max = 35.00   
sliced_era5 = (
    full_era5
    .sel(
        time=slice(start_time, end_time),
        latitude=slice(lat_max, lat_min),
        longitude=slice(lon_min, lon_max)  
    )
)



In [9]:
training_data = ERA5ZarrDataset(sliced_era5)

In [10]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=8, pin_memory=True, shuffle=False)

In [11]:
data = iter(train_dataloader)

In [None]:
train_inputs, train_labels = next(data)


In [None]:
train_inputs

# Load model

In [5]:
# from aurora import AuroraSmall

# model = AuroraSmall(
#     use_lora=False,  # Model was not fine-tuned.
#     autocast=True,  # Use AMP.
# )
# # model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")


In [6]:
# torch.save(model.state_dict(), "../model/aurora-pretrained.pth")

In [7]:
model = AuroraSmall(
    use_lora=False,  # Model was not fine-tuned.
    autocast=True,  # Use AMP.
)
model.load_state_dict(torch.load('../model/aurora-pretrained.pth'))

<All keys matched successfully>

In [8]:
def print_trainable_parameters(model):
    parameters, trainable = 0, 0
    
    for _, p in model.named_parameters():
        parameters += p.numel()
        trainable += p.numel() if p.requires_grad else 0
    print(trainable)
    print(f"trainable parameters: {trainable:,}/{parameters:,} ({100 * trainable / parameters:.2f}%)")

In [9]:
print_trainable_parameters(model)

112797584
trainable parameters: 112,797,584/112,797,584 (100.00%)


In [10]:
for param in model.parameters():
    param.requires_grad = False


In [11]:
print_trainable_parameters(model)

0
trainable parameters: 0/112,797,584 (0.00%)


In [10]:
for param in model.backbone.time_mlp.parameters():
    param.requires_grad = True


In [11]:
print_trainable_parameters(model)

131584
trainable parameters: 131,584/112,797,584 (0.12%)


In [14]:
for param in model.backbone.encoder_layers[0].blocks[0].norm1.ln_modulation.parameters():
    param.requires_grad = True

In [15]:
print_trainable_parameters(model)

263168
trainable parameters: 263,168/112,797,584 (0.23%)


# Get south africa Data

In [12]:
fs = gcsfs.GCSFileSystem(token="anon")

store = fs.get_mapper('gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr')
full_era5 = xr.open_zarr(store=store, consolidated=True, chunks=None)

In [13]:
start_time, end_time = '2022-12-01', '2023-01-31'


In [14]:


atmostpheric_variables = ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity", "geopotential"]
surface_vars = ['2m_temperature', '10m_u_component_of_wind', '10m_v_component_of_wind', 'mean_sea_level_pressure']
static_variables = ["land_sea_mask", "soil_type", "geopotential_at_surface"]


In [15]:

lat_max = -22.00 
lat_min = -37.75  

lon_min = 15.25   
lon_max = 35.00   
sliced_era5_SA = (
    full_era5
    .sel(
        time=slice(start_time, end_time),
        latitude=slice(lat_max, lat_min),
        longitude=slice(lon_min, lon_max)  
    )
    .isel(time=slice(None, -2))
)

target_sliced_era5_SA = (
    full_era5
    .sel(
        time=slice(start_time, end_time),
        latitude=slice(lat_max, lat_min),
        longitude=slice(lon_min, lon_max)  
    )  
    .isel(time=slice(2, None))  # Skip the first two time steps
)

In [16]:
# surf_vars_ds_SA = sliced_era5_SA[surface_vars]

# target_surf_vars_ds_SA = target_sliced_era5_SA[surface_vars]

# atmos_vars_ds_SA = sliced_era5_SA[atmostpheric_variables]

# target_atmos_vars_ds_SA = target_sliced_era5_SA[atmostpheric_variables]

# static_vars_ds_SA = sliced_era5_SA[static_variables]

# target_static_vars_ds_SA = target_sliced_era5_SA[static_variables]

In [17]:

# class ERA5ZarrDataset(Dataset):
#     def __init__(self, surf_vars_ds, atmos_vars_ds, static_vars_ds, sequence_length):
#         self.surf_vars_ds = surf_vars_ds
#         self.atmos_vars_ds = atmos_vars_ds
#         self.static_vars_ds = static_vars_ds
#         self.sequence_length = sequence_length
#         self.time_indices = range(sequence_length, len(surf_vars_ds.time))

#     def __len__(self):
#         return len(self.time_indices)

#     def __getitem__(self, idx):
#         i = self.time_indices[idx]

#         surf_vars = {
#             "2t": torch.from_numpy(self.surf_vars_ds["2m_temperature"].values[[i - 1, i]][None]),
#             "10u": torch.from_numpy(self.surf_vars_ds["10m_u_component_of_wind"].values[[i - 1, i]][None]),
#             "10v": torch.from_numpy(self.surf_vars_ds["10m_v_component_of_wind"].values[[i - 1, i]][None]),
#             "msl": torch.from_numpy(self.surf_vars_ds["mean_sea_level_pressure"].values[[i - 1, i]][None]),
#         }

#         static_vars = {
#             "z": torch.from_numpy(self.static_vars_ds["geopotential_at_surface"].values),
#             "slt": torch.from_numpy(self.static_vars_ds["soil_type"].values),
#             "lsm": torch.from_numpy(self.static_vars_ds["land_sea_mask"].values),
#         }

#         atmos_vars = {
#             "t": torch.from_numpy(self.atmos_vars_ds["temperature"].values[[i - 1, i]][None]),
#             "u": torch.from_numpy(self.atmos_vars_ds["u_component_of_wind"].values[[i - 1, i]][None]),
#             "v": torch.from_numpy(self.atmos_vars_ds["v_component_of_wind"].values[[i - 1, i]][None]),
#             "q": torch.from_numpy(self.atmos_vars_ds["specific_humidity"].values[[i - 1, i]][None]),
#             "z": torch.from_numpy(self.atmos_vars_ds["geopotential"].values[[i - 1, i]][None]),
#         }

#         metadata=Metadata(
#         lat=torch.from_numpy(self.surf_vars_ds.latitude.values),
#         lon=torch.from_numpy(self.surf_vars_ds.longitude.values),
#         time=(self.surf_vars_ds.time.values.astype("datetime64[s]").tolist()[i],),
#         atmos_levels=tuple(int(level) for level in self.atmos_vars_ds.level.values)
#     )


#         return Batch(surf_vars=surf_vars, static_vars=static_vars, atmos_vars=atmos_vars, metadata=metadata)


In [18]:
# SA_batches = ERA5ZarrDataset(surf_vars_ds_SA, atmos_vars_ds_SA, static_vars_ds_SA,1)
# target_SA_batches = ERA5ZarrDataset(target_surf_vars_ds_SA, target_atmos_vars_ds_SA, target_static_vars_ds_SA,1)

# Fine tuning process

In [16]:
import torch.optim as optim
import torch.nn as nn

In [17]:
class LoRA(nn.Module):
    def __init__(self, original_layer, rank=4):
        super(LoRA, self).__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.lora_A = nn.Parameter(torch.randn(rank, original_layer.in_features))
        self.lora_B = nn.Parameter(torch.randn(original_layer.out_features, rank))

    def forward(self, x):
        return self.original_layer(x) + (x @ self.lora_A.T) @ self.lora_B.T


In [18]:
def apply_lora_to_model(model, rank=4):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            setattr(model, name, LoRA(module, rank))
        else:
            apply_lora_to_model(module, rank)
    return model


model = apply_lora_to_model(model, rank=4)


In [19]:
model = apply_lora_to_model(model, rank=4)

In [20]:
print_trainable_parameters(model)

2153600
trainable parameters: 2,153,600/114,951,184 (1.87%)


In [21]:
# import torch.optim as optim
# import torch.nn as nn

# optimizer = optim.Adam(model.parameters(), lr=1e-4)
# criterion = nn.MSELoss()

# model.train()
# for epoch in range(1):
#     print("start")
#     for inputs, targets in zip(SA_batches, target_SA_batches):
#         print("OK")
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, targets)
#         print(loss)
#         loss.backward()
#         optimizer.step()


In [22]:
# torch.save(model.state_dict(), "aurora_lora_finetuned.pth")


In [23]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# criterion = nn.MSELoss()

In [24]:
# sa_latitudes, sa_longitudes = sliced_era5_SA.latitude, sliced_era5_SA.longitude

# # Compute RMSE weights
# sa_rmse_weights = rmse_weights(sa_latitudes, sa_longitudes)

In [25]:
device = "cuda"

# HERE

In [26]:
from loss import AuroraLoss

In [27]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# criterion = nn.MSELoss()

In [28]:
criterion = AuroraLoss()

In [29]:

# def training(model, criterion, num_epochs,
#              dataset=None,
#              dataset_name="ERA5", 
#              accumulation_steps=8
#              ):
#     selected_times = dataset.time
#     loss_list=[]
#     for epoch in range(num_epochs):
#         model.train()
#         for i in range(0, len(selected_times)-3):
#             # get current and previous time step data

#             sa_feature_data =  (
#                     dataset
#                     .sel(time=slice(selected_times[i], selected_times[i+1]))
#                 )

#             sa_target_data =  (
#                     dataset
#                     .sel(time=slice(selected_times[i+2], selected_times[i+3]))
#                 )
            
#             # get each type of data(surface, static atmosphere)

#             sa_feature_surface_data, sa_target_surface_data = get_surface_feature_target_data(sa_feature_data, sa_target_data)
#             sa_feature_atmos_data, sa_target_atmos_data = get_atmos_feature_target_data(sa_feature_data, sa_target_data)
#             sa_feature_static_data, sa_target_static_data = get_static_feature_target_data(sa_feature_data, sa_target_data)
            
#             # create batch for each of them

#             input =  create_batch(sa_feature_surface_data, sa_feature_atmos_data, sa_feature_static_data)
#             target = create_batch(sa_target_surface_data, sa_target_atmos_data, sa_target_static_data)
            
#             print("Start")
                        
#             # Forward pass
#             outputs = model(input)
#             loss = criterion(outputs, target, dataset_name)
#             loss = loss / accumulation_steps  # Normalize loss
#             print(loss.detach().numpy())
            
#             # Backward pass
#             loss.backward()
            
#             # Update weights and reset gradients every accumulation_steps
#             if (i + 1) % accumulation_steps == 0:
#                 optimizer.step()
#                 optimizer.zero_grad()
        
#         # Handle remaining gradients if dataset size is not divisible by accumulation_steps
#         if (i + 1) % accumulation_steps != 0:
#             optimizer.step()
#             optimizer.zero_grad()
        
#         loss_list.append(loss.detach().numpy())
        
        
        
#         print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
        
#     return model, loss_list


    
    
    
    
#     # optimizer.zero_grad()
#     # outputs = model(sa_feature_bacth)
#     # print("Prediction done")
#     # loss = criterion(outputs, sa_target_bacth, "ERA5")
#     # print(loss)
#     # loss.backward()
#     # optimizer.step()
    


In [30]:
# training(model, criterion, num_epochs=1,
#              dataset=sliced_era5_SA)

In [32]:
from train import training

training(model, criterion,
             num_epochs=1,  optimizer=optimizer, dataset = sliced_era5_SA)

Start
86885710.0
Start
95813010.0
Start
86252984.0
Start
92127160.0
Start
97635850.0
Start
102172870.0
Start
102170990.0
Start
95612216.0
Start
99822640.0
Start
99889336.0
Start
91303560.0
Start
100477960.0


KeyboardInterrupt: 

# Evaluation

In [None]:
model.eval()
val_running_loss = 0.0


with torch.no_grad():
        for i in range(0, len(target_sliced_era5_SA.time)-3):
        # get current and previous time step data

            sa_feature_data =  (
                    sliced_era5_SA
                    .sel(time=slice(selected_times[i], selected_times[i+1]))
                )

            sa_target_data =  (
                    target_sliced_era5_SA
                    .sel(time=slice(selected_times[i+2], selected_times[i+3]))
                )
            
            # get each type of data(surface, static atmosphere)

            sa_feature_surface_data, sa_target_surface_data = get_surface_feature_target_data(sa_feature_data, sa_target_data)
            sa_feature_atmos_data, sa_target_atmos_data = get_atmos_feature_target_data(sa_feature_data, sa_target_data)
            sa_feature_static_data, sa_target_static_data = get_static_feature_target_data(sa_feature_data, sa_target_data)
            
            # create batch for each of them

            input =  create_batch(sa_feature_surface_data, sa_feature_atmos_data, sa_feature_static_data)
            target = create_batch(sa_target_surface_data, sa_target_atmos_data, sa_target_static_data)
            
            
            
            print("Start")
            
        

            
            # inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(input)
            loss = criterion(outputs, target, "ERA5")
            loss = loss / accumulation_steps  # Normalize loss
            loss_list.append(loss.detach().numpy())
            print(loss)


print(f'Validation Loss: {val_epoch_loss:.4f}, Validation Accuracy: {val_epoch_accuracy:.4f}')


In [30]:

sa_latitudes = target_sliced_era5_SA.latitude
sa_longitudes = target_sliced_era5_SA.longitude

selected_times =  target_sliced_era5_SA.time
sa_rmses_list=[]
for i in range(0, len(target_sliced_era5_SA.time)-3):
    # get current and previous time step data
    world_feature_data =  (
            target_sliced_era5_SA
            .sel(time=slice(selected_times[i], selected_times[i+1]))
        )
    sa_feature_data =  (
            sliced_era5_SA
            .sel(time=slice(selected_times[i], selected_times[i+1]))
        )

    sa_target_data =  (
            target_sliced_era5_SA
            .sel(time=slice(selected_times[i+2], selected_times[i+3]))
        )
    
    # get each type of data(surface, static atmosphere)

    sa_feature_surface_data, sa_target_surface_data = get_surface_feature_target_data(sa_feature_data, sa_target_data)
    sa_feature_atmos_data, sa_target_atmos_data = get_atmos_feature_target_data(sa_feature_data, sa_target_data)
    sa_feature_static_data, sa_target_static_data = get_static_feature_target_data(sa_feature_data, sa_target_data)
    
    # create batch for each of them

    sa_feature_bacth =  create_batch(sa_feature_surface_data, sa_feature_atmos_data, sa_feature_static_data)
    sa_target_bacth = create_batch(sa_target_surface_data, sa_target_atmos_data, sa_target_static_data)
    
    
    target_tensor = sa_target_bacth.surf_vars["2t"].squeeze()[1,:,:]
    
    print("Start")
    
    optimizer.zero_grad()
    outputs = model(sa_feature_bacth)
    output_tensor =  outputs.surf_vars["2t"][0, 0]
    print("Prediction done")
    loss = custom_rmse(target_tensor, output_tensor, sa_rmse_weights)
    print(loss)
    loss.backward()
    optimizer.step()
    


Start
Prediction done
tensor(80327112., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(73870288., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(71279416., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(66979388., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(76102832., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(67231240., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(58614704., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(55923936., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(53904288., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(52293720., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(53105628., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(49706236., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(47634212., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(41055844., grad_fn=<SqrtBackward0>)
Start
Prediction done
tensor(38299132., grad_fn=<SqrtBackward0>)
Start
Prediction done
ten

KeyboardInterrupt: 

In [48]:

sa_latitudes = target_sliced_era5_SA.latitude
sa_longitudes = target_sliced_era5_SA.longitude
sa_rmse_weights = rmse_weights(sa_latitudes, sa_longitudes)
selected_times =  target_sliced_era5_SA.time
sa_rmses_list=[]
for i in range(0, len(target_sliced_era5_SA.time)-3):
    # get current and previous time step data
    world_feature_data =  (
            target_sliced_era5_SA
            .sel(time=slice(selected_times[i], selected_times[i+1]))
        )
    sa_feature_data =  (
            sliced_era5_SA
            .sel(time=slice(selected_times[i], selected_times[i+1]))
        )

    sa_target_data =  (
            target_sliced_era5_SA
            .sel(time=slice(selected_times[i+2], selected_times[i+3]))
        )
    
    # get each type of data(surface, static atmosphere)

    sa_feature_surface_data, sa_target_surface_data = get_surface_feature_target_data(sa_feature_data, sa_target_data)
    sa_feature_atmos_data, sa_target_atmos_data = get_atmos_feature_target_data(sa_feature_data, sa_target_data)
    sa_feature_static_data, sa_target_static_data = get_static_feature_target_data(sa_feature_data, sa_target_data)
    
    # create batch for each of them

    sa_feature_bacth =  create_batch(sa_feature_surface_data, sa_feature_atmos_data, sa_feature_static_data)
    sa_target_bacth = create_batch(sa_target_surface_data, sa_target_atmos_data, sa_target_static_data)
    # get prediction
    # sa_predictions = predict_fn(batch=sa_feature_bacth)
    # # compute the rmse
    
    # sa_rmses, sa_pred_dates = rmse_fn(predictions=sa_predictions, 
    #         target_batch=sa_target_bacth, var_name="2t",
    #         weigths=sa_rmse_weights, area="sa")
    # # append result to the list
    # world_rmses_list.append(world_rmses); pred_dates_list.append(world_pred_dates)
    # sa_rmses_list.append(sa_rmses)
    print("Start")
    


    model = model.cuda()
    model.train()
    model.configure_activation_checkpointing()

    pred = model.forward(sa_feature_bacth)
    loss , _ = rmse_fn(predictions=[pred], 
             target_batch=sa_target_bacth, var_name="2t",
             weigths=sa_rmse_weights, area="sa")
    loss.backward()

Start


AttributeError: 'list' object has no attribute 'backward'

In [None]:
sa_rmses, _ = rmse_fn(predictions=pred, 
             target_batch=sa_target_bacth, var_name="2t",
             weigths=sa_rmse_weights, area="sa")