# SETUP

In [None]:
!pip install torch einops numpy timm==0.6.13 scipy gcsfs cdsapi xarray zarr netcdf4 matplotlib

In [None]:
%cd /workspace/aurora_229s
!git pull

In [None]:
import importlib
from pathlib import Path
import datetime
import numpy as np
import torch
import gc

In [None]:
from aurora import inference_helper
from aurora.model import aurora, swin3d

def reload():
    importlib.reload(inference_helper)
    importlib.reload(aurora)
    importlib.reload(swin3d)

In [None]:
def gpu_mem(msg):
    print(f'{msg}:')
    print("\ttorch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("\ttorch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("\ttorch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    print()

def print_timestamp():
    current_time = datetime.datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
    print(formatted_time)

In [None]:
# Variable names
surf_vars_names = [
    ('2t', '2m_temperature'),
    ('10u', '10m_u_component_of_wind'),
    ('10v', '10m_v_component_of_wind'),
    ('msl', 'mean_sea_level_pressure'),
]
static_vars_names = [
    ('z', 'z'),
    ('slt', 'slt'),
    ('lsm', 'lsm')
]
atmos_vars_names = [
    ('t', 'temperature'),
    ('u', 'u_component_of_wind'),
    ('v', 'v_component_of_wind'),
    ('q', 'specific_humidity'),
    ('z', 'geopotential')
]

all_vars_names = surf_vars_names + atmos_vars_names

In [None]:
model = aurora.Aurora()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-finetuned.ckpt")
# model.load_checkpoint_local(
#     "/workspace/models/hf_ckpt/aurora-0.25-finetuned.ckpt"
# )
model.configure_activation_checkpointing()

# Inference loop

### Individual tasks

In [None]:
from pathlib import Path
import datetime
import numpy as np
import torch
import gc
from aurora.download_data import download_for_day

# Save May and August for testing
base_date_list = ["2022-02-01", "2022-04-01", "2022-07-01", "2022-11-01"]
# base_date_list = ["2022-02-01", "2022-04-01"]#, "2022-07-01", "2022-11-01"]

base_save_dir = Path("/workspace/models/fisher")
base_save_dir.mkdir(exist_ok=True, parents=True)
# device = xm.xla_device()
device = 'cuda'

bel_model = inference_helper.BackboneEncoderLayers(model.backbone.encoder_layers)
bel_model = bel_model.to(device)
bdl_model = inference_helper.BackboneDecoderLayers(model.backbone.decoder_layers, model.backbone.num_decoder_layers)
bdl_model = bdl_model.to(device)
decoder = model.decoder.to(device)
decoder.eval()

# MAJOR LOOP -- SURF VARS
# for sh_var,lh_var in all_vars_names:
# for sh_var,lh_var in atmos_vars_names:
for sh_var,lh_var in (surf_vars_names[2:] + atmos_vars_names[1:]):
    print(sh_var, lh_var)
    print_timestamp()
    print('\n')

    cnt = 0
    mae_losses = []
    grads = {'backbone_encoder':{}, 'backbone_decoder':{}}
    download_for_day(day=base_date_list[0], download_path=Path("/workspace/data"))
    batcher = inference_helper.InferenceBatcher(
        base_date_list=base_date_list,
        data_path=Path("/workspace/data"),
        max_n_days=14
    )

    while True:
        # model.zero_grad() # Critical to zero-out gradients
        bel_model.zero_grad()
        bdl_model.zero_grad()
        try:
            batch, labels = batcher.get_batch()
        except Exception as e:
            print('\n', e, '\n')
            break
        if batch is None or labels is None:
            break
        # print(batcher.day, batcher.time_idx - 1)

        rollout_step = batch.metadata.rollout_step
        batch = inference_helper.preprocess_batch(model=model, batch=batch, device=device)
        torch.cuda.empty_cache()

        p = next(model.parameters())
        labels = labels.type(p.dtype)
        labels = labels.crop(model.patch_size)

        # -------------------------------------
        # 2. Encoder Forward
        with torch.no_grad():
            x, patch_res = inference_helper.encoder_forward(model=model, batch=batch, device=device)
        batch = batch.to('cpu')
        torch.cuda.empty_cache()
        # gpu_mem('ENCODER FWD PASS')

        # 3. Backbone encoder layers forward
        with torch.no_grad():
            c, all_enc_res, padded_outs = inference_helper.backbone_prep(model=model, x=x, patch_res=patch_res, device=device)
            # x, skips, c, all_enc_res, padded_outs = inference_helper.backbone_encoder_layers_forward(
            #     model=model, x=x, patch_res=patch_res, rollout_step=rollout_step, device=device,
            # )
        torch.cuda.empty_cache()
        # gpu_mem('BACKBONE ENCODER FWD PASS')

        # Step 4?
        x, skips = bel_model.forward(x=x, c=c, all_enc_res=all_enc_res, rollout_step=rollout_step)
        torch.cuda.empty_cache()

        # 6. BDL forward pass
        x = bdl_model.forward(
            x=x, skips=skips, c=c, all_enc_res=all_enc_res, padded_outs=padded_outs, rollout_step=rollout_step,
        )
        del skips, c, all_enc_res, padded_outs
        torch.cuda.empty_cache()
        # gpu_mem('BDL forward pass')

        # 7. Decoder forward pass
        preds = inference_helper.decoder_forward(
            decoder=decoder, x=x, batch=batch, patch_res=patch_res, surf_stats=model.surf_stats,
        )
        del x, batch, patch_res

        # -------------------------------------


        if (sh_var,lh_var) in surf_vars_names:
            task_pred = preds.surf_vars[sh_var][0, 0]
            ref = labels.surf_vars[sh_var][0,0].to(device)
        else:
            task_pred = preds.atmos_vars[sh_var][0, 0]
            ref = labels.atmos_vars[sh_var][0,0].to(device)

        # Paper uses mean absolute error
        loss = torch.mean(torch.abs(task_pred - ref))
        # loss = torch.utils.checkpoint.checkpoint(loss_fn, ref, task_pred, use_reentrant=False)
        loss.backward()

        for name,param in bdl_model.named_parameters():
            if cnt == 0:
                grads['backbone_decoder'][name] = torch.square(param.grad.clone().to('cpu'))
            else:
                grads['backbone_decoder'][name] += torch.square(param.grad.clone().to('cpu'))

        for name,param in bel_model.named_parameters():
            if cnt == 0:
                grads['backbone_encoder'][name] = torch.square(param.grad.clone().to('cpu'))
            else:
                grads['backbone_encoder'][name] += torch.square(param.grad.clone().to('cpu'))
        
        cnt += 1
        mae_losses.append(loss.clone().detach().to('cpu').numpy())
        del preds, task_pred, ref
        gc.collect()
        torch.cuda.empty_cache()

    # finished with loop
    for key in ['backbone_encoder', 'backbone_decoder']:
        task_dir = base_save_dir / lh_var / key
        task_dir.mkdir(parents=True, exist_ok=True)
    
        for name,param in grads[key].items():
            torch.save(param / float(cnt), task_dir / f'{name}.pt')
    np.save(task_dir / f'LOSSES.npy', np.array(mae_losses))

### Multitask

In [None]:
# Repeat with multi-task objective
surf_vars_names_wts = [
    ('2t', '2m_temperature', 3.0),
    ('10u', '10m_u_component_of_wind', 0.77),
    ('10v', '10m_v_component_of_wind', 0.66),
    ('msl', 'mean_sea_level_pressure', 1.5),
]
atmos_vars_names_wts = [
    ('t', 'temperature', 1.7),
    ('u', 'u_component_of_wind', 0.87),
    ('v', 'v_component_of_wind', 0.6),
    ('q', 'specific_humidity', 0.78),
    ('z', 'geopotential', 2.8)
]

# Save May and August for testing
base_date_list = ["2022-02-01", "2022-04-01", "2022-07-01", "2022-11-01"]
# base_date_list = ["2022-02-01", "2022-04-01"]#, "2022-07-01", "2022-11-01"]

base_save_dir = Path("/workspace/models/fisher")
base_save_dir.mkdir(exist_ok=True, parents=True)
# device = xm.xla_device()
device = 'cuda'

bel_model = inference_helper.BackboneEncoderLayers(model.backbone.encoder_layers)
bel_model = bel_model.to(device)
bdl_model = inference_helper.BackboneDecoderLayers(model.backbone.decoder_layers, model.backbone.num_decoder_layers)
bdl_model = bdl_model.to(device)
decoder = model.decoder.to(device)
decoder.eval()

# MAJOR LOOP -- MTL
cnt = 0
mae_losses = []
grads = {'backbone_encoder':{}, 'backbone_decoder':{}}
download_for_day(day=base_date_list[0], download_path=Path("/workspace/data"))
batcher = inference_helper.InferenceBatcher(
    base_date_list=base_date_list,
    data_path=Path("/workspace/data"),
    max_n_days=14
)

while True:
    # model.zero_grad() # Critical to zero-out gradients
    bel_model.zero_grad()
    bdl_model.zero_grad()
    try:
        batch, labels = batcher.get_batch()
    except Exception as e:
        print('\n', e, '\n')
        break
    if batch is None or labels is None:
        break
    # print(batcher.day, batcher.time_idx - 1)

    rollout_step = batch.metadata.rollout_step
    batch = inference_helper.preprocess_batch(model=model, batch=batch, device=device)
    torch.cuda.empty_cache()

    p = next(model.parameters())
    labels = labels.type(p.dtype)
    labels = labels.crop(model.patch_size)

    # -------------------------------------
    # 2. Encoder Forward
    with torch.no_grad():
        x, patch_res = inference_helper.encoder_forward(model=model, batch=batch, device=device)
    batch = batch.to('cpu')
    torch.cuda.empty_cache()
    # gpu_mem('ENCODER FWD PASS')

    # 3. Backbone encoder layers forward
    with torch.no_grad():
        c, all_enc_res, padded_outs = inference_helper.backbone_prep(model=model, x=x, patch_res=patch_res, device=device)
        # x, skips, c, all_enc_res, padded_outs = inference_helper.backbone_encoder_layers_forward(
        #     model=model, x=x, patch_res=patch_res, rollout_step=rollout_step, device=device,
        # )
    torch.cuda.empty_cache()
    # gpu_mem('BACKBONE ENCODER FWD PASS')

    # Step 4?
    x, skips = bel_model.forward(x=x, c=c, all_enc_res=all_enc_res, rollout_step=rollout_step)
    torch.cuda.empty_cache()

    # 6. BDL forward pass
    x = bdl_model.forward(
        x=x, skips=skips, c=c, all_enc_res=all_enc_res, padded_outs=padded_outs, rollout_step=rollout_step,
    )
    del skips, c, all_enc_res, padded_outs
    torch.cuda.empty_cache()
    # gpu_mem('BDL forward pass')

    # 7. Decoder forward pass
    preds = inference_helper.decoder_forward(
        decoder=decoder, x=x, batch=batch, patch_res=patch_res, surf_stats=model.surf_stats,
    )
    del x, batch, patch_res

    # -------------------------------------

    # surf_mae = torch.sum(
    #     torch.tensor(
    #         [
    #             wt*torch.mean(torch.abs(preds.surf_vars[sh_var][0,0] - labels.surf_vars[sh_var][0,0].to(device)))
    #             for sh_var,_,wt in surf_vars_names_wts
    #         ]
    #     )
    # )
    # atmos_mae = torch.sum(
    #     torch.tensor(
    #         [
    #             wt*torch.mean(torch.abs(preds.atmos_vars[sh_var][0,0] - labels.atmos_vars[sh_var][0,0].to(device)))
    #             for sh_var,_,wt in atmos_vars_names_wts
    #         ]
    #     )
    # )
    # loss = surf_mae + 

    # UGLY UGLY CODE
    i = 0
    loss = surf_vars_names_wts[i][2]*torch.mean(torch.abs(preds.surf_vars[surf_vars_names_wts[i][0]][0,0] - labels.surf_vars[surf_vars_names_wts[i][0]][0,0].to(device)))
    for i in range(1,4):
        loss = loss + surf_vars_names_wts[i][2]*torch.mean(torch.abs(preds.surf_vars[surf_vars_names_wts[i][0]][0,0] - labels.surf_vars[surf_vars_names_wts[i][0]][0,0].to(device)))
    for i in range(5):
        loss = loss + atmos_vars_names_wts[i][2]*torch.mean(torch.abs(preds.atmos_vars[atmos_vars_names_wts[i][0]][0,0] - labels.atmos_vars[atmos_vars_names_wts[i][0]][0,0].to(device)))
    loss.backward()

    for name,param in bdl_model.named_parameters():
        if cnt == 0:
            grads['backbone_decoder'][name] = torch.square(param.grad.clone().to('cpu'))
        else:
            grads['backbone_decoder'][name] += torch.square(param.grad.clone().to('cpu'))

    for name,param in bel_model.named_parameters():
        if cnt == 0:
            grads['backbone_encoder'][name] = torch.square(param.grad.clone().to('cpu'))
        else:
            grads['backbone_encoder'][name] += torch.square(param.grad.clone().to('cpu'))
    
    cnt += 1
    mae_losses.append(loss.clone().detach().to('cpu').numpy())
    del preds
    gc.collect()
    torch.cuda.empty_cache()

# finished with loop
for key in ['backbone_encoder', 'backbone_decoder']:
    task_dir = base_save_dir / 'multitask' / key
    task_dir.mkdir(parents=True, exist_ok=True)

    for name,param in grads[key].items():
        torch.save(param / float(cnt), task_dir / f'{name}.pt')
np.save(task_dir.parent / f'LOSSES.npy', np.array(mae_losses))