# Demo using a trained CVAE model
The goal here is to use a trained CVAE model with new data to create synthetic ensemble members.

# Libraries

In [None]:
# imports
import tensorflow as tf
tf.compat.v1.enable_eager_execution()

import os, json
import netCDF4
import cartopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage.filters import gaussian_filter as gf

from tensorflow import keras
from keras import layers
import cProfile   # For eager execution, https://www.tensorflow.org/guide/eager
from sklearn.model_selection import train_test_split

from scripts.cvae import Sampling, build_encoder, calculate_final_shape, calculate_output_paddings
from scripts.cvae import build_decoder, VAE, plot_latent_space, plot_images

from scripts.get_data import download_file
from scripts.get_data import convert_file
from scripts.get_data import subset_file
from scripts.get_data import remove_data # removes all data

In [None]:
# data loading
def load_data(data_dir):      
    files = [f for f in os.listdir(data_dir)] # if ('subset' in f and 'tmp' not in f)]
    
    all_data = ((np.expand_dims(
        np.concatenate(
            [netCDF4.Dataset(data_dir + converted_file)['msl'][:] for converted_file in files]
        ),
        -1
    ).astype("float32") - 85000) / (110000 - 85000)).astype("float16")
    
    return all_data

In [None]:
print("TF version:", tf.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

# Load and preprocess the input data

In [None]:
data_pdir = "./gefs_data"
data_dir = "./gefs_data/converted/"
model_dir = './model_dir'

In [None]:
# example parameters
ex_year = "2018"
ex_month = "01"
ex_day = "01"
ex_ensemble = "c00"

In [None]:
# example for getting and converting files 
download_file(ex_year, ex_month, ex_day, ex_ensemble, data_pdir)
convert_file(ex_year, ex_month, ex_day, ex_ensemble, data_dir)
slp = load_data(data_dir)

In [None]:
# look at data structure
print(np.shape(slp))

In [None]:
# grid point locations
lons = np.loadtxt('coordinates/lon.x')
lats = np.loadtxt('coordinates/lat.y')

x, y = np.meshgrid(lons,lats)
points = np.squeeze(slp[0,:,:,0]) * (110000 - 85000) + 85000

In [None]:
# get the min and max of data
data_min = np.min(points)
data_max = np.max(points)

# number of levels
num_levels = 20

# list of contour levels
levels = np.linspace(data_min, data_max, num_levels)

In [None]:
# example for plot
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')
plt.contour(x, y, points,
            transform = cartopy.crs.PlateCarree(),
            levels = np.linspace(data_min, data_max, num_levels),
plt.title('GEFSv12 MSL 2017 01 01 0000 UTC')
#ax.set_extent([-120, -73, 23, 50])
ax.set_extent([-150, -60, 20, 65])
#plt.colorbar()
plt.show()

# Load ML model

In [None]:
# key CVAE definition parameters
latent_dim = 2
n_conv_layers = 4
stride = 2
kernel_size = 3
batch_size, height, width, channels = slp.shape

encoder = build_encoder(latent_dim)
decoder = build_decoder(latent_dim)
vae = VAE(encoder, decoder)
vae.compile(optimizer='rmsprop')
vae.load_weights(os.path.join('model_dir', 'vae.weights.h5'))

In [None]:
z_mean, z_log_var, z = vae.encoder((np.expand_dims(slp[0:39:2,:,:,:], -1)))
sample_output_images = vae.decoder(z)

In [None]:
# plt.scatter(z[:, 0], z[:, 1])
# plt.show()

fig, ax = plt.subplots(1,1)
ax.plot(z[:,0],z[:,1],"b-")

# Encode, perturb, and decode

In [None]:
print(tf.executing_eagerly())
#tf.compat.v1.enable_eager_execution()

# plot setup
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')

# plot each decoded state
for i, image in enumerate(sample_output_images):
    if i == 0 : # right now only doing the frist one
        print('Filtering...')
        filtered = gf(np.squeeze(image) * 120000 / 100, [3,3], mode = 'constant')
        print(np.mean(filtered))
        print(np.std(filtered))
        #plt.pcolor(x,y,np.squeeze(image)*120000/100,shading='auto')
            #transform = cartopy.crs.PlateCarree(),shading='auto')
        print('Contour plotting...')
        plt.contour(x, y, filtered,
                    transform = cartopy.crs.PlateCarree(),
                    levels = [980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],
                    colors = 'r',
                    linewidths = 1)   

# plot original
print('Filtering...')
filtered = gf(np.squeeze(slp[0,:,:,0]) * 120000 / 100, [3,3], mode = 'constant')
print('Contour plotting...')
plt.contour(x, y, filtered,
            transform = cartopy.crs.PlateCarree(),
            levels = [980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],
            colors = 'k',
            linewidths = 2)

#plt.title('GEFSv12 Re-forecast SLP 990hPa 2018 01 10 0000 UTC Cycle')
#ax.set_extent([-120, -73, 23, 50])
ax.set_extent([-150, -60, 20, 65])
#ax.set_xlim([-150,-60])
#ax.set_ylim([20,65])
#plt.colorbar()
plt.show()

# Perturb current state

In [None]:
perturbed_images_high = vae.decoder(z_mean + z_log_var)
perturbed_images_low = vae.decoder(z_mean - z_log_var)

print(tf.executing_eagerly())
#tf.compat.v1.enable_eager_execution()

# Plot setup
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')

# Plot each one
for i, image in enumerate(perturbed_images_high):
    if i == 0 :
        print('Filtering...')
        filtered = gf(np.squeeze(image)*120000/100, [3,3], mode='constant')
        print('Contour plotting...')
        plt.contour(x,y,filtered,
            transform = cartopy.crs.PlateCarree(),
            levels=[980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],colors='r',linewidths=1)   

# Plot each one
for i, image in enumerate(perturbed_images_low):
    if i == 0 :
        print('Filtering...')
        filtered = gf(np.squeeze(image)*120000/100, [3,3], mode='constant')
        print('Contour plotting...')
        plt.contour(x,y,filtered,
            transform = cartopy.crs.PlateCarree(),
            levels=[980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],colors='b',linewidths=1)   
    
# Plot original
plt.contour(x,y,np.squeeze(slp[0,:,:,0])*120000/100,
             transform = cartopy.crs.PlateCarree(),
             levels=[980,990,1000,1010,1015,1020,1025,1030,1035,1040,1045,1050],colors='k',linewidths=2)

plt.title('GEFSv12 Re-forecast SLP 990hPa 2018 01 10 0000 UTC Cycle')
#ax.set_extent([-120, -73, 23, 50])
ax.set_extent([-150, -60, 20, 65])
#plt.colorbar()
plt.show()

# Generate totally random weather maps

In [None]:
codings = tf.random.normal(shape = [12, latent_dim])
images = vae.decoder(codings).numpy()

# Plot setup
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')

# Plot each one
for i, image in enumerate(images):
    print('Filtering...')
    filtered = gf(np.squeeze(image)*120000/100, [5,5], mode='constant')
    plt.contour(x,y,filtered,
        transform = cartopy.crs.PlateCarree(),
        levels=[970,980,990,1000,1010,1015,1020,1025,1030,1040,1050,1060],colors='k',linewidths=1)
        
#plt.title('Random MSL pressure maps from ML model trained with GEFSv12 Re-forecast 2018-2019')
#ax.set_extent([-120, -73, 23, 50])
ax.set_extent([-150, -60, 20, 65])
#plt.colorbar()
plt.show()

In [None]:
print(latent_dim)