In [1]:
import numpy as np
import xarray as xr
import torch
import torch.utils.data
import yaml

from torchvision import transforms
from credit.transforms404 import NormalizeState, ToTensor
from credit.data import CONUS404Dataset

  warn(


In [15]:
config = "../config/conus404.yml"

In [16]:
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [4]:
conf["data"]

{'variables': ['PSFC',
  'Q500',
  'Q850',
  'T2',
  'T500',
  'T850',
  'totalVap',
  'U10',
  'U1000',
  'U250',
  'U500',
  'U850',
  'V10',
  'V1000',
  'V250',
  'V500',
  'V850',
  'Z050',
  'Z1000',
  'Z500',
  'Z850'],
 'static_variables': [],
 'scaler_type': 'std',
 'save_loc': '/glade/derecho/scratch/mcginnis/???*',
 'mean_path': '/glade/derecho/scratch/mcginnis/conus404/stats/all.avg.C404.nc',
 'std_path': '/glade/derecho/scratch/mcginnis/conus404/stats/all.std.C404.nc',
 'history_len': 2,
 'forecast_len': 1,
 'valid_history_len': 2,
 'valid_forecast_len': 1,
 'time_step': 1}

In [6]:
import torch
from credit.data import Sample


class ToTensor:
    def __init__(self, conf):
        self.conf = conf
        self.hist_len = int(conf["data"]["history_len"])
        self.for_len = int(conf["data"]["forecast_len"])
        self.variables = conf["data"]["variables"]
        self.static_variables = conf["data"]["static_variables"]
        # self.x = 1016
        # self.y = 1638
        self.slice_x = slice(120, 632, None)
        self.slice_y = slice(300, 812, None)

    def __call__(self, sample: Sample) -> Sample:
        return_dict = {}

        for key, value in sample.items():
            if key == "historical_ERA5_images" or key == "x":
                self.datetime = value["Time"]
                self.doy = value["Time.dayofyear"]
                self.hod = value["Time.hour"]

            if isinstance(value, xr.DataArray):
                value_var = value.values

            elif isinstance(value, xr.Dataset):
                concatenated_vars = []
                for vv in self.variables:
                    value_var = value[vv].values
                    if (
                        len(value_var.shape) == 4
                    ):  # some seem to have extra single dimensions
                        value_var = value_var.squeeze(1)
                    concatenated_vars.append(value_var[:, self.slice_x, self.slice_y])
                concatenated_vars = np.array(concatenated_vars)

            else:
                value_var = value

            if key == "x":
                x = torch.as_tensor(
                    np.vstack([np.expand_dims(x, axis=0) for x in concatenated_vars])
                )
                return_dict["x"] = x

            elif key == "y":
                y = torch.as_tensor(
                    np.vstack([np.expand_dims(x, axis=0) for x in concatenated_vars])
                )
                return_dict["y"] = y

        if self.static_variables:
            pass

        return return_dict

In [8]:
transform = transforms.Compose([NormalizeState(conf), ToTensor(conf)])

In [9]:
dataset = CONUS404Dataset(
    zarrpath="/glade/campaign/ral/risc/DATA/conus404/zarr",
    varnames=conf["data"]["variables"],
    history_len=conf["data"]["history_len"],
    forecast_len=conf["data"]["forecast_len"],
    transform=transform,
)

In [10]:
result = dataset.__getitem__(0)

In [11]:
result["x"].shape  # (C, T, L, W)

torch.Size([21, 2, 512, 512])

In [12]:
result["y"].shape  # (C, T, L, W)

torch.Size([21, 1, 512, 512])

In [13]:
result["x"]

tensor([[[[ 0.6357,  0.6352,  0.6347,  ...,  0.6846,  0.6844,  0.6843],
          [ 0.6356,  0.6351,  0.6345,  ...,  0.6846,  0.6845,  0.6844],
          [ 0.6355,  0.6350,  0.6344,  ...,  0.6847,  0.6845,  0.6844],
          ...,
          [-1.8032, -1.8101, -1.7504,  ...,  0.2187,  0.2043,  0.1972],
          [-1.7807, -1.7904, -1.7503,  ...,  0.2104,  0.2349,  0.2361],
          [-1.7590, -1.7249, -1.7385,  ...,  0.1774,  0.1987,  0.1943]],

         [[ 0.6324,  0.6319,  0.6314,  ...,  0.6883,  0.6882,  0.6881],
          [ 0.6322,  0.6317,  0.6311,  ...,  0.6883,  0.6882,  0.6881],
          [ 0.6320,  0.6314,  0.6309,  ...,  0.6884,  0.6883,  0.6882],
          ...,
          [-1.8039, -1.8105, -1.7507,  ...,  0.2146,  0.2000,  0.1930],
          [-1.7814, -1.7909, -1.7509,  ...,  0.2064,  0.2308,  0.2319],
          [-1.7596, -1.7256, -1.7393,  ...,  0.1736,  0.1949,  0.1904]]],


        [[[-0.7455, -0.7411, -0.7367,  ..., -0.8105, -0.8054, -0.8001],
          [-0.7454, -0.7409,

In [33]:
from credit.models.unet import load_premade_encoder_model
import torch.nn.functional as F


class SegmentationModel(torch.nn.Module):
    def __init__(self, conf):
        super(SegmentationModel, self).__init__()

        self.variables = conf["data"]["variables"]
        self.frames = conf["model"]["frames"]
        self.static_variables = (
            conf["data"]["static_variables"]
            if "static_variables" in conf["data"]
            else []
        )

        in_channels = len(self.variables) + len(self.static_variables)
        out_channels = len(self.variables)

        if conf["model"]["architecture"]["name"] == "unet":
            conf["model"]["architecture"]["decoder_attention_type"] = "scse"
        conf["model"]["architecture"]["in_channels"] = in_channels
        conf["model"]["architecture"]["classes"] = out_channels

        self.model = load_premade_encoder_model(conf["model"]["architecture"])

    def forward(self, x):
        x = F.avg_pool3d(x, kernel_size=(2, 1, 1)) if x.shape[2] > 1 else x
        x = x.squeeze(2)  # squeeze time dim
        x = self.model(x)
        return x.unsqueeze(2)

In [34]:
model = SegmentationModel(conf)

In [35]:
y_pred = model(result["x"].unsqueeze(0))

In [37]:
y_pred

tensor([[[[[ 1.2298e-01,  2.1172e-01,  3.1994e-01,  ...,  1.1723e+00,
             1.0137e+00,  4.8237e-01],
           [ 7.5705e-01,  3.8058e-01, -2.6989e-01,  ...,  1.6883e+00,
             7.7407e-01,  8.7452e-03],
           [-3.4797e-01, -5.0708e-01, -4.8845e-01,  ...,  9.0493e-01,
            -1.5792e-01,  1.9497e-01],
           ...,
           [ 4.1544e-01,  5.6858e-01,  2.6262e-01,  ...,  6.1462e-01,
             1.1958e+00,  1.9587e-01],
           [ 3.8555e-01,  4.8237e-01,  1.3258e-01,  ...,  2.3232e-01,
             4.9614e-01,  1.7373e-02],
           [ 4.0762e-01,  5.2569e-01,  9.4668e-02,  ...,  9.6332e-02,
             1.6551e-01, -2.2042e-01]]],


         [[[-5.5071e-02,  3.0240e-01, -2.6451e-01,  ...,  5.8349e-01,
             4.6458e-01,  7.6008e-02],
           [ 9.4891e-01,  4.2858e-01, -2.9804e-02,  ...,  1.3773e+00,
             2.5978e-01, -3.5476e-01],
           [ 3.7230e-01, -1.4736e-01, -3.7474e-01,  ...,  6.9030e-01,
            -5.3386e-01, -8.3791e-01],