## Imports

In [1]:
import tensorflow as tf
from glob import glob
from os.path import join
from tensorflow.keras import layers, models
import xarray as xr
import pandas as pd
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
import imports.GDL_layers as GDL_layers

from time import time

from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import StandardScaler

## Load data

In [2]:
# Find a list of all the datafiles
patch_path = "/glade/scratch/lverhoef/WRF_all/track_data_hrrr_3km_nc_refl/"
patch_files = sorted(glob(join(patch_path, "*.nc")))
csv_path = "/glade/scratch/lverhoef/WRF_all/track_data_hrrr_3km_csv_refl/"
csv_files = sorted(glob(join(csv_path, "track_step_*.csv")))

In [3]:
ds = xr.open_mfdataset(patch_files, combine='nested', concat_dim='p', combine_attrs='drop_conflicts')
ds['date'] = (('p'), [str(x).split("T")[0] for x in ds['time'].values])

In [4]:
flat_seed = 1000
gsp = GroupShuffleSplit(n_splits=1, random_state = flat_seed, train_size=.9)
splits = list(gsp.split(ds["REFL_COM_curr"], groups = ds["date"]))
train_index, test_index = splits[0]
input_train_full, input_test = ds.isel(p=train_index), ds.isel(p=test_index)

In [5]:
csv_variables = ["major_axis_length", "minor_axis_length"]
csv_data_list = []
for csv_file in csv_files:
    csv_ds = pd.read_csv(csv_file)
    csv_data_list.append(csv_ds[csv_variables].to_xarray().rename({'index': 'p'}))
csv_ds = xr.concat(csv_data_list, dim='p').to_array().transpose()

In [6]:
output_train_full, output_test = csv_ds.isel(p=train_index), csv_ds.isel(p=test_index)

In [7]:
class ImageScaler(object):
    
    def __init__(self, with_mean=True, with_std=True, copy=True):
        self.with_mean = with_mean
        self.with_std = with_std
        self.copy = copy
        
    def fit(self, X, y=None):
        self.mean_ = X.mean()
        self.std_ = X.std()
        return None
        
    def transform(self, X, copy=None):
        try:
            self.mean_
            self.std_
        except NameError:
            print('Must run .fit method first')
        if copy is None:
            copy = self.copy
        if self.with_mean:
            X -= self.mean_
        if self.with_std:
            X /= self.std_
        return X

    def fit_transform(self, X, y=None):
        self.fit(X)
        return self.transform(X)
    
    def inverse_transform(self, X, y=None):
        try:
            self.mean_
            self.std_
        except NameError:
            print('Must run .fit method first')
        if self.with_std:
            X *= self.std_
        if self.with_mean:
            X += self.mean_

In [8]:
gsp = GroupShuffleSplit(n_splits=1, train_size=0.885)
splits = list(gsp.split(input_train_full["REFL_COM_curr"], groups=input_train_full['date']))
train_index, val_index = splits[0]
input_train, input_val = input_train_full.isel(p=train_index), input_train_full.isel(p=val_index)
output_train, output_val = output_train_full.isel(p=train_index), output_train_full.isel(p=val_index)

x_scaler, y_scaler = ImageScaler(), StandardScaler()
input_train_norm = x_scaler.fit_transform(input_train["REFL_COM_curr"].values)
input_val_norm = x_scaler.transform(input_val["REFL_COM_curr"].values)
input_test_norm = x_scaler.transform(input_test["REFL_COM_curr"].values)

output_train_norm = y_scaler.fit_transform(output_train.values)
output_val_norm = y_scaler.transform(output_val.values)
output_test_norm = y_scaler.transform(output_test.values)

## Train Models

In [11]:
ricnn_models = [models.Sequential() for i in range(6)]
filters = [32, 64, 64, 128]
for inv_layer in range(6):
    ricnn_models[inv_layer].add(GDL_layers.RotEquivConv2D(filters[0], (3, 3), rot_axis=False, input_shape=(144, 144, 1)))
    ricnn_models[inv_layer].add(GDL_layers.RotEquivPool2D((2, 2)))
    if inv_layer == 0:
        ricnn_models[inv_layer].add(GDL_layers.RotInvPool())
    for layer, filts in enumerate(filters):
        if layer < inv_layer:
            ricnn_models[inv_layer].add(GDL_layers.RotEquivConv2D(filters[layer], (3, 3)))
            ricnn_models[inv_layer].add(GDL_layers.RotEquivPool2D((2, 2)))
            if layer + 1 == inv_layer:
                ricnn_models[inv_layer].add(GDL_layers.RotInvPool())
        else:
            ricnn_models[inv_layer].add(layers.Conv2D(filters[layer], (3, 3), activation='relu'))
            ricnn_models[inv_layer].add(layers.MaxPooling2D((2, 2)))
    ricnn_models[inv_layer].add(layers.Flatten())
    ricnn_models[inv_layer].add(layers.Dense(32, activation='relu'))
    ricnn_models[inv_layer].add(layers.Dense(2))
    ricnn_models[inv_layer].compile(
        optimizer='nadam',
        loss='mse'
    )

In [None]:
ricnn_history = []
ricnn_test = []
time_history = []
for inv_layer, model in enumerate(ricnn_models):
    model.summary()
    start = time()
    ricnn_history.append(model.fit(x=input_train_norm, y=output_train_norm, epochs=10, validation_data=(input_val_norm, output_val_norm)))
    time_history.append(time() - start)
    print(f'Time elapsed: {time_history[-1]:.3f}')
    ricnn_test.append(model.evaluate(x=input_test_norm, y=output_test_norm))
    model.save(f"saved_models/rot_inv_hyper/inv_layer_{inv_layer}")
    K.clear_session()

Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_equiv_conv2d_1 (RotEqui  (None, 142, 142, 4, 32)  320       
 vConv2D)                                                        
                                                                 
 rot_equiv_pool2d_1 (RotEqui  (None, 71, 71, 4, 32)    0         
 vPool2D)                                                        
                                                                 
 rot_inv_pool (RotInvPool)   (None, 71, 71, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 69, 69, 32)        9248      
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 34, 34, 32)       0         
 2D)                                                             
                                                      

2022-12-19 08:54:03.356219: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8100


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Time elapsed: 571.598




INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_0/assets


INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_0/assets


Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_equiv_conv2d_2 (RotEqui  (None, 142, 142, 4, 32)  320       
 vConv2D)                                                        
                                                                 
 rot_equiv_pool2d_2 (RotEqui  (None, 71, 71, 4, 32)    0         
 vPool2D)                                                        
                                                                 
 rot_equiv_conv2d_3 (RotEqui  (None, 69, 69, 4, 32)    9248      
 vConv2D)                                                        
                                                                 
 rot_equiv_pool2d_3 (RotEqui  (None, 34, 34, 4, 32)    0         
 vPool2D)                                                        
                                                                 
 rot_inv_pool_1 (RotInvPool)  (None, 34, 34, 32)      



INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_1/assets


INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_1/assets


Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_equiv_conv2d_4 (RotEqui  (None, 142, 142, 4, 32)  320       
 vConv2D)                                                        
                                                                 
 rot_equiv_pool2d_4 (RotEqui  (None, 71, 71, 4, 32)    0         
 vPool2D)                                                        
                                                                 
 rot_equiv_conv2d_5 (RotEqui  (None, 69, 69, 4, 32)    9248      
 vConv2D)                                                        
                                                                 
 rot_equiv_pool2d_5 (RotEqui  (None, 34, 34, 4, 32)    0         
 vPool2D)                                                        
                                                                 
 rot_equiv_conv2d_6 (RotEqui  (None, 32, 32, 4, 64)   



INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_2/assets


INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_2/assets


Model: "sequential_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_equiv_conv2d_7 (RotEqui  (None, 142, 142, 4, 32)  320       
 vConv2D)                                                        
                                                                 
 rot_equiv_pool2d_7 (RotEqui  (None, 71, 71, 4, 32)    0         
 vPool2D)                                                        
                                                                 
 rot_equiv_conv2d_8 (RotEqui  (None, 69, 69, 4, 32)    9248      
 vConv2D)                                                        
                                                                 
 rot_equiv_pool2d_8 (RotEqui  (None, 34, 34, 4, 32)    0         
 vPool2D)                                                        
                                                                 
 rot_equiv_conv2d_9 (RotEqui  (None, 32, 32, 4, 64)   



INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_3/assets


INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_3/assets


Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_equiv_conv2d_11 (RotEqu  (None, 142, 142, 4, 32)  320       
 ivConv2D)                                                       
                                                                 
 rot_equiv_pool2d_11 (RotEqu  (None, 71, 71, 4, 32)    0         
 ivPool2D)                                                       
                                                                 
 rot_equiv_conv2d_12 (RotEqu  (None, 69, 69, 4, 32)    9248      
 ivConv2D)                                                       
                                                                 
 rot_equiv_pool2d_12 (RotEqu  (None, 34, 34, 4, 32)    0         
 ivPool2D)                                                       
                                                                 
 rot_equiv_conv2d_13 (RotEqu  (None, 32, 32, 4, 64)  

INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_4/assets


Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rot_equiv_conv2d_16 (RotEqu  (None, 142, 142, 4, 32)  320       
 ivConv2D)                                                       
                                                                 
 rot_equiv_pool2d_16 (RotEqu  (None, 71, 71, 4, 32)    0         
 ivPool2D)                                                       
                                                                 
 rot_equiv_conv2d_17 (RotEqu  (None, 69, 69, 4, 32)    9248      
 ivConv2D)                                                       
                                                                 
 rot_equiv_pool2d_17 (RotEqu  (None, 34, 34, 4, 32)    0         
 ivPool2D)                                                       
                                                                 
 rot_equiv_conv2d_18 (RotEqu  (None, 32, 32, 4, 64)  

INFO:tensorflow:Assets written to: saved_models/rot_inv_hyper/inv_layer_5/assets


In [19]:
print(ricnn_test)
print(time_history)

[0.08970838040113449, 0.0711064487695694, 0.07478220015764236, 0.06666864454746246, 0.06033242121338844, 0.05551617220044136]
[571.5980935096741, 707.718359708786, 765.6176037788391, 791.5497725009918, 808.8878953456879, 807.6326174736023]


In [None]:
F, ax = plt.subplots(1, 1, figsize=(15, 10))
colors = ['orange', 'red', 'blue', 'green', 'gold', 'purple']
for i in range(6):
    ax.plot(ricnn_history[i].history['loss'], color=colors[i], label = f'Layer {i} inv train')
    ax.plot(ricnn_history[i].history['val_loss'], color=colors[i], label = f'Layer {i} inv val', linestyle='dashed')
plt.legend()
ax.set_title("History")
ax.set_ylim(bottom=0)