# SETUP

In [None]:
!pip install torch einops numpy timm==0.6.13 scipy gcsfs cdsapi xarray zarr netcdf4 matplotlib pandas

In [None]:
%cd /workspace/aurora_229s
!git pull

In [None]:
import importlib
from pathlib import Path
import datetime
import numpy as np
import torch
import gc
import pandas as pd

In [None]:
from aurora import inference_helper, evaluation_helper, compression
from aurora.model import aurora, swin3d

def reload():
    importlib.reload(inference_helper)
    importlib.reload(evaluation_helper)
    importlib.reload(compression)
    importlib.reload(aurora)
    importlib.reload(swin3d)

In [None]:
def gpu_mem(msg):
    print(f'{msg}:')
    print("\ttorch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("\ttorch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("\ttorch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    print()

def print_timestamp():
    current_time = datetime.datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
    print(formatted_time)

In [None]:
model = aurora.AuroraSmall()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
model.eval()
print('loaded')

# Parameters

In [None]:
download_path = Path("/workspace/data")

save_dir = Path("/workspace/results")
save_dir.mkdir(exist_ok=True, parents=True)

device = 'cuda'

surf_vars_names_wts, atmos_vars_names_wts = inference_helper.get_vars_names_wts()
n_multiday_days = 7
multiday_starts = ['2022-05-01']#, '2022-08-01']

compression_ratios = [0.5]#, 0.25, 0.75]
base_grad_dir = Path("/workspace/models/fisher")
# lh_task_names = ['multitask'] + [lh for _,lh,_ in surf_vars_names_wts] + [lh for _,lh,_ in atmos_vars_names_wts]
sh_exclude = ['msl', 'z', 'q']
# lh_task_names = ['multitask_exclude'] + [lh for sh,lh,_ in surf_vars_names_wts if not (sh in sh_exclude)] + [lh for sh,lh,_ in atmos_vars_names_wts if not (sh in sh_exclude)]
lh_task_names = ['multitask_exclude', '2t', 't']

sameday_starts = []
for day in multiday_starts:
    sameday_starts.append(day)
    for _ in range(n_multiday_days-1):
        day = inference_helper.increment_day(day)
        sameday_starts.append(day)

### Baseline

In [None]:
model_name = 'baseline'
total_df = None

# Sameday
for day in sameday_starts[:1]: # HACK
    day_results_df = evaluation_helper.same_day_eval(model=model, day=day, download_path=download_path, device=device)

    if total_df is None:
        total_df = day_results_df.copy(deep=True)
    else:
        total_df = pd.concat([total_df, day_results_df], axis=0).reset_index(drop=True)

(save_dir / 'sameday').mkdir(exist_ok=True, parents=True)
total_df.to_csv(save_dir / 'sameday' / f'{model_name}.csv', index=False)
del day_results_df, total_df, day

# Multiday
total_df = None
for day in multiday_starts[:1]: # HACK
    md_results_df = evaluation_helper.multi_day_eval(
        model=model, day=day, download_path=download_path,
        max_n_days=2, device=device, verbose=True
    )

    if total_df is None:
        total_df = md_results_df.copy(deep=True)
    else:
        total_df = pd.concat([total_df, md_results_df], axis=0).reset_index(drop=True)

(save_dir / 'multiday').mkdir(exist_ok=True, parents=True)
total_df.to_csv(save_dir / 'multiday' / f'{model_name}.csv', index=False)
del md_results_df, total_df, day

In [None]:
# Multiday
total_df = None
for day in multiday_starts:
    md_results_df = evaluation_helper.multi_day_eval(
        model=model, day=day, download_path=download_path,
        max_n_days=2, device=device, verbose=True
    )

    if total_df is None:
        total_df = md_results_df.copy(deep=True)
    else:
        total_df = pd.concat([total_df, md_results_df], axis=0).reset_index(drop=True)

(save_dir / 'multiday').mkdir(exist_ok=True, parents=True)
total_df.to_csv(save_dir / 'multiday' / f'{model_name}.csv', index=False)
del md_results_df, total_df, day

# Check models

In [None]:
import numpy as np
import torch
import xarray as xr
from pathlib import Path

from aurora import Batch, Metadata

day = '2022-02-01'
download_path = Path('/workspace/data')

static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4")
surf_vars_ds = xr.open_dataset(download_path / day / f"{day}-surface-level.nc", engine="netcdf4")
atmos_vars_ds = xr.open_dataset(download_path / day / f"{day}-atmospheric.nc", engine="netcdf4")

i = 1  # Select this time index in the downloaded data.

def _prepare(x: np.ndarray) -> torch.Tensor:
    """Prepare a variable.

    This does the following things:
    * Select time indices `i` and `i - 1`.
    * Insert an empty batch dimension with `[None]`.
    * Flip along the latitude axis to ensure that the latitudes are decreasing.
    * Copy the data, because the data must be contiguous when converting to PyTorch.
    * Convert to PyTorch.
    """
    return torch.from_numpy(x[[i - 1, i]][None][..., ::-1, :].copy())


batch = Batch(
    surf_vars={
        "2t": _prepare(surf_vars_ds["2m_temperature"].values),
        "10u": _prepare(surf_vars_ds["10m_u_component_of_wind"].values),
        "10v": _prepare(surf_vars_ds["10m_v_component_of_wind"].values),
        "msl": _prepare(surf_vars_ds["mean_sea_level_pressure"].values),
    },
    static_vars={
        # The static variables are constant, so we just get them for the first time. They
        # don't need to be flipped along the latitude dimension, because they are from
        # ERA5.
        "z": torch.from_numpy(static_vars_ds["z"].values[0]),
        "slt": torch.from_numpy(static_vars_ds["slt"].values[0]),
        "lsm": torch.from_numpy(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": _prepare(atmos_vars_ds["temperature"].values),
        "u": _prepare(atmos_vars_ds["u_component_of_wind"].values),
        "v": _prepare(atmos_vars_ds["v_component_of_wind"].values),
        "q": _prepare(atmos_vars_ds["specific_humidity"].values),
        "z": _prepare(atmos_vars_ds["geopotential"].values),
    },
    metadata=Metadata(
        # Flip the latitudes! We need to copy because converting to PyTorch, because the
        # data must be contiguous.
        lat=torch.from_numpy(surf_vars_ds.latitude.values[::-1].copy()),
        lon=torch.from_numpy(surf_vars_ds.longitude.values),
        # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
        # `datetime.datetime`s. Note that this needs to be a tuple of length one:
        # one value for every batch element.
        time=(surf_vars_ds.time.values.astype("datetime64[s]").tolist()[i],),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.level.values),
    ),
)

In [None]:
def check_c_model(c_model):
    for name,param in c_model.backbone.named_parameters():
        assert not bool(torch.any(torch.isnan(param))), name

    c_model.eval()
    c_model = c_model.to("cuda")

    with torch.inference_mode():
        preds = c_model.forward(batch)
    c_model = c_model.to("cpu")

    for sh,v in c_model.surf_vars.items():
        assert not bool(torch.any(torch.isnan(v))), sh
    for sh,v in c_model.atmos_vars.items():
        assert not bool(torch.any(torch.isnan(v))), sh

    print('all good')

In [None]:
check_c_model(compression.svd_only_compression(original_model=model, ratio=0.5))

In [None]:
check_c_model(compression.fisher_base_compression(original_model=model, ratio=0.5, grad_path=Path("/workspace/models/fisher/multitask_exclude")))

In [None]:
check_c_model(compression.fisher_improved_compression(original_model=model, ratio=0.5, grad_path=Path("/workspace/models/fisher/multitask_exclude")))

### SVD compression loop

In [None]:
def comp_sameday_inference_loop(c_model, c_model_name):
    # Sameday
    if (save_dir / 'sameday' / f'{c_model_name}.csv').is_file():
        print('Already exists: ', str(save_dir / 'sameday' / f'{c_model_name}.csv'))
        return

    print('\t\tsameday')
    for day in sameday_starts:
        print(f'\t\t\t{day}')
        day_results_df = evaluation_helper.same_day_eval(model=c_model, day=day, download_path=download_path, device=device)

        if total_df is None:
            total_df = day_results_df.copy(deep=True)
        else:
            total_df = pd.concat([total_df, day_results_df], axis=0).reset_index(drop=True)

    (save_dir / 'sameday').mkdir(exist_ok=True, parents=True)
    total_df.to_csv(save_dir / 'sameday' / f'{c_model_name}.csv', index=False)
    
def comp_multiday_inference_loop(c_model, c_model_name):
    # Multiday
    if (save_dir / 'multiday' / f'{c_model_name}.csv').is_file():
        print('Already exists: ', str(save_dir / 'multiday' / f'{c_model_name}.csv'))

    total_df = None
    print('\t\tmultiday')
    for day in multiday_starts:
        print(f'\t\t\t{day}')
        md_results_df = evaluation_helper.multi_day_eval(
            model=c_model, day=day, download_path=download_path,
            max_n_days=n_multiday_days, device=device, verbose=False
        )

        if total_df is None:
            total_df = md_results_df.copy(deep=True)
        else:
            total_df = pd.concat([total_df, md_results_df], axis=0).reset_index(drop=True)

    (save_dir / 'multiday').mkdir(exist_ok=True, parents=True)
    total_df.to_csv(save_dir / 'multiday' / f'{c_model_name}.csv', index=False)

In [None]:
comp_sameday_inference_loop(
    c_model=model,
    c_model_name='baseline'
)

In [None]:
for ratio in compression_ratios:
    print(ratio)
    comp_sameday_inference_loop(
        c_model=compression.svd_only_compression(original_model=model, ratio=ratio),
        c_model_name=f'svd_{ratio}'
    )
print('DONE!!!')

In [None]:
# Baseline Fisher
for lh in lh_task_names[:2]:
    print(lh)
    for ratio in compression_ratios:
        print('\t', ratio)
        comp_sameday_inference_loop(
            c_model=compression.fisher_base_compression(original_model=model, ratio=ratio, grad_path=base_grad_dir / lh),
            c_model_name=f'fisher_base_{lh}_{ratio}'
        )
print('DONE!!!')

In [None]:
# Improved Fisher
for lh in lh_task_names[:2]:
    print(lh)
    for ratio in compression_ratios:
        print('\t', ratio)
        comp_sameday_inference_loop(
            c_model=compression.fisher_improved_compression(original_model=model, ratio=ratio, grad_path=base_grad_dir / lh),
            c_model_name=f'fisher_base_{lh}_{ratio}'
        )
print('DONE!!!')