# Demo using a trained CVAE model
The goal here is to use a trained CVAE model with new data to create synthetic ensemble members.

# Libraries

In [None]:
import tensorflow as tf
tf.compat.v1.enable_eager_execution()

import os, json
import pickle5
import numpy as np
import pandas as pd
from scipy.ndimage.filters import gaussian_filter as gf

from tensorflow import keras
from tensorflow.keras import layers
import cProfile   # For eager execution, https://www.tensorflow.org/guide/eager
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

import cartopy

from cvae import Sampling, build_encoder, calculate_final_shape, calculate_output_paddings
from cvae import build_decoder, VAE, plot_latent_space, plot_images

In [None]:
from scripts.get_data import download_file
from scripts.get_data import convert_file
from scripts.get_data import subset_file
from scripts.get_data import remove_data # removes all data

# data loading
def load_data(data_dir):      
    files = [f for f in os.listdir(data_dir)] # if ('subset' in f and 'tmp' not in f)]
    
    all_data = ((np.expand_dims(
        np.concatenate(
            [netCDF4.Dataset(data_dir + converted_file)['msl'][:] for converted_file in files]
        ),
        -1
    ).astype("float32") - 85000) / (110000 - 85000)).astype("float16")
    
    return all_data

# Load and preprocess the input data
The standard way of manipulating arrays in Conv2D layers in TF is to use arrays in the shape:
`batch_size,  height, width, channels = data.shape`
In our case, the the `batch_size` is the number of image frames (i.e. separate samples or rows in a `.csv` file), the `height` and `width` define the size of the image frame in number of pixels, and the `channels` are the number of layers in the frames.  Typically, channels are color layers (e.g. RGB or CMYK) but in our case, we could use different metereological variables.  However, for this first experiment, **we only need one channel** because we're only going to use sea level pressure (SLP).

The code for loading GEFS `.grib` files and making an initial plot is from [Victor Gensini's example](https://github.com/vgensini/gefs_v12_example/blob/master/GEFS_v12_eample.ipynb) posted on the GEFS Open Data Registry landing page.

In [None]:
# example parameters
ex_year = "2018"
ex_month = "01"
ex_day = "01"
ex_ensemble = "c00"

In [None]:
# example for getting and converting files 
download_file(ex_year, ex_month, ex_day, ex_ensemble, data_pdir)
convert_file(ex_year, ex_month, ex_day, ex_ensemble, data_dir)
slp = load_data(data_dir)

In [None]:
# look at data structure
print(np.shape(slp))

In [None]:
# grid point locations
lons = np.loadtxt('coordinates/lon.x')
lats = np.loadtxt('coordinates/lat.y')

x, y = np.meshgrid(lons,lats)

np.shape(x)

In [None]:
# example for plot
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')
plt.contour(x,y,np.squeeze(slp[0,:,:,0])*120000/100,
             transform = cartopy.crs.PlateCarree(),
             levels=[970,975,980,985,990,995,1000,1005,1010,1015,1020,1025,1030,1035,1040,1045,1050,1055,1060],colors='k')
plt.title('GEFSv12 MSL 2017 01 01 0000 UTC')
#ax.set_extent([-120, -73, 23, 50])
ax.set_extent([-150, -60, 20, 65])
#plt.colorbar()
plt.show()

# Load ML model

In [None]:
# key CVAE definition parameters
latent_dim = 3
n_conv_layers = 4
stride = 2
kernel_size = 3
batch_size, height, width, channels = slp.shape

encoder = build_encoder(latent_dim, height, width, channels, n_conv_layers, kernel_size, stride, base_filters = 16)
decoder = build_decoder(latent_dim, height, width, channels, n_conv_layers, kernel_size, stride, base_filters = 16)
vae = VAE(encoder, decoder, height * width)
vae.compile(optimizer='rmsprop')
vae.load_weights(os.path.join('model_dir', 'vae.weights.h5'))

# Encode, perturb, and decode

In [None]:
z_mean, z_log_var, z = vae.encoder(slp.astype('float32'))
sample_output_images = vae.decoder(z)

In [None]:
print(tf.executing_eagerly())
#tf.compat.v1.enable_eager_execution()

# Plot setup
#fig, ax = plt.subplots(figsize=(9,6))
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')

# Plot each one
for i, image in enumerate(sample_output_images):
    if i == 0 :
        print('Filtering...')
        filtered = gf(np.squeeze(image)*120000/100, [3,3], mode='constant')
        print(np.mean(filtered))
        print(np.std(filtered))
        #plt.pcolor(x,y,np.squeeze(image)*120000/100,shading='auto')
            #transform = cartopy.crs.PlateCarree(),shading='auto')
        print('Contour plotting...')
        plt.contour(x,y,filtered,
            transform = cartopy.crs.PlateCarree(),
            levels=[980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],colors='r',linewidths=1)   

# Plot original
print('Filtering...')
filtered = gf(np.squeeze(slp[0,:,:,0])*120000/100, [3,3], mode='constant')
print('Contour plotting...')
plt.contour(x,y,filtered,
             transform = cartopy.crs.PlateCarree(),
             levels=[980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],colors='k',linewidths=2)

#plt.title('GEFSv12 Re-forecast SLP 990hPa 2018 01 10 0000 UTC Cycle')
#ax.set_extent([-120, -73, 23, 50])
ax.set_extent([-150, -60, 20, 65])
#ax.set_xlim([-150,-60])
#ax.set_ylim([20,65])
#plt.colorbar()
plt.show()

# Perturb current state

In [None]:
perturbed_images_high = vae.decoder(z_mean + z_log_var)
perturbed_images_low = vae.decoder(z_mean - z_log_var)

print(tf.executing_eagerly())
#tf.compat.v1.enable_eager_execution()

# Plot setup
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')

# Plot each one
for i, image in enumerate(perturbed_images_high):
    if i == 0 :
        print('Filtering...')
        filtered = gf(np.squeeze(image)*120000/100, [3,3], mode='constant')
        print('Contour plotting...')
        plt.contour(x,y,filtered,
            transform = cartopy.crs.PlateCarree(),
            levels=[980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],colors='r',linewidths=1)   

# Plot each one
for i, image in enumerate(perturbed_images_low):
    if i == 0 :
        print('Filtering...')
        filtered = gf(np.squeeze(image)*120000/100, [3,3], mode='constant')
        print('Contour plotting...')
        plt.contour(x,y,filtered,
            transform = cartopy.crs.PlateCarree(),
            levels=[980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],colors='b',linewidths=1)   
    
# Plot original
plt.contour(x,y,np.squeeze(slp[0,:,:,0])*120000/100,
             transform = cartopy.crs.PlateCarree(),
             levels=[980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],colors='k',linewidths=2)

plt.title('GEFSv12 Re-forecast SLP 990hPa 2018 01 10 0000 UTC Cycle')
#ax.set_extent([-120, -73, 23, 50])
ax.set_extent([-150, -60, 20, 65])
#plt.colorbar()
plt.show()

# Generate totally random weather maps

In [None]:
codings = tf.random.normal(shape = [12, latent_dim])
images = vae.decoder(codings).numpy()

# Plot setup
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')

# Plot each one
for i, image in enumerate(images):
    print('Filtering...')
    filtered = gf(np.squeeze(image)*120000/100, [5,5], mode='constant')
    plt.contour(x,y,filtered,
        transform = cartopy.crs.PlateCarree(),
        levels=[970,980,990,1000,1010,1015,1020,1025,1030,1040,1050,1060],colors='k',linewidths=1)
        
#plt.title('Random MSL pressure maps from ML model trained with GEFSv12 Re-forecast 2018-2019')
#ax.set_extent([-120, -73, 23, 50])
ax.set_extent([-150, -60, 20, 65])
#plt.colorbar()
plt.show()

In [None]:
print(latent_dim)