In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '/iopsstor/scratch/cscs/stefschu/DSM500_FPR/modulus-baseline')

from modulus.utils.distributed_manager import DistributedManager as DM
from modulus.utils.caching import Cache
from omegaconf import OmegaConf
from modulus.models.graph_cast_net import GraphCastNet

cfg = OmegaConf.load('/iopsstor/scratch/cscs/stefschu/DSM500_FPR/modulus-baseline/conf/config.yaml')

DM.destroy()
DM.initialize()
Cache.initialize(dir=cfg.cache.dir)

  warn("DistributedManager: running in single process mode!")


In [2]:
from modulus.models.utils.loss import GraphCastLossFunction
from modulus.datapipes.era5_hdf5 import ERA5HDF5Datapipe

cfg.toggles.data.include_sst_channel = True
cfg.toggles.loss.use_original_variable_weights = False
cfg.toggles.loss.fix_inverse_variance_data = False


model = GraphCastNet(cfg)

datapipe = ERA5HDF5Datapipe(
    cfg=cfg,
    dataset_folder='train',
    num_output_steps=1,

    latitudes=model.latitudes,
    longitudes=model.longitudes,
    map_grid_to_latlon=model.map_grid_to_latlon,
    dtype=model.dtype(),

    iterator={
        "shuffle": True,
        "shuffle_seed": 0,
        "initial_epoch_idx": 0,
        "initial_sample_idx": 0
    }
)

sample = next(iter(datapipe))
data = sample["data"]

invar, outvar = data[0], data[1]

criteria = GraphCastLossFunction(
    cfg=cfg,
    area=model.area,
    channels_metadata=model.metadata
)

forecast = model(invar)

outvar = outvar[:forecast.shape[0]]
print(f"{invar.shape=} {forecast.shape=} {outvar.shape=}")

criteria(outvar, forecast)

invar.shape=torch.Size([31, 721, 1440]) forecast.shape=torch.Size([21, 721, 1440]) outvar.shape=torch.Size([21, 721, 1440])


tensor(2.0797, device='cuda:0', grad_fn=<MeanBackward0>)

In [7]:
criteria.inverse_variance_weights_old.dtype, criteria.inverse_variance_weights_new.dtype

(torch.float32, torch.bfloat16)

In [3]:
datapipe = ERA5HDF5Datapipe(
    cfg=cfg,
    dataset_folder='train',
    num_output_steps=1,

    latitudes=model.latitudes,
    longitudes=model.longitudes,
    map_grid_to_latlon=model.map_grid_to_latlon,
    dtype=model.dtype(),

    iterator={
        "shuffle": True,
        "shuffle_seed": 0,
        "initial_epoch_idx": 0,
        "initial_sample_idx": 0
    }
)

sample = next(iter(datapipe))
data = sample["data"]

for include_sst_channel in [True, False]:
    for use_original_variable_weights in [True, False]:
        for fix_inverse_variance_data in [True, False]:
            cfg.toggles.data.include_sst_channel = include_sst_channel
            cfg.toggles.loss.use_original_variable_weights = use_original_variable_weights
            cfg.toggles.loss.fix_inverse_variance_data = fix_inverse_variance_data

            model = GraphCastNet(cfg)

            if cfg.toggles.data.include_sst_channel:
                invar, outvar = data[0][:21], data[1][:21]
            else:
                invar, outvar = data[0][:20], data[1][:20]

            criteria = GraphCastLossFunction(
                cfg=cfg,
                area=model.area,
                channels_metadata=model.metadata
            )

            loss = criteria(invar, outvar)

            print(f"{include_sst_channel=} {use_original_variable_weights=} {fix_inverse_variance_data=}")
            print(f"loss: {loss=}")

include_sst_channel=True use_original_variable_weights=True fix_inverse_variance_data=True
loss: loss=tensor(0.0091, device='cuda:0')
include_sst_channel=True use_original_variable_weights=True fix_inverse_variance_data=False
loss: loss=tensor(0.0512, device='cuda:0')
include_sst_channel=True use_original_variable_weights=False fix_inverse_variance_data=True
loss: loss=tensor(0.0080, device='cuda:0', dtype=torch.bfloat16)
include_sst_channel=True use_original_variable_weights=False fix_inverse_variance_data=False
loss: loss=tensor(0.0217, device='cuda:0')
include_sst_channel=False use_original_variable_weights=True fix_inverse_variance_data=True
loss: loss=tensor(0.0029, device='cuda:0')
include_sst_channel=False use_original_variable_weights=True fix_inverse_variance_data=False
loss: loss=tensor(0.0029, device='cuda:0')
include_sst_channel=False use_original_variable_weights=False fix_inverse_variance_data=True
loss: loss=tensor(0.0026, device='cuda:0')
include_sst_channel=False use_o

In [20]:
import numpy as np

old = np.load('/iopsstor/scratch/cscs/stefschu/DSM500_FPR/data/FCN_ERA5_data_v0/stats/time_diff_std.npy')
new = np.load('/iopsstor/scratch/cscs/stefschu/DSM500_FPR/data/FCN_ERA5_data_v0/stats/time_diff_std_with_sst_fix.npy')

print(old.shape, new.shape)
old = old.squeeze()
new = new.squeeze()

for c in range(21):

    print(f"{c:2} {old[c]:9.2f} {new[c]:9.2f}")

(1, 21) (1, 21)
 0      2.31      2.30
 1      2.64      2.61
 2      3.27      3.17
 3    352.00    254.49
 4    364.00    269.10
 5      1.88      1.67
 6      2.55      2.53
 7      2.92      2.89
 8    221.00    219.14
 9      3.19      3.15
10      3.58      3.52
11    199.00    195.47
12      3.94      3.94
13      4.81      4.77
14    266.00    235.60
15      1.57      1.35
16    478.00    236.01
17     22.62     22.69
18     18.50     18.36
19      3.48      3.47
20      0.08      3.09


In [17]:
import numpy as np

old = np.load('/iopsstor/scratch/cscs/stefschu/DSM500_FPR/data/FCN_ERA5_data_v0/stats/global_means.npy')
new = np.load('/iopsstor/scratch/cscs/stefschu/DSM500_FPR/data/FCN_ERA5_data_v0/stats/global_means_with_sst_fix.npy')

print(old.shape, old.dtype, new.shape, new.dtype)
old = old.squeeze()
new = new.squeeze()

for c in range(21):

    print(f"{c:2} {old[c]:9.2f} {new[c]:9.2f}")

(1, 21, 1, 1) float64 (1, 21, 1, 1) float32
 0     -0.05     -0.05
 1      0.19      0.19
 2    278.45    278.44
 3  96650.39  96650.43
 4 100957.49 100957.75
 5    274.53    274.52
 6     -0.03     -0.03
 7      0.19      0.19
 8    737.07    737.18
 9      1.42      1.42
10      0.14      0.14
11  13747.95  13747.48
12      6.55      6.56
13     -0.02     -0.02
14  54110.10  54108.05
15    252.93    252.92
16 199361.21 199359.03
17     50.42     50.43
18     69.13     69.13
19     18.30     18.29
20 -10915.60    280.50


In [16]:
import numpy as np

old = np.load('/iopsstor/scratch/cscs/stefschu/DSM500_FPR/data/FCN_ERA5_data_v0/stats/global_stds.npy')
new = np.load('/iopsstor/scratch/cscs/stefschu/DSM500_FPR/data/FCN_ERA5_data_v0/stats/global_stds_with_sst_fix.npy')

print(old.shape, old.dtype, new.shape, new.dtype)
old = old.squeeze()
new = new.squeeze()

for c in range(21):

    print(f"{c:2} {old[c]:9.2f} {new[c]:9.2f}")

(1, 21, 1, 1) float64 (1, 21, 1, 1) float32
 0      5.54      5.54
 1      4.76      4.76
 2     21.29     21.30
 3   9587.90   9588.23
 4   1332.69   1332.07
 5     15.63     15.63
 6      6.14      6.14
 7      5.30      5.30
 8   1072.83   1072.37
 9      8.19      8.18
10      6.26      6.26
11   1471.42   1471.12
12     11.98     11.98
13      9.18      9.18
14   3357.15   3357.24
15     13.07     13.08
16   5895.93   5894.05
17     33.58     33.58
18     26.41     26.42
19     16.39     16.39
20  15645.77     20.43


In [28]:
import torch

torch.zeros((1, 4, 31, 721, 1440)).to(torch.float32).nbytes

514967040