In [7]:
%cd /localhome/prateiksinha/atmos-arena/atmos_arena
import torch
import numpy as np
from stormer_arch import Stormer
from s2s.mini_stormer_arch import MiniStormer

from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger
from s2s.window_module import WindowForecastingModule
from s2s.window_datamodule import WindowDataModule
from s2s.window_datamodule import ERA5WindowDataset

/localhome/prateiksinha/atmos-arena/atmos_arena


In [2]:
device = 'cuda:8'
variables = [
    "2m_temperature",
    "10m_u_component_of_wind",
    "10m_v_component_of_wind",
    # "geopotential_50",
    # "geopotential_250",
    # "geopotential_500",
    # "geopotential_600",
    # "geopotential_700",
    # "geopotential_850",
    # "geopotential_925",
    # "u_component_of_wind_50",
    # "u_component_of_wind_250",
    # "u_component_of_wind_500",
    # "u_component_of_wind_600",
    # "u_component_of_wind_700",
    # "u_component_of_wind_850",
    # "u_component_of_wind_925",
    # "v_component_of_wind_50",
    # "v_component_of_wind_250",
    # "v_component_of_wind_500",
    # "v_component_of_wind_600",
    # "v_component_of_wind_700",
    # "v_component_of_wind_850",
    # "v_component_of_wind_925",
    # "temperature_50",
    # "temperature_250",
    # "temperature_500",
    # "temperature_600",
    # "temperature_700",
    # "temperature_850",
    # "temperature_925",
    # "specific_humidity_50",
    # "specific_humidity_250",
    # "specific_humidity_500",
    # "specific_humidity_600",
    # "specific_humidity_700",
    # "specific_humidity_850",
    # "specific_humidity_925",
]

In [3]:
num_files_per_day = 4
num_weeks = 6
datamodule = WindowDataModule(
    root_dir = '/localhome/data/datasets/climate/wb2/1.40625deg_6hr_h5df',
    in_variables = variables,
    out_variables = variables,
    num_steps_in_output = 7 * num_weeks * num_files_per_day,
    lead_time = 6
)
datamodule.setup()
test_dataloader = datamodule.test_dataloader()

model = DirectForecastingModule(
    net = Stormer(
        in_img_size = [128,256],
        in_variables = variables,
    ).to(device),
    pretrained_path = "",
    # pretrained_path = "/localhome/data/ckpts/tungnd/stormer/6_12_24_climax_large_2_True_delta_8/checkpoints/epoch_015.ckpt",
)

model.set_transforms(datamodule.in_transforms, datamodule.out_transforms)
model.set_lat_lon(*datamodule.get_lat_lon())
model.set_lead_time(datamodule.hparams.lead_time)

In [4]:
def check(dataloader):
    sample_batch = next(iter(dataloader))
    for i in sample_batch:
        t = type(i)
        l = len(i) if t == list else i.shape
        print(f'{t} | {l}')

print('Checking batch contents:')
check(test_dataloader)

Checking batch contents:
<class 'torch.Tensor'> | torch.Size([1, 3, 128, 256])
<class 'torch.Tensor'> | torch.Size([1, 168, 3, 128, 256])
<class 'torch.Tensor'> | torch.Size([1])
<class 'list'> | 3
<class 'list'> | 3


In [None]:
# week 1: 0,1,2,3 | 4,5,6,7 | ... | 24,25,26,27
# week 2: 28  ... 
# week 3: 56  ... 
# week 4: 84  ... 
# week 5: 112 ... 
#
# window containing weeks x thru y = (28*(x-1), 28*y) = files_per_day x (7*(x-1), 7*y)

with torch.no_grad():
    model.net.eval()
    for index, batch in enumerate(test_dataloader):
        x = model.test_step_window(
            batch, 
            windows=[
                (7*4*2, 7*4*4), # weeks 3-4
                (7*4*4, 7*4*6)  # weeks 5-6
            ], 
            device=device
        )
        print(f'batch number {index} complete')

In [6]:
x

{56: {'w_mse_2m_temperature': tensor(5.4079e+08, device='cuda:8'),
  'w_mse_10m_u_component_of_wind': tensor(34.7983, device='cuda:8'),
  'w_mse_10m_v_component_of_wind': tensor(275.4200, device='cuda:8'),
  'w_mse': 180263860.0,
  'w_rmse_2m_temperature': tensor(23254.9180, device='cuda:8'),
  'w_rmse_10m_u_component_of_wind': tensor(5.8990, device='cuda:8'),
  'w_rmse_10m_v_component_of_wind': tensor(16.5958, device='cuda:8'),
  'w_rmse': 7759.137},
 112: {'w_mse_2m_temperature': tensor(1.5094e+09, device='cuda:8'),
  'w_mse_10m_u_component_of_wind': tensor(66.1080, device='cuda:8'),
  'w_mse_10m_v_component_of_wind': tensor(728.8865, device='cuda:8'),
  'w_mse': 503141660.0,
  'w_rmse_2m_temperature': tensor(38851.3086, device='cuda:8'),
  'w_rmse_10m_u_component_of_wind': tensor(8.1307, device='cuda:8'),
  'w_rmse_10m_v_component_of_wind': tensor(26.9979, device='cuda:8'),
  'w_rmse': 12962.145}}

- add logging
- make it a script that you can leave to run

X
- add new metrics