In [1]:
# ---------- #
# System
import gc
import os
import sys
import yaml
import glob
import logging
import warnings
import traceback
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp

# ---------- #
# Numerics
import datetime
import pandas as pd
import xarray as xr
import numpy as np

# ---------- #
# AI libs
import torch
import torch.distributed as dist
from torchvision import transforms
# import wandb

# ---------- #
# credit
from credit.data import Predict_Dataset, concat_and_reshape
from credit.models import load_model
from credit.transforms import Normalize_ERA5_and_Forcing, load_transforms
from credit.seed import seed_everything
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper
from credit.models.checkpoint import load_model_state
from credit.solar import TOADataLoader
from credit.output import split_and_reshape, load_metadata, make_xarray, save_netcdf_increment
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

In [2]:
config_name = '/glade/u/home/ksha/miles-credit/results/fuxi_norm/model_new.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [3]:
rank = 0
world_size = 1

In [4]:
# infer device id from rank
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    torch.cuda.set_device(rank % torch.cuda.device_count())
else:
    device = torch.device("cpu")

In [5]:
history_len = conf["data"]["history_len"]
transform = load_transforms(conf)

if conf["data"]["scaler_type"] == 'std_new':
    state_transformer = Normalize_ERA5_and_Forcing(conf)
# ----------------------------------------------------------------- #
# parse varnames and save_locs from config
## upper air variables
all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))
varname_upper_air = conf['data']['variables']

## surface variables
if "save_loc_surface" in conf["data"]:
    surface_files = sorted(glob.glob(conf["data"]["save_loc_surface"]))
    varname_surface = conf['data']['surface_variables']
else:
    surface_files = None
    varname_surface = None 

## forcing variables
if ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0):
    forcing_files = conf['data']['save_loc_forcing']
    varname_forcing = conf['data']['forcing_variables']
else:
    forcing_files = None
    varname_forcing = None

## static variables
if ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0):
    static_files = conf['data']['save_loc_static']
    varname_static = conf['data']['static_variables']
else:
    static_files = None
    varname_static = None

In [6]:
# ----------------------------------------------------------------- #\
# get dataset
dataset = Predict_Dataset(
    conf, 
    varname_upper_air,
    varname_surface,
    varname_forcing,
    varname_static,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    fcst_datetime=load_forecasts(conf),
    history_len=history_len,
    rank=rank,
    world_size=world_size,
    transform=transform,
    rollout_p=0.0,
    which_forecast=None
)
# setup the dataloder
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=False,
)

In [7]:
# load model
model = load_model(conf, load_weights=True).to(device)

if conf["trainer"]["mode"] in ["ddp", "fsdp"]:    
    # A new field needs to be added to predict
    model = distributed_model_wrapper(conf, model, device)
    
    if conf["trainer"]["mode"] == "fsdp":
        # Load model weights (if any), an optimizer, scheduler, and gradient scaler
        model = load_model_state(conf, model, device)

#model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [8]:
# get lat/lons from x-array
latlons = xr.open_dataset(conf["loss"]["latitude_weights"])

meta_data = load_metadata(conf)

# Set up the diffusion and pole filters
if (
    "use_laplace_filter" in conf["predict"]
    and conf["predict"]["use_laplace_filter"]
):
    dpf = Diffusion_and_Pole_Filter(
        nlat=conf["model"]["image_height"],
        nlon=conf["model"]["image_width"],
        device=device,
    )


In [9]:
# Rollout
with torch.no_grad():
    # forecast count = a constant for each run
    forecast_count = 0

    # y_pred allocation
    y_pred = None
    static = None
    results = []

    # model inference loop
    for k, batch in enumerate(data_loader):

        # get the datetime and forecasted hours
        date_time = batch["datetime"].item()
        forecast_hour = batch["forecast_hour"].item()

        # initialization on the first forecast hour
        if forecast_hour == 1:
            # Initialize x and x_surf with the first time step
            #x = model.concat_and_reshape(batch["x"], batch["x_surf"]).to(device)
            #x = concat_and_reshape(batch["x"], batch["x_surf"]).to(device)

            if "x_surf" in batch:
                # combine x and x_surf
                # input: (batch_num, time, var, level, lat, lon), (batch_num, time, var, lat, lon) 
                # output: (batch_num, var, time, lat, lon), 'x' first and then 'x_surf'
                x = concat_and_reshape(batch["x"], batch["x_surf"]).to(device).float()
            else:
                # no x_surf
                x = reshape_only(batch["x"]).to(self.device).float()

            # add forcing and static variables
            if 'x_forcing_static' in batch:
                
                # (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
                x_forcing_batch = batch['x_forcing_static'].to(device).permute(0, 2, 1, 3, 4).float()

                # concat on var dimension
                x = torch.cat((x, x_forcing_batch), dim=1)
            
            init_datetime_str = datetime.datetime.utcfromtimestamp(date_time)
            init_datetime_str = init_datetime_str.strftime('%Y-%m-%dT%HZ')
            
        # Predict and convert to real space for laplace filter and metrics
        raise
        y_pred = model(x)
        y_pred = state_transformer.inverse_transform(y_pred.cpu())

        if ("use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]):
            y_pred = (
                dpf.diff_lap2d_filt(y_pred.to(device).squeeze())
                .unsqueeze(0)
                .unsqueeze(2)
                .cpu()
            )

        # Save the current forecast hour data in parallel
        utc_datetime = datetime.datetime.utcfromtimestamp(date_time) + datetime.timedelta(hours=forecast_hour)

        # convert the current step result as x-array
        darray_upper_air, darray_single_level = make_xarray(
            y_pred,
            utc_datetime,
            latlons.latitude.values,
            latlons.longitude.values,
            conf,
        )

        # Save the current forecast hour data in parallel
        result = p.apply_async(
            save_netcdf_increment,
            (darray_upper_air, darray_single_level, init_datetime_str, forecast_hour, meta_data, conf)
        )
        results.append(result)

        # Update the input
        # setup for next iteration, transform to z-space and send to device
        y_pred = state_transformer.transform_array(y_pred).to(device)

        if history_len == 1:
            x = y_pred.detach()
        else:
            # use multiple past forecast steps as inputs
            # static channels will get updated on next pass
            static_dim_size = abs(x.shape[1] - y_pred.shape[1])  
            # if static_dim_size=0 then :0 gives empty range
            x_detach = x[:, :-static_dim_size, 1:].detach() if static_dim_size else x[:, :, 1:].detach()  
            x = torch.cat([x_detach, y_pred.detach()], dim=2)

        # Explicitly release GPU memory
        torch.cuda.empty_cache()
        gc.collect()

        if batch["stop_forecast"][0]:
            # Wait for all processes to finish in order
            for result in results:
                result.get()

            # Now merge all the files into one and delete leftovers
            # merge_netcdf_files(init_datetime_str, conf)

            # forecast count = a constant for each run
            forecast_count += 1

            # update lists
            results = []

            # y_pred allocation
            y_pred = None

            gc.collect()

            if distributed:
                torch.distributed.barrier()

if distributed:
    torch.distributed.barrier()

RuntimeError: No active exception to reraise

In [10]:
x[0, -4, ...]

tensor([[[-0.5151, -0.5148, -0.5146,  ..., -0.5156, -0.5154, -0.5152],
         [-0.4601, -0.4599, -0.4595,  ..., -0.4609, -0.4606, -0.4602],
         [-0.4083, -0.4081, -0.4079,  ..., -0.4088, -0.4086, -0.4085],
         ...,
         [-0.6567, -0.6568, -0.6568,  ..., -0.6565, -0.6566, -0.6567],
         [-0.6584, -0.6584, -0.6584,  ..., -0.6583, -0.6583, -0.6583],
         [-0.6583, -0.6583, -0.6584,  ..., -0.6583, -0.6583, -0.6583]],

        [[-0.5215, -0.5213, -0.5212,  ..., -0.5218, -0.5217, -0.5216],
         [-0.4651, -0.4650, -0.4648,  ..., -0.4655, -0.4653, -0.4651],
         [-0.4081, -0.4081, -0.4080,  ..., -0.4081, -0.4081, -0.4081],
         ...,
         [-0.6479, -0.6481, -0.6482,  ..., -0.6473, -0.6475, -0.6477],
         [-0.6515, -0.6516, -0.6517,  ..., -0.6513, -0.6513, -0.6514],
         [-0.6536, -0.6536, -0.6536,  ..., -0.6535, -0.6535, -0.6536]]],
       device='cuda:0')

### compare old & new rollouts

In [25]:
new = xr.open_dataset('/glade/derecho/scratch/ksha/CREDIT/fuxi_norm_new/2018-06-01T00Z/pred_2018-06-01T00Z_002.nc')

  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)


In [26]:
old = xr.open_dataset('/glade/derecho/scratch/ksha/CREDIT/fuxi_norm_test/2018-06-01T00Z/pred_2018-06-01T00Z_002.nc')

  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)
  var = coder.decode(var, name=name)


In [27]:
np.mean(np.array(new['Z500']) - np.array(old['Z500']))

0.40104607

In [17]:
np.array(new['Z500'])

array([[[51883.26 , 51883.098, 51883.46 , ..., 51877.56 , 51878.54 ,
         51879.797],
        [51890.203, 51890.65 , 51890.28 , ..., 51885.215, 51886.035,
         51887.37 ],
        [51895.32 , 51895.86 , 51895.645, ..., 51890.457, 51891.38 ,
         51891.742],
        ...,
        [50532.99 , 50533.633, 50533.477, ..., 50544.344, 50541.297,
         50540.9  ],
        [50525.34 , 50526.992, 50525.566, ..., 50535.746, 50533.855,
         50532.27 ],
        [50501.258, 50500.633, 50504.504, ..., 50509.96 , 50511.91 ,
         50511.758]]], dtype=float32)

In [18]:
np.array(old['Z500'])

array([[[51883.438, 51883.258, 51883.61 , ..., 51877.86 , 51878.81 ,
         51880.03 ],
        [51890.367, 51890.81 , 51890.42 , ..., 51885.51 , 51886.31 ,
         51887.613],
        [51895.46 , 51895.996, 51895.76 , ..., 51890.734, 51891.625,
         51891.96 ],
        ...,
        [50532.848, 50533.496, 50533.34 , ..., 50544.215, 50541.16 ,
         50540.78 ],
        [50525.2  , 50526.848, 50525.426, ..., 50535.617, 50533.715,
         50532.137],
        [50501.11 , 50500.49 , 50504.367, ..., 50509.824, 50511.777,
         50511.625]]], dtype=float32)

In [None]:
y_pred.shape

In [24]:
from credit.data import Sample
from typing import Dict

In [27]:
logger = logging.getLogger(__name__)

In [29]:
state_transformer = Normalize_ERA5_and_Forcing(conf)

In [30]:
y_pred = state_transformer.inverse_transform(y_pred.cpu())

In [31]:
y_pred.shape

torch.Size([1, 67, 1, 640, 1280])

In [32]:
y_pred

tensor([[[[[ 9.6050e+00,  9.6030e+00,  9.5811e+00,  ...,  9.6841e+00,
             9.6410e+00,  9.6423e+00],
           [ 9.6263e+00,  9.5931e+00,  9.5919e+00,  ...,  9.6613e+00,
             9.6555e+00,  9.6297e+00],
           [ 9.7909e+00,  9.7778e+00,  9.7666e+00,  ...,  9.8450e+00,
             9.8432e+00,  9.8334e+00],
           ...,
           [ 1.7022e-01,  3.5282e-01,  5.4143e-01,  ...,  5.4510e-02,
             1.0519e-01,  1.4594e-01],
           [ 2.9755e-01,  4.7897e-01,  6.9006e-01,  ...,  3.1048e-01,
             2.8333e-01,  2.2607e-01],
           [ 1.0609e-01,  3.1670e-01,  5.3178e-01,  ...,  1.3572e-01,
             1.2455e-01,  5.1676e-02]]],


         [[[ 4.2611e+00,  4.2457e+00,  4.2450e+00,  ...,  4.2849e+00,
             4.2833e+00,  4.2390e+00],
           [ 4.4083e+00,  4.3953e+00,  4.3780e+00,  ...,  4.4362e+00,
             4.4175e+00,  4.3946e+00],
           [ 4.4866e+00,  4.4929e+00,  4.4759e+00,  ...,  4.5344e+00,
             4.5213e+00,  4.5233e+00],