In [1]:
from network_def import CNN
from data_utils import (
    DATA_DIR,
    load_labels,
    aug_crossentropy_RI_Dataset,
    AddGaussianNoise,
    crossentropy_RI_Dataset,
    find_file,
)
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as tvtf
import numpy as np
import torchinfo
import xarray as xr

In [2]:
device = "cpu"

In [27]:
train_labels, _ = load_labels(DATA_DIR + "/valid_labels.json")
valid_labels, _ = load_labels(DATA_DIR + "/valid_labels.json")
cnn_train_ds = aug_crossentropy_RI_Dataset(train_labels)
cnn_valid_ds = aug_crossentropy_RI_Dataset(valid_labels)
batch_size = 64
cnn_train_dataloader = DataLoader(cnn_train_ds, num_workers=2, batch_size=batch_size)
cnn_valid_dataloader = DataLoader(cnn_valid_ds, num_workers=2, batch_size=batch_size)

In [4]:
cnn_model = CNN(dropout_rate=0.5)
cnn_model.load_state_dict(
    torch.load(
        "./saved_models/crps_cnn.pt", map_location=torch.device(device)
    )
)
cnn_model.eval()

CNN(
  (conv_layers): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (1): LeakyReLU(negative_slope=0.2)
      (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (4): LeakyReLU(negative_slope=0.2)
      (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (1): LeakyReLU(negative_slope=0.2)
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (4): LeakyReLU(negative_slope=0.2)
      

In [5]:
torchinfo.summary(cnn_model, (batch_size, 1, 380, 540))

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      [64, 100]                 --
├─Sequential: 1-1                        [64, 256, 5, 8]           --
│    └─Sequential: 2-1                   [64, 8, 190, 270]         --
│    │    └─Conv2d: 3-1                  [64, 8, 380, 540]         80
│    │    └─LeakyReLU: 3-2               [64, 8, 380, 540]         --
│    │    └─BatchNorm2d: 3-3             [64, 8, 380, 540]         16
│    │    └─Conv2d: 3-4                  [64, 8, 380, 540]         584
│    │    └─LeakyReLU: 3-5               [64, 8, 380, 540]         --
│    │    └─BatchNorm2d: 3-6             [64, 8, 380, 540]         16
│    │    └─MaxPool2d: 3-7               [64, 8, 190, 270]         --
│    └─Sequential: 2-2                   [64, 16, 95, 135]         --
│    │    └─Conv2d: 3-8                  [64, 16, 190, 270]        1,168
│    │    └─LeakyReLU: 3-9               [64, 16, 190, 270]        --
│    │    └

In [28]:
test_iterator = iter(cnn_train_dataloader)

In [30]:
test_vals = next(test_iterator)
test_input = torch.reshape(test_vals[0], (-1, 1, 380, 540))
test_target = torch.reshape(test_vals[1], (-1,))
print(test_input.shape)
print(test_target.shape)
with torch.no_grad():
    test_output = cnn_model(test_input)
print(test_output.shape)
vars = torch.var(test_output, dim=1)
means = test_output.mean(dim=1)
for i in range(batch_size):
    print(f"Target: {test_target[i]}, Mean: {means[i]:.3f}, Var: {vars[i]:.3f}")

torch.Size([64, 1, 380, 540])
torch.Size([64])
torch.Size([64, 100])
Target: 0, Mean: -0.006, Var: 0.010
Target: 0, Mean: -0.006, Var: 0.010
Target: 0, Mean: -0.006, Var: 0.010
Target: 0, Mean: -0.005, Var: 0.010
Target: 0, Mean: -0.006, Var: 0.010
Target: 0, Mean: -0.006, Var: 0.009
Target: 0, Mean: -0.005, Var: 0.009
Target: 0, Mean: -0.007, Var: 0.010
Target: 0, Mean: -0.004, Var: 0.009
Target: 0, Mean: 0.001, Var: 0.008
Target: 0, Mean: -0.002, Var: 0.009
Target: 0, Mean: -0.005, Var: 0.009
Target: 0, Mean: -0.006, Var: 0.009
Target: 0, Mean: -0.006, Var: 0.010
Target: 0, Mean: -0.004, Var: 0.009
Target: 0, Mean: 0.001, Var: 0.009
Target: 0, Mean: -0.008, Var: 0.010
Target: 0, Mean: -0.008, Var: 0.010
Target: 0, Mean: -0.008, Var: 0.010
Target: 0, Mean: -0.006, Var: 0.010
Target: 0, Mean: -0.006, Var: 0.010
Target: 0, Mean: 0.000, Var: 0.009
Target: 0, Mean: -0.005, Var: 0.010
Target: 0, Mean: -0.005, Var: 0.010
Target: 0, Mean: -0.004, Var: 0.010
Target: 0, Mean: -0.004, Var: 0.00

In [None]:
from network_def import crps_loss
loss = crps_loss()