In [None]:
import torch
import pytorch_lightning as pl
from ens_transformer.transforms import to_tensor, Normalizer
from ens_transformer.data_module import IFSERADataModule

from tqdm.notebook import tqdm

In [None]:
pl.seed_everything(seed=42)

In [None]:
data_module = IFSERADataModule(include_vars=['t2m', 't_850', 'gh_500'])
data_module.setup()

In [None]:
rolling_sum = dict()
rolling_squared_sum = dict()
rolling_elems = 0

In [None]:
for ifs_data, era_data in tqdm(data_module.train_dataloader(), total=len(data_module.train_dataloader())):
    ifs_mean = ifs_data.mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
    ifs_squared_mean = ifs_data.pow(2).mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
    era_mean = era_data.mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
    era_squared_mean = era_data.pow(2).mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
    try:
        rolling_sum['ifs'] = rolling_sum['ifs']+ifs_mean
        rolling_squared_sum['ifs'] = rolling_squared_sum['ifs']+ifs_squared_mean
        rolling_sum['era'] = rolling_sum['era']+era_mean
        rolling_squared_sum['era'] = rolling_squared_sum['era']+era_squared_mean
    except:
        rolling_sum['ifs'] = ifs_mean
        rolling_squared_sum['ifs'] = ifs_squared_mean
        rolling_sum['era'] = era_mean
        rolling_squared_sum['era'] = era_squared_mean
    rolling_elems += ifs_data.shape[0]

In [None]:
mean_values = {k: v / rolling_elems for k, v in rolling_sum.items()}
stddev_values = {k: (v / rolling_elems - mean_values[k].pow(2)).sqrt() for k, v in rolling_squared_sum.items()}

In [None]:
normalizers = {
    'ifs': Normalizer(mean=mean_values['ifs'], std=stddev_values['ifs']),
    'era': Normalizer(mean=mean_values['era'], std=stddev_values['era']),
}

In [None]:
torch.save(normalizers, f='../data/interim/normalizers.pt')