In [2]:
import sys
sys.path.append('/iopsstor/scratch/cscs/stefschu/DSM500/github/modulus-a5275d8')

import torch
import numpy as np
import hydra
import h5py
import json
from omegaconf import DictConfig
from collections import defaultdict
from pathlib import Path
from hydra.utils import to_absolute_path
from modulus.distributed import DistributedManager

from modulus.launch.logging import PythonLogger
from modulus.datapipes.climate.era5_hdf5_newest import fix_latitude_alignment

# FIXIT: Move it under modulus
from inference import Inference

import os
os.environ["MODULUS_DISTRIBUTED_INITIALIZATION_METHOD"] = "ENV"
os.environ["SLURM_PROCID"] = os.environ["RANK"] = os.environ["SLURM_LOCALID"] = "0"
os.environ["SLURM_NTASKS"] = os.environ["SLURM_NPROCS"] = os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "23456"

import hydra
with hydra.initialize(config_path="..", version_base="1.3"):
    cfg = hydra.compose(config_name="config_new")

logger = PythonLogger("main")
logger.file_logging()
DistributedManager.initialize()

inference = Inference(cfg, logger)

class DataGenerator:
    def __init__(self, inference, cfg):
        self.inference = inference
        self.cfg = cfg

        self.model, self.datapipe = inference.load_model_and_datapipe()

        # Data containers
        self.initial_conditions = None
        self.outputs = None
        self.reanalysis = None
        self.generated_channels = None
        self.forecasts = None
        self.global_sample_ids = None
        self.timestamps = None

        # Climatology
        self.climatology = None

        # Metrics
        self.mse = None

        # Flags
        self.samples_loaded = False
        self.forecasts_computed = False

    def load_samples(self):
        assert self.samples_loaded is False, "Samples already loaded"

        for sample_i, data in enumerate(self.datapipe):
            # dict_keys(['epoch_idx', 'idx_in_epoch', 'global_sample_id', 'output', 'timestamps', 'input'])
            # data['input'].shape, data["output"].shape, data["global_sample_id"].shape, data["timestamps"].shape
            # (torch.Size([1, 1, 31, 721, 1440]),
            #  torch.Size([1, 24, 31, 721, 1440]),
            #  torch.Size([1]),
            #  torch.Size([1, 25]))
            # type(data['input']), type(data["output"]), type(data["global_sample_id"]), type(data["timestamps"])
            # (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
            data = data[0]

            if self.initial_conditions is None:
                self.initial_conditions = data['input']
                self.outputs = data["output"]
                self.global_sample_ids = data["global_sample_id"]
                self.timestamps = data["timestamps"]
            else:
                self.initial_conditions = torch.concatenate([self.initial_conditions, data['input']], dim=0)
                self.outputs = torch.concatenate([self.outputs, data["output"]], dim=0)
                self.global_sample_ids = torch.concatenate([self.global_sample_ids, data["global_sample_id"]], dim=0)
                self.timestamps = torch.concatenate([self.timestamps, data["timestamps"]], dim=0)

            if self.initial_conditions.shape[0] == self.cfg.inference.samples:
                break
        
        # torch.Size([3, 24, 21, 721, 1440]), torch.Size([3, 24, 10, 721, 1440]) <- outputs torch.Size([3, 24, 31, 721, 1440])
        channels_count_reanalysis = self.cfg.output_channels # supposedly 21
        channels_count_generated = self.outputs.shape[2] - channels_count_reanalysis # supposedly 10, if all options enabled
        self.reanalysis, self.generated_channels = torch.split(self.outputs, [channels_count_reanalysis, channels_count_generated], dim=2)
        self.outputs = None
        self.forecasts = torch.zeros(self.reanalysis.shape)

        # cast and move as needed
        self.initial_conditions = self.initial_conditions.to(inference.dtype).to(self.inference.device)
        self.reanalysis = self.reanalysis.to(inference.dtype).to(self.inference.device)
        self.generated_channels = self.generated_channels.to(inference.dtype).to(self.inference.device)
        self.forecasts = self.forecasts.to(inference.dtype).to(self.inference.device)

        # Load climatology for samples
        self.climatology = torch.zeros(self.reanalysis.shape, device="cpu")
        with h5py.File(Path(cfg.dataset.base_path) / "climatology.h5", "r") as f:
            for sample_i, global_sample_id in enumerate(self.global_sample_ids):
                idx_start = global_sample_id + 1
                idx_end = idx_start + self.cfg.inference.rollout_steps
                self.climatology[sample_i, :] = torch.tensor(f["climatology"][idx_start:idx_end])
        self.climatology = fix_latitude_alignment(self.climatology)

        self.samples_loaded = True

    def compute_forecasts(self):
        assert self.samples_loaded is True, "Samples not loaded"
        assert self.forecasts_computed is False, "Forecasts already computed"
        
        with torch.no_grad():
            for sample_i in range(self.cfg.inference.samples):
                for step in range(self.cfg.inference.rollout_steps + 1):
                    # Step 0 is the initial condition
                    if step == 0:
                        next_input = self.initial_conditions[sample_i].unsqueeze(0)
                    else:
                        self.forecasts[sample_i, step - 1] = self.model(next_input)
                        next_input = torch.concatenate([
                            self.forecasts[sample_i, step - 1],
                            self.generated_channels[sample_i, step - 1]
                        ], dim=0).unsqueeze(0).unsqueeze(0)

        # Now move everything off the GPU
        self.initial_conditions = self.initial_conditions.to(torch.float32).cpu()
        self.reanalysis = self.reanalysis.to(torch.float32).cpu()
        self.generated_channels = self.generated_channels.to(torch.float32).cpu()
        self.forecasts = self.forecasts.to(torch.float32).cpu()

        self.forecasts_computed = True

    def compute_metrics(self):
        assert self.forecasts_computed, "Forecasts not computed"

        # MSE
        self.mse = torch.square(self.reanalysis - self.forecasts).mean(dim=(-2, -1))
        
        # ACC
        forecast_anomalies = self.forecasts - self.climatology
        reanalysis_anomalies = self.reanalysis - self.climatology
        forecast_anomalies_std = torch.sqrt(torch.sum(torch.square(forecast_anomalies), dim=(-2, -1)))
        reanalysis_anomalies_std = torch.sqrt(torch.sum(torch.square(reanalysis_anomalies), dim=(-2, -1)))
        numerator = torch.sum(forecast_anomalies * reanalysis_anomalies, dim=(-2, -1))
        self.acc = numerator / (forecast_anomalies_std * reanalysis_anomalies_std)

dg = DataGenerator(inference, cfg)
dg.load_samples()
dg.compute_forecasts()
dg.compute_metrics()

  warn("Distributed manager is already intialized")


In [3]:
dg.acc[0,:, 0]

tensor([0.9987, 0.9978, 0.9963, 0.9947, 0.9924, 0.9905])

In [None]:
from inference import Inference



torch.Size([2, 10])