# SETUP

In [None]:
!pip install torch einops numpy timm==0.6.13 scipy gcsfs cdsapi xarray zarr netcdf4 matplotlib

In [None]:
%cd /workspace/aurora_229s
!git pull

In [None]:
import importlib
from pathlib import Path
import datetime
import numpy as np
import torch
import gc

In [None]:
from aurora import inference_helper
from aurora.model import aurora, swin3d

def reload():
    importlib.reload(inference_helper)
    importlib.reload(aurora)
    importlib.reload(swin3d)

In [None]:
def gpu_mem(msg):
    print(f'{msg}:')
    print("\ttorch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("\ttorch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("\ttorch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    print()

def print_timestamp():
    current_time = datetime.datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
    print(formatted_time)

In [None]:
# Variable names
surf_vars_names = [
    ('2t', '2m_temperature'),
    ('10u', '10m_u_component_of_wind'),
    ('10v', '10m_v_component_of_wind'),
    ('msl', 'mean_sea_level_pressure'),
]
static_vars_names = [
    ('z', 'z'),
    ('slt', 'slt'),
    ('lsm', 'lsm')
]
atmos_vars_names = [
    ('t', 'temperature'),
    ('u', 'u_component_of_wind'),
    ('v', 'v_component_of_wind'),
    ('q', 'specific_humidity'),
    ('z', 'geopotential')
]

all_vars_names = surf_vars_names + atmos_vars_names

In [None]:
model = aurora.AuroraSmall()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
# model.load_checkpoint_local(
#     "/workspace/models/hf_ckpt/aurora-0.25-finetuned.ckpt"
# )
model.configure_activation_checkpointing()

In [None]:
from pathlib import Path
import datetime
import numpy as np
import torch
import gc
from aurora.download_data import download_for_day
import dataclasses
import contextlib


In [None]:
from aurora import evaluation_helper
evaluation_helper.cleanup_download_dir(download_path=Path("/workspace/data"))

# Inference loop

### Individual tasks

In [None]:
import contextlib
from datetime import timedelta
from copy import deepcopy

# Repeat with multi-task objective
surf_vars_names_wts, atmos_vars_names_wts = inference_helper.get_vars_names_wts()

# Save May and August for testing
base_date_list = ["2022-02-01", "2022-04-01", "2022-07-01", "2022-11-01"]
# base_date_list = ["2022-02-01", "2022-04-01"]#, "2022-07-01", "2022-11-01"]

base_save_dir = Path("/workspace/models/fisher")
base_save_dir.mkdir(exist_ok=True, parents=True)
# device = xm.xla_device()
device = 'cuda'

sh_exclude = ['msl', 'z', 'q']

backbone = deepcopy(model.backbone).to(device)
decoder = deepcopy(model.decoder).to(device)
# decoder.eval()

# MAJOR LOOP -- MTL
surf_sh_list = [sh for sh,_,_ in surf_vars_names_wts]
all_vars_list = [surf_vars_names_wts[0], atmos_vars_names_wts[0]] + surf_vars_names_wts[1:3] + atmos_vars_names_wts[1:3]
for (sh,lh,wt) in all_vars_list:
    if sh in sh_exclude:
        continue
    
    cnt = 0
    mae_losses = []
    # grads = {'backbone_encoder':{}, 'backbone_decoder':{}}
    grads = {}
    download_for_day(day=base_date_list[0], download_path=Path("/workspace/data"))
    batcher = inference_helper.InferenceBatcher(
        base_date_list=base_date_list,
        data_path=Path("/workspace/data"),
        max_n_days=7#was 14
    )
    
    while True:
        # model.zero_grad() # Critical to zero-out gradients
        backbone.zero_grad()
        decoder.zero_grad()
        try:
            batch, labels = batcher.get_batch()
        except Exception as e:
            print('\n', e, '\n')
            break
        if batch is None or labels is None:
            break
        # print(batcher.day, batcher.time_idx - 1)
    
        rollout_step = batch.metadata.rollout_step
        batch = inference_helper.preprocess_batch(model=model, batch=batch, device=device, norm=True)
        torch.cuda.empty_cache()
    
        p = next(model.parameters())
        labels = labels.type(p.dtype)
        labels = labels.crop(model.patch_size)
        # labels = inference_helper.preprocess_batch(model=model, batch=batch, device=device, norm=False)
    
        # -------------------------------------
        # 2. Encoder Forward
        with torch.no_grad():
            x, patch_res = inference_helper.encoder_forward(model=model, batch=batch, device=device)
        # batch = batch.to('cpu')
        torch.cuda.empty_cache()
        # gpu_mem('ENCODER FWD PASS')
    
        # 3. Backbone encoder layers forward
        with torch.no_grad():
            c, all_enc_res, padded_outs = inference_helper.backbone_prep(backbone=backbone, x=x, patch_res=patch_res, device=device)
            # x, skips, c, all_enc_res, padded_outs = inference_helper.backbone_encoder_layers_forward(
            #     model=model, x=x, patch_res=patch_res, rollout_step=rollout_step, device=device,
            # )
        # print('c', torch.any(torch.isnan(c)))
        # print('all_enc_res', torch.any(torch.isnan(torch.tensor(all_enc_res))))
        # print('padded_outs', torch.any(torch.isnan(torch.tensor(padded_outs))))
        torch.cuda.empty_cache()
        # gpu_mem('BACKBONE ENCODER FWD PASS')
    
        # Step 4?
        with torch.autocast(device_type="cuda") if model.autocast else contextlib.nullcontext():
            x = backbone(
                x,
                lead_time=timedelta(hours=6),
                patch_res=patch_res,
                rollout_step=batch.metadata.rollout_step,
            )
        # print('after backbone x:', torch.any(torch.isnan(x)))
        
        # gpu_mem('BDL forward pass')
    
        # 7. Decoder forward pass
        preds = inference_helper.decoder_forward(
            decoder=decoder, x=x, batch=batch, patch_res=patch_res, surf_stats=model.surf_stats,
        )
        del x, batch, patch_res
    
        # -------------------------------------
    
        # surf_mae = torch.sum(
        #     torch.tensor(
        #         [
        #             wt*torch.mean(torch.abs(preds.surf_vars[sh_var][0,0] - labels.surf_vars[sh_var][0,0].to(device)))
        #             for sh_var,_,wt in surf_vars_names_wts
        #         ]
        #     )
        # )
        # atmos_mae = torch.sum(
        #     torch.tensor(
        #         [
        #             wt*torch.mean(torch.abs(preds.atmos_vars[sh_var][0,0] - labels.atmos_vars[sh_var][0,0].to(device)))
        #             for sh_var,_,wt in atmos_vars_names_wts
        #         ]
        #     )
        # )
        # loss = surf_mae + 
    
        # UGLY UGLY CODE
        if sh in surf_sh_list:
            loss = wt * torch.mean(torch.abs(preds.surf_vars[sh][0,0] - labels.surf_vars[sh][0,0].to(device)))
        else:
            loss = wt * torch.mean(torch.abs(preds.atmos_vars[sh][0,0] - labels.atmos_vars[sh][0,0].to(device)))
        if cnt < 5:
            print(loss)
        if bool(torch.any(torch.isnan(loss))):
            del loss
            continue
        loss.backward()
        # assert False
    
        # for name,param in bdl_model.named_parameters():
        #     if cnt == 0:
        #         grads['backbone_decoder'][name] = torch.square(param.grad.clone().to('cpu'))
        #     else:
        #         grads['backbone_decoder'][name] += torch.square(param.grad.clone().to('cpu'))
    
        # for name,param in bel_model.named_parameters():
        #     if cnt == 0:
        #         grads['backbone_encoder'][name] = torch.square(param.grad.clone().to('cpu'))
        #     else:
        #         grads['backbone_encoder'][name] += torch.square(param.grad.clone().to('cpu'))
        for name,param in backbone.named_parameters():
            if cnt == 0:
                # print(param.grad.clone().to('cpu').reshape((-1,))[:5])
                grads[name] = torch.square(param.grad.clone().to('cpu'))
            else:
                grads[name] += torch.square(param.grad.clone().to('cpu'))
        
        cnt += 1
        mae_losses.append(loss.clone().detach().to('cpu').numpy())
        del preds
        gc.collect()
        torch.cuda.empty_cache()
    
    # finished with loop
    # for key in ['backbone_encoder', 'backbone_decoder']:
    #     task_dir = base_save_dir / 'multitask_exclude' / key
    #     task_dir.mkdir(parents=True, exist_ok=True)
    
    #     for name,param in grads[key].items():
    #         torch.save(param / float(cnt), task_dir / f'{name}.pt')
    task_dir = base_save_dir / lh
    task_dir.mkdir(parents=True, exist_ok=True)
    
    for name,param in grads.items():
        torch.save(param / float(cnt), task_dir / f'{name}.pt')
    np.save(task_dir / f'LOSSES.npy', np.array(mae_losses))

### Multitask

In [None]:
import contextlib
from datetime import timedelta
from copy import deepcopy

# Repeat with multi-task objective
surf_vars_names_wts, atmos_vars_names_wts = inference_helper.get_vars_names_wts()

# Save May and August for testing
base_date_list = ["2022-02-01", "2022-04-01", "2022-07-01", "2022-11-01"]
# base_date_list = ["2022-02-01", "2022-04-01"]#, "2022-07-01", "2022-11-01"]

base_save_dir = Path("/workspace/models/fisher")
base_save_dir.mkdir(exist_ok=True, parents=True)
# device = xm.xla_device()
device = 'cuda'

sh_exclude = ['msl', 'z', 'q']

backbone = deepcopy(model.backbone).to(device)
decoder = deepcopy(model.decoder).to(device)
# decoder.eval()

# MAJOR LOOP -- MTL
cnt = 0
mae_losses = []
# grads = {'backbone_encoder':{}, 'backbone_decoder':{}}
grads = {}
download_for_day(day=base_date_list[0], download_path=Path("/workspace/data"))
batcher = inference_helper.InferenceBatcher(
    base_date_list=base_date_list,
    data_path=Path("/workspace/data"),
    max_n_days=7
)

while True:
    # model.zero_grad() # Critical to zero-out gradients
    backbone.zero_grad()
    decoder.zero_grad()
    try:
        batch, labels = batcher.get_batch()
    except Exception as e:
        print('\n', e, '\n')
        break
    if batch is None or labels is None:
        break
    # print(batcher.day, batcher.time_idx - 1)

    rollout_step = batch.metadata.rollout_step
    batch = inference_helper.preprocess_batch(model=model, batch=batch, device=device, norm=True)
    torch.cuda.empty_cache()

    p = next(model.parameters())
    labels = labels.type(p.dtype)
    labels = labels.crop(model.patch_size)
    # labels = inference_helper.preprocess_batch(model=model, batch=batch, device=device, norm=False)

    # -------------------------------------
    # 2. Encoder Forward
    with torch.no_grad():
        x, patch_res = inference_helper.encoder_forward(model=model, batch=batch, device=device)
    # batch = batch.to('cpu')
    torch.cuda.empty_cache()
    # gpu_mem('ENCODER FWD PASS')

    # 3. Backbone encoder layers forward
    with torch.no_grad():
        c, all_enc_res, padded_outs = inference_helper.backbone_prep(backbone=backbone, x=x, patch_res=patch_res, device=device)
        # x, skips, c, all_enc_res, padded_outs = inference_helper.backbone_encoder_layers_forward(
        #     model=model, x=x, patch_res=patch_res, rollout_step=rollout_step, device=device,
        # )
    # print('c', torch.any(torch.isnan(c)))
    # print('all_enc_res', torch.any(torch.isnan(torch.tensor(all_enc_res))))
    # print('padded_outs', torch.any(torch.isnan(torch.tensor(padded_outs))))
    torch.cuda.empty_cache()
    # gpu_mem('BACKBONE ENCODER FWD PASS')

    # Step 4?
    with torch.autocast(device_type="cuda") if model.autocast else contextlib.nullcontext():
        x = backbone(
            x,
            lead_time=timedelta(hours=6),
            patch_res=patch_res,
            rollout_step=batch.metadata.rollout_step,
        )
    # print('after backbone x:', torch.any(torch.isnan(x)))
    
    # gpu_mem('BDL forward pass')

    # 7. Decoder forward pass
    preds = inference_helper.decoder_forward(
        decoder=decoder, x=x, batch=batch, patch_res=patch_res, surf_stats=model.surf_stats,
    )
    del x, batch, patch_res

    # -------------------------------------

    # surf_mae = torch.sum(
    #     torch.tensor(
    #         [
    #             wt*torch.mean(torch.abs(preds.surf_vars[sh_var][0,0] - labels.surf_vars[sh_var][0,0].to(device)))
    #             for sh_var,_,wt in surf_vars_names_wts
    #         ]
    #     )
    # )
    # atmos_mae = torch.sum(
    #     torch.tensor(
    #         [
    #             wt*torch.mean(torch.abs(preds.atmos_vars[sh_var][0,0] - labels.atmos_vars[sh_var][0,0].to(device)))
    #             for sh_var,_,wt in atmos_vars_names_wts
    #         ]
    #     )
    # )
    # loss = surf_mae + 

    # UGLY UGLY CODE
    i = 0
    loss = surf_vars_names_wts[i][2]*torch.mean(torch.abs(preds.surf_vars[surf_vars_names_wts[i][0]][0,0] - labels.surf_vars[surf_vars_names_wts[i][0]][0,0].to(device)))
    for i in range(1,4):
        if surf_vars_names_wts[i][0] in sh_exclude:
            continue
        loss = loss + surf_vars_names_wts[i][2]*torch.mean(torch.abs(preds.surf_vars[surf_vars_names_wts[i][0]][0,0] - labels.surf_vars[surf_vars_names_wts[i][0]][0,0].to(device)))
    for i in range(5):
        if atmos_vars_names_wts[i][0] in sh_exclude:
            continue
        loss = loss + atmos_vars_names_wts[i][2]*torch.mean(torch.abs(preds.atmos_vars[atmos_vars_names_wts[i][0]][0,0] - labels.atmos_vars[atmos_vars_names_wts[i][0]][0,0].to(device)))
    if cnt < 5:
        print(loss)
    if bool(torch.any(torch.isnan(loss))):
        del loss
        continue
    
    loss.backward()
    # assert False

    # for name,param in bdl_model.named_parameters():
    #     if cnt == 0:
    #         grads['backbone_decoder'][name] = torch.square(param.grad.clone().to('cpu'))
    #     else:
    #         grads['backbone_decoder'][name] += torch.square(param.grad.clone().to('cpu'))

    # for name,param in bel_model.named_parameters():
    #     if cnt == 0:
    #         grads['backbone_encoder'][name] = torch.square(param.grad.clone().to('cpu'))
    #     else:
    #         grads['backbone_encoder'][name] += torch.square(param.grad.clone().to('cpu'))
    for name,param in backbone.named_parameters():
        if cnt == 0:
            # print(param.grad.clone().to('cpu').reshape((-1,))[:5])
            grads[name] = torch.square(param.grad.clone().to('cpu'))
        else:
            grads[name] += torch.square(param.grad.clone().to('cpu'))
    
    cnt += 1
    mae_losses.append(loss.clone().detach().to('cpu').numpy())
    del preds
    gc.collect()
    torch.cuda.empty_cache()

# finished with loop
# for key in ['backbone_encoder', 'backbone_decoder']:
#     task_dir = base_save_dir / 'multitask_exclude' / key
#     task_dir.mkdir(parents=True, exist_ok=True)

#     for name,param in grads[key].items():
#         torch.save(param / float(cnt), task_dir / f'{name}.pt')
task_dir = base_save_dir / 'multitask_exclude'
task_dir.mkdir(parents=True, exist_ok=True)

for name,param in grads.items():
    torch.save(param / float(cnt), task_dir / f'{name}.pt')
np.save(task_dir / f'LOSSES.npy', np.array(mae_losses))