# **CVAE Weather Ensemble Model**

This notebook provides examples for working with weather data along with a section to launch online learning.

# Libraries and Setup

In [None]:
import os, json

import papermill as pm
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import tensorflow as tf
import netCDF4
import cartopy

from tensorflow import keras
from keras import layers
from sklearn.model_selection import train_test_split 

print("TF version:", tf.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
# make needed directories
# use -p to make parent directories
!mkdir -p gefs_data/converted
!mkdir -p model_dir

In [None]:
# DVC initialization and storage set up
!dvc init --subdir
!dvc remote add -d dvcstorage /aws-dvc-bucket

In [None]:
# initial commit to git
!git add .
!git commit -m "loaded dependencies, mkdir -p, DVC init"

# Download and Convert Data
On my [first Google hit for GEFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-ensemble-forecast), I clicked on [AWS Open Data Registry for GEFS](https://registry.opendata.aws/noaa-gefs-pds/) and selected [NOAA GEFS Re-forecast](https://registry.opendata.aws/noaa-gefs-reforecast/) which has no useage restrictions.  The [GEFS Re-forecast data documentation](https://noaa-gefs-retrospective.s3.amazonaws.com/Description_of_reforecast_data.pdf) is very clear. The date of the initialization of the re-forecast is in the file name in the format YYYYMMDDHH.  The c00, p01, p02, p03, p04 are the control and perturbation ensemble members (5 total).

In [None]:
from scripts.get_data import download_file
from scripts.get_data import convert_file
from scripts.get_data import remove_data

In [None]:
data_pdir = "./gefs_data"
data_dir = "./gefs_data/converted/"
model_dir = './model_dir'

## Examples

In [None]:
# example parameters
ex_year = "2018"
ex_month = "01"
ex_day = "01"
ex_ensemble = "c00"

In [None]:
# example for getting and converting files 
download_file(ex_year, ex_month, ex_day, ex_ensemble, data_pdir)
convert_file(ex_year, ex_month, ex_day, ex_ensemble, data_dir)

In [None]:
# example for loading data
dataset = netCDF4.Dataset(data_dir + f"pres_msl_{ex_year}{ex_month}{ex_day}00_{ex_ensemble}.nc")

In [None]:
# example for simple data access
print(dataset) # look at data structure
print(dataset.variables.keys())

for var in dataset.variables:
    print(dataset.variables[var])
    # print(dataset.variables[var][:]) # prints actual data

In [None]:
# example for plot
fig = plt.figure(figsize=(9,6))
ax = plt.axes(projection = cartopy.crs.LambertConformal())

ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.LAKES, alpha = 0.5)
ax.add_feature(cartopy.feature.STATES, edgecolor='grey')
ax.set_extent([-120, -73, 23, 50])

plt.contour(
    dataset.variables['lon'][:],     # longitudes
    dataset.variables['lat'][:],     # latitudes
    dataset.variables['msl'][0,:,:], # actual data
    transform = cartopy.crs.PlateCarree()) #, levels=np.arange(30000,110000,20000))

plt.title('GEFSv12 SLP 2019 01 10 0000 UTC Cycle')
plt.colorbar()
plt.show()

# Execute Online Training

*add a note about how datat is randomized, sorted, and stuff*

In [None]:
# clean up example data before executing online session
remove_data(data_pdir)

In [None]:
# parameters for pm -> change the lists to fit your needs
years = ["2000", "2001", "2002", "2003"]#, "2004", "2005", "2006"] # , "2007", "2008", "2009", 
         # "2010", "2011", "2012", "2013", "2014", "2015", "2016", "2017", "2018", "2019", "2020"]
days = ["01", "10", "20"] # "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "13", "14", "15", 
        # "16", "17", "18", "19", "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31"]

In [None]:
for y in years:
    for d in days:
            pm.execute_notebook(
                'cvae_training.ipynb',
                'cvae_log.ipynb',
                parameters = dict(year = y, day = d),
                kernel_name = 'cvae_env' # this should be changed to whatever you chose <NAME> to be during environment set up
            )

In [None]:
# Try different filter sizes
# Aim for large initial filter > 200km scale (about 10)
# Aim for some, but minimal overlap in initial filter.

# Aim for smallish second filter, but still try to reduce dimensionality
# to make dense network tractable later. No overlap (but no good
# reason why this is).

# ----------------------Input: 721 x 1440--------------

# For Lat = 721,
# K = 11 -> K_radius = 5.0 -> S = 9 -> H_out = 79.0

# For Lon = 1440,
# K = 11 -> K_radius = 5.0 -> S = 10 -> H_out = 143.0

# -----------------------Layer1: 79 x 143----------------

# For Lat = 79,
# K = 5 -> K_radius = 2.0 -> S = 5 -> H_out = 15.0

# For Lon = 143,
# K = 9 -> K_radius = 4.0 -> S = 9 -> H_out = 15.0

'''
H_in = 1440
P = 0
K_list = [3, 5, 7, 9, 11, 13, 15]                    # Kernel size
S_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] # Stride

for K in K_list:
    for S in S_list:
        K_radius = np.floor(np.divide(K, 2))   # Half width of number of points around the central point
        K_diameter = K - 1                     # Number of points around the central point, ASSUMES K = ODD
        # S = K                                # S = K is stride necessary to have non-overlapping filters
        print('K = ' + str(K) + ' -> K_radius = ' + str(K_radius) + ' -> S = ' + str(S) + ' -> H_out = ' + str((H_in + (2 * P) - K_diameter) / S))
    print('')
'''

# Display a grid of sampled digits

In [None]:
# if latent_dim == 2:
#     plot_latent_space(vae, path = os.path.join(model_dir, 'latent_space.png'))

# # Generating new images
# codings = tf.random.normal(shape = [12, latent_dim])
# images = vae.decoder(codings).numpy()
# plot_images(images, 3, 4, path = os.path.join(model_dir, 'generated.png'))

# # Semantic interpolation
# codings_grid = tf.reshape(codings, [1, 3, 4, latent_dim])
# larger_grid = tf.image.resize(codings_grid, size = [5, 7])
# interpolated_codings = tf.reshape(larger_grid, [-1, latent_dim])
# images = vae.decoder(interpolated_codings).numpy()
# plot_images(images, 5, 7, path = os.path.join(model_dir, 'interpolated.png'))

In [None]:
# plot_latent_space(vae)

# Display how the latent space clusters different digit classes

In [None]:
# def plot_label_clusters(vae, data, labels):
#     # display a 2D plot of the digit classes in the latent space
#     z_mean, _, _ = vae.encoder.predict(data)
#     plt.figure(figsize=(12, 10))
#     plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
#     plt.colorbar()
#     plt.xlabel("z[0]")
#     plt.ylabel("z[1]")
#     plt.show()


# (x_train, y_train), _ = keras.datasets.mnist.load_data()
# x_train = np.expand_dims(x_train, -1).astype("float32") / 255

# plot_label_clusters(vae, x_train, y_train)