# SETUP

In [None]:
!pip install torch einops numpy timm==0.6.13 scipy gcsfs cdsapi xarray zarr netcdf4 matplotlib

In [None]:
!cd /workspace/aurora_229s
!git pull

In [None]:
import importlib
from pathlib import Path
import datetime
import numpy as np
import torch
import gc

In [None]:
from aurora import inference_helper
from aurora.model import aurora, swin3d

def reload():
    importlib.reload(inference_helper)
    importlib.reload(aurora)
    importlib.reload(swin3d)

In [None]:
def gpu_mem(msg):
    print(f'{msg}:')
    print("\ttorch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("\ttorch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("\ttorch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    print()

def print_timestamp():
    current_time = datetime.datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
    print(formatted_time)

In [None]:
# Variable names
surf_vars_names = [
    ('2t', '2m_temperature'),
    ('10u', '10m_u_component_of_wind'),
    ('10v', '10m_v_component_of_wind'),
    ('msl', 'mean_sea_level_pressure'),
]
static_vars_names = [
    ('z', 'z'),
    ('slt', 'slt'),
    ('lsm', 'lsm')
]
atmos_vars_names = [
    ('t', 'temperature'),
    ('u', 'u_component_of_wind'),
    ('v', 'v_component_of_wind'),
    ('q', 'specific_humidity'),
    ('z', 'geopotential')
]

all_vars_names = surf_vars_names + atmos_vars_names

In [None]:
model = aurora.Aurora()
# model.load_checkpoint("microsoft/aurora", "aurora-0.25-finetuned.ckpt")
model.load_checkpoint_local(
    "/workspace/models/hf_ckpt/aurora-0.25-finetuned.ckpt"
)
model.configure_activation_checkpointing()

# Inference loop

In [None]:
# Save May for testing
# base_date_list = ["2022-02-01", "2022-04-01", "2022-05-01", "2022-07-01", "2022-11-01"]
base_date_list = ["2022-02-01", "2022-04-01", "2022-07-01", "2022-11-01"]

base_save_dir = Path("/workspace/models/fisher")
base_save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda'

model.train()
model = model.to(device)

# MAJOR LOOP -- SURF VARS
for sh_var,lh_var in all_vars_names:
    print(sh_var, lh_var)
    print_timestamp()
    print('\n')

    cnt = 0
    mae_losses = []
    grads = {}
    batcher = inference_helper.InferenceBatcher(
        base_date_list=base_date_list,
        data_path=Path("/workspace/data"),
        max_n_days=14,
    )

    while True:
        model.zero_grad() # Critical to zero-out gradients
        batch, labels = batcher.get_batch()
        if batch is None or labels is None:
            break
        print(batcher.day, batcher.time_idx - 1)

        p = next(model.parameters())
        batch = batch.type(p.dtype)
        batch = batch.crop(model.patch_size)
        batch = batch.to(p.device)

        labels = labels.type(p.dtype)
        labels = labels.crop(model.patch_size)

        # preds = model.forward(batch)
        preds = torch.utils.checkpoint.checkpoint(model.forward, batch, use_reentrant=False)

        if (sh_var,lh_var) in surf_vars_names:
            task_pred = preds.surf_vars[sh_var][0, 0]
            ref = labels.surf_vars[sh_var][0,0].to(device)
        else:
            task_pred = preds.atmos_vars[sh_var][0, 0]
            ref = labels.atmos_vars[sh_var][0,0].to(device)

        # Paper uses mean absolute error
        loss = torch.mean(torch.abs(task_pred - ref))
        # loss = torch.utils.checkpoint.checkpoint(loss_fn, ref, task_pred, use_reentrant=False)
        loss.backward()

        for name,param in model.named_parameters():
            if cnt == 0:
                grads[name] = torch.square(param.grad.clone().to('cpu'))
            else:
                grads[name] += torch.swaure(param.grad.clone().to('cpu'))

        cnt += 1
        mae_losses.append(loss.clone().detach().to('cpu').numpy())
        del preds, task_pred, ref, batch, labels, loss
        gc.collect()
        torch.cuda.empty_cache()

    # finished with loop
    task_dir = base_save_dir / lh_var
    task_dir.mkdir(parents=True, exist_ok=True)

    for name,param in grads.items():
        torch.save(param / float(cnt), task_dir / f'{name}.pt')
    np.save(task_dir / f'LOSSES.npy', np.array(mae_losses))