In [1]:
import xarray as xr
from datetime import datetime

import torch

from aurora import AuroraSmall, Batch, Metadata, rollout
import matplotlib.pyplot as plt

from pathlib import Path

import cdsapi
import numpy as np
from sklearn.metrics import root_mean_squared_error
import gcsfs

from torch.utils.data import Dataset
from aurora import Batch, Metadata
import os

  from .autonotebook import tqdm as notebook_tqdm


# Get new model

In [2]:
from aurora import Aurora

model = AuroraSmall(
    use_lora=False,  # Model was not fine-tuned.
    autocast=True,  # Use AMP.
    stabilise_level_agg=True
)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt",  strict=False)

torch.save(model.state_dict(), "../model/aurora-0.25-small-pretrained.pth")

In [3]:
import sys
sys.path.append(os.path.abspath("../src"))
from utils import get_surface_feature_target_data, get_atmos_feature_target_data
from utils import get_static_feature_target_data, create_batch, predict_fn, rmse_weights
from utils import rmse_fn, plot_rmses, custom_rmse

# Load Model

In [4]:
model = AuroraSmall(
    use_lora=False,  # Model was not fine-tuned.
    autocast=True,  # Use AMP.
)
model.load_state_dict(torch.load('../model/aurora-pretrained.pth'))

<All keys matched successfully>

# Freeze all weights

In [5]:
for param in model.parameters():
    param.requires_grad = False

In [6]:
print(model)

Aurora(
  (encoder): Perceiver3DEncoder(
    (surf_mlp): MLP(
      (net): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=256, bias=True)
        (3): Dropout(p=0.0, inplace=False)
      )
    )
    (surf_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (pos_embed): Linear(in_features=256, out_features=256, bias=True)
    (scale_embed): Linear(in_features=256, out_features=256, bias=True)
    (lead_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (absolute_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (atmos_levels_embed): Linear(in_features=256, out_features=256, bias=True)
    (surf_token_embeds): LevelPatchEmbed(
      (weights): ParameterDict(
          (10u): Parameter containing: [torch.FloatTensor of size 256x1x2x4x4]
          (10v): Parameter containing: [torch.FloatTensor of size 256x1x2x4x4]
    

In [8]:
for layer in model.backbone.time_mlp:
    print(layer)

Linear(in_features=256, out_features=256, bias=True)
SiLU()
Linear(in_features=256, out_features=256, bias=True)


# Lora class

In [10]:
class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

In [11]:
class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [12]:
from functools import partial
# default hyperparameter choices
lora_r = 8
lora_alpha = 16
# lora_dropout = 0.05
# lora_query = True
# lora_key = False
# lora_value = True
# lora_projection = False
# lora_mlp = False
# lora_head = False

# layers = []

assign_lora = partial(LinearWithLoRA, rank=lora_r, alpha=lora_alpha)

# Add lora to some parts

# Backbone MLP

In [16]:
model.backbone.time_mlp[0] = assign_lora(model.backbone.time_mlp[0])
model.backbone.time_mlp[2] = assign_lora(model.backbone.time_mlp[2])

# Encoder

In [20]:
model.backbone.encoder_layers[0].blocks[0].norm1.ln_modulation[1] = assign_lora(model.backbone.encoder_layers[0].blocks[0].norm1.ln_modulation[1] )



In [24]:
model.backbone.encoder_layers[0].blocks[0].attn.qkv = assign_lora(model.backbone.encoder_layers[0].blocks[0].attn.qkv)
model.backbone.encoder_layers[0].blocks[0].attn.proj = assign_lora(model.backbone.encoder_layers[0].blocks[0].attn.proj)

In [None]:
model.backbone.encoder_layers[0].blocks[0].attn.qkv = assign_lora(model.backbone.encoder_layers[0].blocks[0].attn.qkv)
model.backbone.encoder_layers[0].blocks[0].attn.proj = assign_lora(model.backbone.encoder_layers[0].blocks[0].attn.proj)

In [28]:
model.backbone.encoder_layers[0].blocks[0].norm2.ln_modulation[1] =  assign_lora(model.backbone.encoder_layers[0].blocks[0].norm2.ln_modulation[1])

Linear(in_features=256, out_features=512, bias=True)

In [29]:
model.backbone.encoder_layers[0].blocks[0].mlp.fc1

MLP(
  (fc1): Linear(in_features=256, out_features=1024, bias=True)
  (act): GELU(approximate='none')
  (fc2): Linear(in_features=1024, out_features=256, bias=True)
  (drop): Dropout(p=0.0, inplace=False)
)