## Inference
Test if we can run the saved models

In [None]:
import sys
import xarray as xr
import torch

sys.path.append('../src')
from Inference import *

from torch.utils.data.dataloader import DataLoader

Set up models

In [None]:
# Subset of years
year_start = 2020
year_end = 2021

# Device
device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# Dirs
data_dir="../data/"
model_dir="../"


Get checkpoints of models for both UNet and Diffusion

In [None]:
# Get models
# unet
model_unet = UNet((256, 128), 5, 3, label_dim=2, use_diffuse=False).to(device)
model_unet.load_state_dict(torch.load(f"{model_dir}/Model_chpt/unet.pt"))
# diffusion
model_diff = EDMPrecond((256, 128), 8, 3).to(device)
model_diff.load_state_dict(torch.load(f"{model_dir}/Model_chpt/diffusion.pt"))

Open datasets

In [None]:
# define the datasets
datadir = "../data/"

dataset_test = UpscaleDataset(datadir, year_start=year_start, year_end=year_end,
                              constant_variables=["lsm", "z"])

lat = dataset_test.lat
lon = dataset_test.lon

nlat, nlon = len(lat), len(lon)

BATCH_SIZE = 1
dataloader = DataLoader(dataset_test,
                        batch_size=BATCH_SIZE,
                        shuffle=False)


 Sample our batch of 1 for both models

In [None]:
t = 0   # time index
test_batch = next(iter(dataloader))

# Run models
coarse, fine, predicted_unet = sample_unet(test_batch, model_unet, 
                                           device, dataset_test)
_, _, predicted_diff = sample_model_EDS(test_batch, model_unet, 
                                                device, dataset_test, num_steps=40)

coarse = coarse.detach().numpy()
fine = fine.detach().numpy()
predicted_unet = predicted_unet.detach().numpy()
predicted_diff = predicted_diff.detach().numpy()


In [None]:
print(coarse.shape, fine.shape, predicted_unet.shape, predicted_diff.shape)

## Plot
Plot all three variables for this one timestep

In [None]:
# Variables - defines three separate subplots
varnames = ["VAR_2T", "VAR_10U", "VAR_10V"]
vmin = [250, -10, -10]
vmax = [300, 10, 10]
vmax_stds = [3, 1, 1]
cmaps = ["rainbow", "BrBG_r", "BrBG_r"]

plot_varnames = ["Temperature", "Zonal wind", "Meridional wind"]
plot_var_labels = ["K", "m/s", "m/s"]
plt.rcParams.update({'font.size': 18})

In [None]:
plt.clf()
fig, axs = plt.subplots(4,3, figsize=(16, 10.2),
                        subplot_kw={'projection': ccrs.PlateCarree()},
                        gridspec_kw={'wspace': 0.1,
                                     'hspace': 0.1})
for i, varname in enumerate(varnames):
    # Plot truth for first plot
    ax = axs[0, i]
    plt.sca(ax)
    ax.coastlines()
    ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none')
    pcm = plt.pcolormesh(lon, lat, coarse[:, i],
                   vmin=vmin[i], vmax=vmax[i],
                   shading='nearest',
                   cmap=cmaps[i])
    plt.title(f"{plot_varnames[i]}")
    if i == 0:
        plt.text(lon[0]-2, lat[len(lat) // 2], f"Coarse", transform=ccrs.PlateCarree(),
                 rotation='vertical', ha='right', va='center', zorder=10)
    #plt.colorbar(pcm, orientation="horizontal", label=f"{varname}")

    ax = axs[1, i]
    plt.sca(ax)
    ax.coastlines()
    ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none')
    pcm = plt.pcolormesh(lon, lat, fine[:, i],
                   vmin=vmin[i], vmax=vmax[i],
                   shading='nearest',
                   cmap=cmaps[i])
    if i == 0:
        plt.text(lon[0]-2, lat[len(lat) // 2], f"Truth", transform=ccrs.PlateCarree(),
                 rotation='vertical', ha='right', va='center', zorder=10)
    #plt.title(f"Truth {varname}")
    #plt.colorbar(pcm, orientation="horizontal", label=f"{varname}")

    ax = axs[2, i]
    plt.sca(ax)
    ax.coastlines()
    ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none')
    pcm = plt.pcolormesh(lon, lat, predicted_unet[:, i],
                   vmin=vmin[i], vmax=vmax[i],
                   shading='nearest',
                   cmap=cmaps[i])
    if i == 0:
        plt.text(lon[0]-2, lat[len(lat) // 2], f"U-Net", transform=ccrs.PlateCarree(),
                 rotation='vertical', ha='right', va='center', zorder=10)
    #plt.title(f"UNet {varname}")
    #plt.colorbar(pcm, orientation="horizontal", label=f"{varname}")


    ax = axs[3, i]
    plt.sca(ax)
    ax.coastlines()
    ax.add_feature(cartopy.feature.LAKES, edgecolor='black', facecolor='none')
    pcm = plt.pcolormesh(lon, lat, predicted_diffusion[:, i],
                   vmin=vmin[i], vmax=vmax[i],
                   shading='nearest',
                   cmap=cmaps[i])
    if i == 0:
        plt.text(lon[0]-2, lat[len(lat) // 2], f"Diffusion", transform=ccrs.PlateCarree(),
                 rotation='vertical', ha='right', va='center', zorder=10)
    cax = axs[3, i].inset_axes([0., -0.25, 1, 0.1])
    plt.colorbar(pcm, cax = cax, orientation="horizontal", label=f"{plot_var_labels[i]}")

plt.show()