# **CVAE_training**

In [None]:
import os, json

import papermill as pm
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import tensorflow as tf
import netCDF4
import cartopy

from tensorflow import keras
from keras import layers
from sklearn.model_selection import train_test_split 

print("TF version:", tf.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

# Download and Convert Data
On my [first Google hit for GEFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-ensemble-forecast), I clicked on [AWS Open Data Registry for GEFS](https://registry.opendata.aws/noaa-gefs-pds/) and selected [NOAA GEFS Re-forecast](https://registry.opendata.aws/noaa-gefs-reforecast/) which has no useage restrictions.  The [GEFS Re-forecast data documentation](https://noaa-gefs-retrospective.s3.amazonaws.com/Description_of_reforecast_data.pdf) is very clear and we're going to download two files, 57 MB each.  The date of the initialization of the re-forecast is in the file name in the format YYYYMMDDHH.  The c00, p01, p02, p03, p04 are the control and perturbation ensemble members (5 total).

In [None]:
data_prefix = "./gefs_data"
data_dir = "./gefs_data/converted/" # change to match your own directory

# data download
def get_data(year, month, day, ensembles):
    num_files = 0
    
    for ensemble in ensembles:
        if f'pres_msl_{year}{month}{day}00_{ensemble}.grib2' not in os.listdir(data_prefix):
            !wget -q -P {data_prefix} https://noaa-gefs-retrospective.s3.amazonaws.com/GEFSv12/reforecast/{year}/{year}{month}{day}00/{ensemble}/Days%3A1-10/pres_msl_{year}{month}{day}00_{ensemble}.grib2
        
        if f'pres_msl_{year}{month}{day}00_{ensemble}.grib2' in os.listdir(data_prefix):
            num_files += 1
            
    return num_files

In [None]:
# delete all files
def remove_data():
    !find {data_prefix} -type f -delete

# Neural Network Design

We need to get to a small latent space. Conv2D networks are good because they help reduce the number of connections in a network in a meaningful way.  I'm using terms as defined in [this definition of conv2D](https://towardsdatascience.com/conv2d-to-finally-understand-what-happens-in-the-forward-pass-1bbaafb0b148).

**Definitions:**
K -> kernel size;
P -> padding;
S -> stride;
D -> Dilation;
G -> Groups

**Filter options:**
Longitude is easy because it is large and even, so as long as you have an even stride, you get integer results when dividing.
e.g. lon 9: stride 4, lat 7: stride 5

- Latitude - whole numbers occurr for P = 2 & K = 3 or K = 11.
- 11 grid points * 0.25 deg * 100 km/deg = 275 km filter window (a good scale for weather)
- 9 grid points * 0.25 deg * 100 km/deg = 225 km
- Longitude - whole numbers occur for P = 0 & K = 11 (nice match with Latitude), P = 1 & K = 3 or 13, P = 2 & K = 5.

For a 5 x 7 filter with 3 stride (no overlap) and no padding:
- lat: (721 - 4) / 3 = 239 possible steps (good whole number!)
- lon: (1440 - 4) / 3 = 478.6666 possible steps

## Load and Preprocess Training Data:
The standard way of manipulating arrays in Conv2D layers in TF is to use arrays in the shape:
`batch_size,  height, width, channels = data.shape`
In our case, the the `batch_size` is the number of image frames (i.e. separate samples or rows in a `.csv` file), the `height` and `width` define the size of the image frame in number of pixels, and the `channels` are the number of layers in the frames.  Typically, channels are color layers (e.g. RGB or CMYK) but in our case, we could use different metereological variables.  However, for this first experiment, **we only need one channel** because we're only going to use mean sea level pressure (msl).

## Build the Encoder:
GFS grids I have available here are at 0.25 degree resolution.  I'm doing this as a "worst case" scenario since there are also 0.5 and 1.0 degree grids with lower resolution but I can't find that data quickly and don't know what's available.

These 0.25 degree grids are 721 x 1440.
Each forecast file is 3 hourly for 10 days = 8 steps/forecast * 10 days = 80 "frames"
This demo is only using two forecasts from the control ensemble
(one launched Jan 01, 2019 and one launched Jan 02, 2019) -> this is only 
a small subset of the variability possible in the model.

This particular data set spans 2000-2019 and there are 5 ensemble members.

## Build the Decoder:
With the 11 x 11 and 5 x 5 filters, non-overlapping stride, applied here, we have a final "image" size of 14 x 27 and 64 channels.

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

def build_encoder(latent_dim):
    encoder_inputs = keras.Input(shape=(721, 1440, 1))
    
    x = layers.Conv2D(32, 11, activation = "relu", strides = [9, 10], padding = "valid")(encoder_inputs)
    x = layers.Conv2D(64, [5,9], activation = "relu", strides = [5, 9], padding = "valid")(x)
    x = layers.Flatten()(x)
    x = layers.Dense(16, activation="relu")(x)
    
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
    z = Sampling()([z_mean, z_log_var])
    
    encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name = "encoder")
    
    print(encoder.summary())
    return encoder

def build_decoder(latent_dim):
    latent_inputs = keras.Input(shape=(latent_dim,))
    x = layers.Dense(15 * 15 * 64, activation="relu")(latent_inputs)
    x = layers.Reshape((15, 15, 64))(x)
    # FIXME - there is something wrong here, but at least there is a pattern.
    # Using output_padding as a fudge factor -> it may be that there is exactly
    # one "missing" filter stamp/convolution because for both Conv2DTranspose
    # operations, output_padding is set to maximum it could be in both dims
    # (i.e. exactly one less than the stride of each filter).
    x = layers.Conv2DTranspose(64, [5, 9], activation = "relu", strides = [5,9], padding = "valid", output_padding = [4, 8])(x)
    x = layers.Conv2DTranspose(32, 11, activation = "relu", strides = [9,10], padding = "valid", output_padding = [8, 9])(x)
    decoder_outputs = layers.Conv2DTranspose(1, 3, activation = "sigmoid", padding = "same")(x)
    decoder = keras.Model(latent_inputs, decoder_outputs, name = "decoder")
    
    print(decoder.summary())
    return decoder

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name = "total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name = "reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name = "kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            # FIXME: Normalize loss with the number of features (28 * 28)
            n_features = 28 * 28
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis = (1, 2)
                )
            ) / n_features
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis = 1)) / n_features
            total_loss = (reconstruction_loss + kl_loss)
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    # Needed to validate (validation loss) and to evaluate
    def test_step(self, data):
        if type(data) == tuple:
            data, _ = data
            
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        # FIXME: Normalize loss with the number of features (28 * 28)
        n_features = 28 * 28
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis = (1, 2)
            )
        ) / n_features
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis = 1)) / n_features
        total_loss = (reconstruction_loss + kl_loss)
        # grads = tape.gradient(total_loss, self.trainable_weights)
        # self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

In [None]:
def train_model(X_train, X_test, X_valid, date, vae):
    early_stopping_cb = keras.callbacks.EarlyStopping(patience = 5, restore_best_weights = True) # stops training early if the validation loss does not improve
    
    if os.path.exists(model_dir + '/model/saved_model'): # if the model is already saved, load that model
        vae = tf.saved_model.load(model_dir + '/model')
    
    # if os.path.exists(os.path.join(model_dir, 'vae.weights.h5')): # if the model has already been trained at least once, load that model
    #     vae.load_weights(os.path.join(model_dir, 'vae.weights.h5'))

    history = vae.fit(
        X_train, epochs = 50, batch_size = 40,
        callbacks = [early_stopping_cb],
        validation_data = (X_valid,)
    )

    # vae.save_weights(os.path.join(model_dir, 'vae.weights.h5')) # save model weights after training
    tf.saved_model.save(vae, model_dir + '/model')

    hist_pd = pd.DataFrame(history.history)
    hist_pd.to_csv(os.path.join(model_dir, f'history_{date}.csv'), index = False)

    test_loss = vae.evaluate(X_test)
    test_loss = dict(zip(["loss", "reconstruction_loss", "kl_loss"], test_loss))

    print('Test loss:', test_loss)

    with open(os.path.join(model_dir, f'test_loss_{date}.json'), 'w') as json_file:
        json.dump(test_loss, json_file, indent = 4)

# used MNIST data preproc as template for this definition
def load_data(): 
    files = os.listdir(data_dir)
    files = [f for f in files if '.nc' in f]
    
    all_data = np.expand_dims(
        np.concatenate(
            [netCDF4.Dataset(data_dir + converted_file)['msl'][:] for converted_file in files]
        ),
        -1
    ).astype("float32") / 110000
    return all_data

def run_train(num_files, date, vae):
    slp = load_data() # load data
    print("shape:", np.shape(slp)) # verify data shape
    
    # split the data - y values are throw away
    X_train, X_test, y_train, y_test = train_test_split(slp[0:(num_files * 80 - 1), :, :, :], np.arange(0, num_files * 80 - 1), test_size = 0.2, random_state = 1)
    X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size = 0.25, random_state = 1) # 0.25 x 0.8 = 0.2

    train_model(X_train, X_test, X_valid, date, vae)
    remove_data()

# Train the VAE model

In [None]:
latent_dim = 2
model_dir = './model_dir'

os.makedirs(model_dir, exist_ok = True)
os.makedirs(model_dir + '/model', exist_ok = True)

# build encoder
encoder = build_encoder(latent_dim)
print("Memory usage after building encoder:", tf.config.experimental.get_memory_info('GPU:0'))

# build decoder
decoder = build_decoder(latent_dim)
print("Memory usage after building decoder:", tf.config.experimental.get_memory_info('GPU:0'))

# build VAE (variational autoencoder)
vae = VAE(encoder, decoder)
vae.compile(optimizer = 'rmsprop') 
# vae.compile(optimizer = keras.optimizers.Adam())
print("Memory usage after building VAE:", tf.config.experimental.get_memory_info('GPU:0'))

In [None]:
# parameter cell for pm 
year = "2018"
month = "01"
day = "01"
ensembles = ["c00", "p01", "p02", "p03", "p04"]

In [None]:
# run training
num_files = get_data(year, month, day, ensembles)
!csh batch_grib2nc.csh

date = year + month + day
run_train(num_files, date, vae)
    
print("Memory usage after training:", tf.config.experimental.get_memory_info('GPU:0'))