# A summary of all models

This notebook summarizes all the neural network components of this research

* VQ-VAE
* ViT-based bias correction model
* Latent Diffusion Model

In [4]:
import os
import sys
import time
import h5py
import numpy as np
from glob import glob

# ------------------------------------------------------- #
# Turn-off warnings
import logging
import warnings

warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
logging.getLogger("tensorflow").setLevel(logging.ERROR)

# ------------------------------------------------------- #
# Turn-off tensoflow-specific warnings
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

tf.autograph.set_verbosity(0)
tf.get_logger().setLevel('ERROR')

# ------------------------------------------------------- #
# Import customized modules and settings
sys.path.insert(0, '/glade/u/home/ksha/GAN_proj/')
sys.path.insert(0, '/glade/u/home/ksha/GAN_proj/libs/')

from namelist import *
import data_utils as du
import model_utils as mu
import verif_utils as vu

## VQ-VAE

In [2]:
# Hyperparameters
filter_nums = [64, 128] # number of convolution kernels per down-/upsampling layer 
latent_dim = 4 # number of latent feature channels
activation = 'gelu' # activation function
num_embeddings = 128 #128 # number of the VQ codes

input_size = (224, 464, 1) # size of MRMS input
latent_size = (14, 29, latent_dim) # size of compressed latent features

drop_encode = True
drop_decode = True

model_name_encoder_load = model_dir+'models/VQ_VAE_encoder_stack1_tune0'
model_name_decoder_load = model_dir+'models/VQ_VAE_decoder_stack1_tune0'

### Build model

In [6]:
encoder = mu.VQ_VAE_encoder_x4(input_size, filter_nums, latent_dim, num_embeddings, activation, drop_encode)

W_old = mu.dummy_loader(model_name_encoder_load)
encoder.set_weights(W_old)

decoder = mu.VQ_VAE_decoder_x4(latent_size, filter_nums, activation, drop_decode)

W_old = mu.dummy_loader(model_name_decoder_load)
decoder.set_weights(W_old)

In [7]:
# Connect the encoder and decoder
X = keras.Input(shape=input_size)
X_encode = encoder(X)
X_decode = decoder(X_encode)
VQ_VAE = keras.Model(X, X_decode)

### Try on the validations set

In [8]:
BATCH_dir = camp_dir+'BATCH_CCPA_full/'
# validation set size

# collect validation set sampales
filenames = sorted(glob(BATCH_dir+'*.npy'))
filenames = filenames[:-3648]
L = len(filenames)

#filename_valid = filenames[::8][:2627]
filename_valid = filenames[::8][:500] # samller validation set size
L_valid = len(filename_valid)

Y_valid = np.empty((L_valid, 224, 464, 1))
Y_valid[...] = np.nan

for i, name in enumerate(filename_valid):
    Y_valid[i, ..., 0] = np.load(name)

In [9]:
# Predict
Y_pred = VQ_VAE.predict(Y_valid)
Y_pred[Y_pred<0] = 0
record = du.mean_absolute_error(Y_valid, Y_pred)
print('MAE: {}'.format(record))

MAE: 0.0054497166918077156


## Bias-correction ViT

In [10]:
# The tensor size of embedded CCPA and GEFS ensemble mean 
latent_size = (14, 29, 4)
# input size for the 48h models
input_size = (8,) + latent_size

# patch size
patch_size = (1, 1, 1) # (time, space, space)

N_heads = 4
N_layers = 8
project_dim = 128

model_name_load_48h = model_dir+'baseline/ViT3d_0_48_depth{}_patch{}{}{}_dim{}_heads{}_tune'.format(
    N_layers, patch_size[0], patch_size[1], patch_size[2], project_dim, N_heads)

### Build model

In [11]:
model_48h = mu.ViT3d_corrector(input_size, patch_size, project_dim, N_layers, N_heads)

W_old = mu.dummy_loader(model_name_load_48h)
model_48h.set_weights(W_old)

### Try on the validations set

In [12]:
BATCH_dir = camp_dir+'BATCH_ViT/'
filenames = sorted(glob(BATCH_dir+'*npy'))

L_valid = 500
filenames_valid = filenames[:L_valid]

valid_GEFS = np.empty((L_valid, 8,)+latent_size)
valid_CCPA = np.empty((L_valid, 8,)+latent_size)

for i, name_ in enumerate(filenames_valid):
    temp_data = np.load(name_, allow_pickle=True)[()]
    valid_GEFS[i, ...] = temp_data['GEFS_embed'][:8, ...]
    valid_CCPA[i, ...] = temp_data['CCPA_embed'][:8, ...]

In [13]:
Y_pred_48 = model_48h.predict(valid_GEFS[:, 0:8, ...])
record = du.mean_absolute_error(valid_CCPA[:, 0:8, ...], Y_pred_48)
print(record)

0.03936765624405502


## Diffusion model

In [14]:
# The tensor size of embedded CCPA and GEFS ensemble mean 
latent_size = (14, 29, 4)
input_size = (24,) + latent_size # LDM generates all 06-144, (24,) lead times

# model design
widths = [32, 64, 96, 128]
embedding_dims = 32
block_depth = 2

# diffusion steps
diffusion_steps = 20
min_signal_rate = 0.02
max_signal_rate = 0.95
ema = 0.999

# location of the previous weights
model_name_load = model_dir+'models/LDM_3d_tune4/'

### Build model

In [15]:
LDM = mu.DiffusionModel(input_size, input_size, input_size, 
                        diffusion_steps, min_signal_rate, max_signal_rate, 
                        embedding_dims, widths, block_depth, ema)

LDM.load_weights(model_name_load)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x1526e5c7c9d0>

### Try on the validations set

In [16]:
# min-max values
ccpa_min = -2.0606
ccpa_max = 1.6031

# Rescale to [-1, 1]
def rescale(x, min_val=ccpa_min, max_val=ccpa_max):
    return ((x - min_val) / (max_val - min_val))* 2 - 1

In [17]:
BATCH_dir = camp_dir+'BATCH_LDM/'
filenames = sorted(glob(BATCH_dir+'*npy'))

L_valid = 100
filenames_valid = filenames[::40][:L_valid]

valid_CCPA = np.empty((L_valid, input_size[0],)+latent_size)
valid_ViT = np.empty((L_valid, input_size[0],)+latent_size)

for i, name_ in enumerate(filenames_valid):
    temp_data = np.load(name_, allow_pickle=True)[()]
    valid_CCPA[i, ...] = temp_data['CCPA_embed'][:input_size[0], ...]
    valid_ViT[i, ...] = temp_data['ViT_embed'][:input_size[0], ...]

valid_CCPA = rescale(valid_CCPA)

In [18]:
valid_ViT = rescale(valid_ViT)

In [19]:
y_pred = LDM.generate(L_valid, valid_ViT)
y_pred = np.array(y_pred)
record = du.mean_absolute_error(y_pred, valid_CCPA)
print(record)

0.028370507129221484
