In [1]:
# ---------- #
# System
import gc
import os
import sys
import yaml
import glob
import logging
import warnings
import traceback
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp

# ---------- #
# Numerics
import datetime
import pandas as pd
import xarray as xr
import numpy as np

# ---------- #
# AI libs
import torch
import torch.distributed as dist
from torchvision import transforms
# import wandb

# ---------- #
# credit
from credit.data import ERA5Dataset, concat_and_reshape
from credit.models import load_model
from credit.transforms import ToTensor, NormalizeState, NormalizeState_Quantile
from credit.seed import seed_everything
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper
from credit.models.checkpoint import load_model_state
from credit.solar import TOADataLoader
from credit.output import split_and_reshape, load_metadata, make_xarray, save_netcdf_increment
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

In [2]:
config_name = '/glade/u/home/ksha/miles-credit/results/fuxi_norm/model.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [5]:
test = xr.open_dataset('/glade/derecho/scratch/ksha/CREDIT/wx_former_6h/2020-01-01T00Z/pred_2020-01-01T00Z_006.nc')

In [3]:
rank = 0
world_size = 1

In [4]:
class PredictForecast(torch.utils.data.IterableDataset):
    def __init__(self,
                 filenames,
                 forecasts,
                 history_len,
                 skip_periods,
                 rank,
                 world_size,
                 shuffle=False,
                 transform=None,
                 rollout_p=0.0,
                 which_forecast=None):

        self.dataset = ERA5Dataset(
            filenames=filenames,
            history_len=history_len,
            forecast_len=1,
            skip_periods=skip_periods,
            transform=transform
        )
        self.meta_data_dict = self.dataset.meta_data_dict
        self.all_files = self.dataset.all_fils
        self.history_len = history_len
        self.filenames = filenames
        self.transform = transform
        self.rank = rank
        self.world_size = world_size
        self.shuffle = shuffle
        self.skip_periods = skip_periods
        self.current_epoch = 0
        self.rollout_p = rollout_p
        self.forecasts = forecasts
        self.skip_periods = skip_periods if skip_periods is not None else 1
        self.which_forecast = which_forecast

    def find_start_stop_indices(self, index):
        start_time = self.forecasts[index][0]
        date_object = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S')
        shifted_hours = self.skip_periods * self.history_len
        date_object = date_object - datetime.timedelta(hours=shifted_hours)
        self.forecasts[index][0] = date_object.strftime('%Y-%m-%d %H:%M:%S')

        datetime_objs = [np.datetime64(date) for date in self.forecasts[index]]
        start_time, stop_time = [str(datetime_obj) + '.000000000' for datetime_obj in datetime_objs]
        self.start_time = np.datetime64(start_time).astype(datetime.datetime)
        self.stop_time = np.datetime64(stop_time).astype(datetime.datetime)

        info = {}

        for idx, dataset in enumerate(self.all_files):
            start_time = np.datetime64(dataset['time'].min().values).astype(datetime.datetime)
            stop_time = np.datetime64(dataset['time'].max().values).astype(datetime.datetime)
            track_start = False
            track_stop = False

            if start_time <= self.start_time <= stop_time:
                # Start time is in this file, use start time index
                dataset = np.array([np.datetime64(x.values).astype(datetime.datetime) for x in dataset['time']])
                start_idx = np.searchsorted(dataset, self.start_time)
                start_idx = max(0, min(start_idx, len(dataset)-1))
                track_start = True

            elif start_time < self.stop_time and stop_time > self.start_time:
                # File overlaps time range, use full file
                start_idx = 0
                track_start = True

            if start_time <= self.stop_time <= stop_time:
                # Stop time is in this file, use stop time index
                if isinstance(dataset, np.ndarray):
                    pass
                else:
                    dataset = np.array([np.datetime64(x.values).astype(datetime.datetime) for x in dataset['time']])
                stop_idx = np.searchsorted(dataset, self.stop_time)
                stop_idx = max(0, min(stop_idx, len(dataset)-1))
                track_stop = True

            elif start_time < self.stop_time and stop_time > self.start_time:
                # File overlaps time range, use full file
                stop_idx = len(dataset) - 1
                track_stop = True

            # Only include files that overlap the time range
            if track_start and track_stop:
                info[idx] = ((idx, start_idx), (idx, stop_idx))

        indices = []
        for dataset_idx, (start, stop) in info.items():
            for i in range(start[1], stop[1]+1):
                indices.append((start[0], i))
        return indices

    def __len__(self):
        return len(self.forecasts)

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        sampler = DistributedSampler(self, num_replicas=num_workers*self.world_size, rank=self.rank*num_workers+worker_id, shuffle=self.shuffle)

        for index in sampler:

            data_lookup = self.find_start_stop_indices(index)

            for k, (file_key, time_key) in enumerate(data_lookup):

                if k == 0:
                    concatenated_samples = {'x': [], 'x_surf': []}
                    sliced_x = xr.open_zarr(self.filenames[file_key], consolidated=True).isel(time=slice(time_key, time_key+self.history_len+1))

                    # Check if additional data from the next file is needed
                    if len(sliced_x['time']) < self.history_len + 1:
                        # Load excess data from the next file
                        next_file_idx = self.filenames.index(self.filenames[file_key]) + 1
                        if next_file_idx == len(self.filenames):
                            raise OSError("You have reached the end of the available data. Exiting.")
                        sliced_x_next = xr.open_zarr(
                            self.filenames[next_file_idx],
                            consolidated=True).isel(time=slice(0, self.history_len+1-len(sliced_x['time'])))

                        # Concatenate excess data from the next file with the current data
                        sliced_x = xr.concat([sliced_x, sliced_x_next], dim='time')

                    sample_x = {
                        'x': sliced_x.isel(time=slice(0, self.history_len))
                    }

                    if self.transform:
                        sample_x = self.transform(sample_x)
                        # Add static vars, if any, to the return dictionary
                        if "static" in sample_x:
                            concatenated_samples["static"] = []
                        if "TOA" in sample_x:
                            concatenated_samples["TOA"] = []

                    for key in concatenated_samples.keys():
                        concatenated_samples[key] = sample_x[key].squeeze(0) if self.history_len == 1 else sample_x[key]

                    concatenated_samples['forecast_hour'] = k + 1
                    concatenated_samples['stop_forecast'] = (k == (len(data_lookup)-self.history_len-1))  # Adjust stopping condition
                    concatenated_samples['datetime'] = sliced_x.time.values.astype('datetime64[s]').astype(int)[-1]

                else:
                    concatenated_samples['forecast_hour'] = k + 1
                    concatenated_samples['stop_forecast'] = (k == (len(data_lookup)-self.history_len-1))  # Adjust stopping condition

                yield concatenated_samples

                if concatenated_samples['stop_forecast']:
                    break


def setup(rank, world_size, mode):
    logging.info(f"Running {mode.upper()} on rank {rank} with world_size {world_size}.")
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

In [5]:
if conf["trainer"]["mode"] in ["fsdp", "ddp"]:
    setup(rank, world_size, conf["trainer"]["mode"])

# infer device id from rank
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    torch.cuda.set_device(rank % torch.cuda.device_count())
else:
    device = torch.device("cpu")

# Config settings
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

history_len = conf["data"]["history_len"]
time_step = conf["data"]["time_step"] if "time_step" in conf["data"] else None

# Load paths to all ERA5 data available
all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))

# Preprocessing transformations
if conf["data"]["scaler_type"] == "std":
    state_transformer = NormalizeState(conf)
else:
    state_transformer = NormalizeState_Quantile(conf)
transform = transforms.Compose(
    [
        state_transformer,
        ToTensor(conf),
    ]
)

dataset = PredictForecast(
    filenames=all_ERA_files,
    forecasts=load_forecasts(conf),
    history_len=history_len,
    skip_periods=time_step,
    transform=transform,
    rank=rank,
    world_size=world_size,
    shuffle=False,
)

# setup the dataloder for this process
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=False,
)

# load model
model = load_model(conf, load_weights=True).to(device)

# Warning -- see next line
distributed = conf["trainer"]["mode"] in ["ddp", "fsdp"]
if distributed:  # A new field needs to be added to predict
    model = distributed_model_wrapper(conf, model, device)
    if conf["trainer"]["mode"] == "fsdp":
        # Load model weights (if any), an optimizer, scheduler, and gradient scaler
        model = load_model_state(conf, model, device)

model.eval()

# get lat/lons from x-array
latlons = xr.open_dataset(conf["loss"]["latitude_weights"])

meta_data = load_metadata(conf)

# Set up the diffusion and pole filters
if (
    "use_laplace_filter" in conf["predict"]
    and conf["predict"]["use_laplace_filter"]
):
    dpf = Diffusion_and_Pole_Filter(
        nlat=conf["model"]["image_height"],
        nlon=conf["model"]["image_width"],
        device=device,
    )

# Rollout
with torch.no_grad():
    # forecast count = a constant for each run
    forecast_count = 0

    # y_pred allocation
    y_pred = None
    static = None
    results = []

    # model inference loop
    for k, batch in enumerate(data_loader):

        # get the datetime and forecasted hours
        date_time = batch["datetime"].item()
        forecast_hour = batch["forecast_hour"].item()

        # initialization on the first forecast hour
        if forecast_hour == 1:
            # Initialize x and x_surf with the first time step
            #x = model.concat_and_reshape(batch["x"], batch["x_surf"]).to(device)
            x = concat_and_reshape(batch["x"], batch["x_surf"]).to(device)
            
            init_datetime_str = datetime.datetime.utcfromtimestamp(date_time)
            init_datetime_str = init_datetime_str.strftime('%Y-%m-%dT%HZ')

        # Add statics
        if "static" in batch:
            if static is None:
                static = batch["static"].to(device).unsqueeze(2).expand(-1, -1, x.shape[2], -1, -1).float()
            x = torch.cat((x, static.clone()), dim=1)

        # Add solar "statics"
        if "static_variables" in conf["data"] and "tsi" in conf["data"]["static_variables"]:
            if k == 0:
                toaDL = TOADataLoader(conf)
            elapsed_time = pd.Timedelta(hours=k)
            tnow = pd.to_datetime(datetime.datetime.utcfromtimestamp(batch["datetime"]))
            tnow = tnow + elapsed_time
            if history_len == 1:
                current_times = [pd.to_datetime(datetime.datetime.utcfromtimestamp(_t)) + elapsed_time for _t in tnow]
            else:
                current_times = [tnow if hl == 0 else tnow - pd.Timedelta(hours=hl) for hl in range(history_len)]

            toa = torch.cat([toaDL(_t) for _t in current_times], dim=0).to(device)
            toa = toa.squeeze().unsqueeze(0)
            x = torch.cat([x, toa.unsqueeze(1).to(device).float()], dim=1)

        # Predict and convert to real space for laplace filter and metrics
        raise
        y_pred = model(x)
        y_pred = state_transformer.inverse_transform(y_pred.cpu())

        if ("use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]):
            y_pred = (
                dpf.diff_lap2d_filt(y_pred.to(device).squeeze())
                .unsqueeze(0)
                .unsqueeze(2)
                .cpu()
            )

        # Save the current forecast hour data in parallel
        utc_datetime = datetime.datetime.utcfromtimestamp(date_time) + datetime.timedelta(hours=forecast_hour)

        # convert the current step result as x-array
        darray_upper_air, darray_single_level = make_xarray(
            y_pred,
            utc_datetime,
            latlons.latitude.values,
            latlons.longitude.values,
            conf,
        )

        # Save the current forecast hour data in parallel
        result = p.apply_async(
            save_netcdf_increment,
            (darray_upper_air, darray_single_level, init_datetime_str, forecast_hour, meta_data, conf)
        )
        results.append(result)

        # Update the input
        # setup for next iteration, transform to z-space and send to device
        y_pred = state_transformer.transform_array(y_pred).to(device)

        if history_len == 1:
            x = y_pred.detach()
        else:
            # use multiple past forecast steps as inputs
            static_dim_size = abs(x.shape[1] - y_pred.shape[1])  # static channels will get updated on next pass
            x_detach = x[:, :-static_dim_size, 1:].detach() if static_dim_size else x[:, :, 1:].detach()  # if static_dim_size=0 then :0 gives empty range
            x = torch.cat([x_detach, y_pred.detach()], dim=2)

        # Explicitly release GPU memory
        torch.cuda.empty_cache()
        gc.collect()

        if batch["stop_forecast"][0]:
            # Wait for all processes to finish in order
            for result in results:
                result.get()

            # Now merge all the files into one and delete leftovers
            # merge_netcdf_files(init_datetime_str, conf)

            # forecast count = a constant for each run
            forecast_count += 1

            # update lists
            results = []

            # y_pred allocation
            y_pred = None

            gc.collect()

            if distributed:
                torch.distributed.barrier()

if distributed:
    torch.distributed.barrier()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


RuntimeError: No active exception to reraise

In [6]:
x[0, -4, ...]

tensor([[[-0.5151, -0.5148, -0.5146,  ..., -0.5156, -0.5154, -0.5152],
         [-0.4601, -0.4599, -0.4595,  ..., -0.4609, -0.4606, -0.4602],
         [-0.4083, -0.4081, -0.4079,  ..., -0.4088, -0.4086, -0.4085],
         ...,
         [-0.6567, -0.6568, -0.6568,  ..., -0.6565, -0.6566, -0.6567],
         [-0.6584, -0.6584, -0.6584,  ..., -0.6583, -0.6583, -0.6583],
         [-0.6583, -0.6583, -0.6584,  ..., -0.6583, -0.6583, -0.6583]],

        [[-0.5215, -0.5213, -0.5212,  ..., -0.5218, -0.5217, -0.5216],
         [-0.4651, -0.4650, -0.4648,  ..., -0.4655, -0.4653, -0.4651],
         [-0.4081, -0.4081, -0.4080,  ..., -0.4081, -0.4081, -0.4081],
         ...,
         [-0.6479, -0.6481, -0.6482,  ..., -0.6473, -0.6475, -0.6477],
         [-0.6515, -0.6516, -0.6517,  ..., -0.6513, -0.6513, -0.6514],
         [-0.6536, -0.6536, -0.6536,  ..., -0.6535, -0.6535, -0.6536]]],
       device='cuda:0')

### compare old & new rollouts

In [13]:
new = xr.open_dataset('/glade/derecho/scratch/ksha/CREDIT/fuxi_norm_new/2018-06-01T00Z/pred_2018-06-01T00Z_001.nc')

  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)


In [15]:
old = xr.open_dataset('/glade/derecho/scratch/ksha/CREDIT/fuxi_norm_test/2018-06-01T00Z/pred_2018-06-01T00Z_001.nc')

  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)


In [14]:
np.array(new['t2m'])

array([[[-0.34469616, -0.3446092 , -0.3437317 , ..., -0.34757087,
         -0.34641483, -0.34614974],
        [-0.33615398, -0.3353314 , -0.33522156, ..., -0.3368566 ,
         -0.3371188 , -0.33786282],
        [-0.3358565 , -0.33543956, -0.3358222 , ..., -0.33780622,
         -0.337941  , -0.33799165],
        ...,
        [-2.2178302 , -2.2207105 , -2.22325   , ..., -2.2133207 ,
         -2.2163541 , -2.2175353 ],
        [-2.234412  , -2.2354207 , -2.2362342 , ..., -2.2284808 ,
         -2.2312403 , -2.2342846 ],
        [-2.2077699 , -2.2087858 , -2.2077293 , ..., -2.202627  ,
         -2.2074575 , -2.2098649 ]]], dtype=float32)

In [16]:
np.array(old['t2m'])

array([[[-0.344811  , -0.3447332 , -0.34384775, ..., -0.34767446,
         -0.34650114, -0.3462309 ],
        [-0.336249  , -0.3354228 , -0.3353209 , ..., -0.3369275 ,
         -0.33717674, -0.33792216],
        [-0.335923  , -0.3355071 , -0.33589032, ..., -0.33785227,
         -0.33797932, -0.33802745],
        ...,
        [-2.2178102 , -2.2206895 , -2.2232304 , ..., -2.2132993 ,
         -2.2163353 , -2.2175174 ],
        [-2.23439   , -2.2353997 , -2.2362142 , ..., -2.2284608 ,
         -2.2312205 , -2.234268  ],
        [-2.207747  , -2.2087667 , -2.2077122 , ..., -2.202608  ,
         -2.2074409 , -2.209847  ]]], dtype=float32)

In [None]:
y_pred.shape

In [24]:
from credit.data import Sample
from typing import Dict

In [27]:
logger = logging.getLogger(__name__)

In [29]:
state_transformer = Normalize_ERA5_and_Forcing(conf)

In [30]:
y_pred = state_transformer.inverse_transform(y_pred.cpu())

In [31]:
y_pred.shape

torch.Size([1, 67, 1, 640, 1280])

In [32]:
y_pred

tensor([[[[[ 9.6050e+00,  9.6030e+00,  9.5811e+00,  ...,  9.6841e+00,
             9.6410e+00,  9.6423e+00],
           [ 9.6263e+00,  9.5931e+00,  9.5919e+00,  ...,  9.6613e+00,
             9.6555e+00,  9.6297e+00],
           [ 9.7909e+00,  9.7778e+00,  9.7666e+00,  ...,  9.8450e+00,
             9.8432e+00,  9.8334e+00],
           ...,
           [ 1.7022e-01,  3.5282e-01,  5.4143e-01,  ...,  5.4510e-02,
             1.0519e-01,  1.4594e-01],
           [ 2.9755e-01,  4.7897e-01,  6.9006e-01,  ...,  3.1048e-01,
             2.8333e-01,  2.2607e-01],
           [ 1.0609e-01,  3.1670e-01,  5.3178e-01,  ...,  1.3572e-01,
             1.2455e-01,  5.1676e-02]]],


         [[[ 4.2611e+00,  4.2457e+00,  4.2450e+00,  ...,  4.2849e+00,
             4.2833e+00,  4.2390e+00],
           [ 4.4083e+00,  4.3953e+00,  4.3780e+00,  ...,  4.4362e+00,
             4.4175e+00,  4.3946e+00],
           [ 4.4866e+00,  4.4929e+00,  4.4759e+00,  ...,  4.5344e+00,
             4.5213e+00,  4.5233e+00],