In [6]:
import sys
sys.path.insert(1,"/home/ankitesh/miniconda3/envs/CbrainCustomLayer/lib/python3.6/site-packages") #work around for h5py
from cbrain.imports import *
from cbrain.cam_constants import *
from cbrain.utils import *
from cbrain.layers import *
from cbrain.data_generator import DataGenerator
import tensorflow as tf
from tensorflow import math as tfm
# import tensorflow_probability as tfp
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import xarray as xr
import numpy as np
from cbrain.model_diagnostics import ModelDiagnostics
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.image as imag
import scipy.integrate as sin
# import cartopy.crs as ccrs
import matplotlib.ticker as mticker
# from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import pickle

Description - this notbook is a try to replace the norm.nc file by adding batch norm between the layers.  
AIM: Generate and compare three models(vanilla and Input_RH)  
1 - Without any input scaling(least priority, just for get a deeper insights on the learning)  
2 - With norm.nc file(This is alread trained $SCRATCH/models)  
3 - With adding batch norm layer between the networks  

In [4]:
np.random.seed(0)

### Constants 

In [4]:
TRAINDIR = '/oasis/scratch/comet/ankitesh/temp_project/PrepData/'
path = '/home/ankitesh/CBrain_project/CBRAIN-CAM/cbrain/'
path_hyam = 'hyam_hybm.pkl'
BATCH_SIZE = 1024


In [5]:
hf = open(path+path_hyam,'rb')
hyam,hybm = pickle.load(hf)


### BatchNorm Experiment

Note: We are initially going to use the in-built batch norm layer that scales using the mean and the std.    

### Modification such that no input normalization is performed.

changes in  
1 cbrain/normalization.py  
2 data_generator.py  
Added normalize_flag(defaults to True) - pass this flag as False to the Data Generator object to turn off rescaling

### Data Generators

In [7]:
TRAINFILE = 'CI_SP_M4K_train_shuffle.nc'
NORMFILE = 'CI_SP_M4K_NORM_norm.nc'
VALIDFILE = 'CI_SP_M4K_valid.nc'
import xarray as xr
ds = xr.open_dataset(TRAINDIR+TRAINFILE)

In [8]:
scale_dict = load_pickle('/home/ankitesh/CBrain_project/CBRAIN-CAM/nn_config/scale_dicts/009_Wm2_scaling.pkl')
in_vars = ['QBP','TBP','PS', 'SOLIN', 'SHFLX', 'LHFLX']
out_vars = ['PHQ','TPHYSTND','FSNT', 'FSNS', 'FLNT', 'FLNS']

In [9]:
train_gen = DataGenerator(
    data_fn = TRAINDIR+TRAINFILE,
    input_vars = in_vars,
    output_vars = out_vars,
    norm_fn = TRAINDIR+NORMFILE,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dict,
    batch_size=1024,
    shuffle=True,
    normalize_flag=False
)

In [10]:
valid_gen = DataGenerator(
    data_fn = TRAINDIR+VALIDFILE,
    input_vars = in_vars,
    output_vars = out_vars,
    norm_fn = TRAINDIR+NORMFILE,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dict,
    batch_size=1024,
    shuffle=True,
    normalize_flag=False
)

### Generators using RH Inputs

In [11]:
scale_dict_RH = load_pickle('/home/ankitesh/CBrain_project/CBRAIN-CAM/nn_config/scale_dicts/009_Wm2_scaling.pkl')

In [12]:
scale_dict_RH['RH'] = 0.01*L_S/G, # Arbitrary 0.1 factor as specific humidity is generally below 2%

In [13]:
in_vars_RH = ['RH','TBP','PS', 'SOLIN', 'SHFLX', 'LHFLX']
out_vars_RH = ['PHQ','TPHYSTND','FSNT', 'FSNS', 'FLNT', 'FLNS']

In [14]:
TRAINFILE_RH = 'CI_RH_M4K_NORM_train_shuffle.nc'
NORMFILE_RH = 'CI_RH_M4K_NORM_norm.nc'
VALIDFILE_RH = 'CI_RH_M4K_NORM_valid.nc'

In [15]:
train_gen_RH = DataGenerator(
    data_fn = TRAINDIR+TRAINFILE_RH,
    input_vars = in_vars_RH,
    output_vars = out_vars_RH,
    norm_fn = TRAINDIR+NORMFILE_RH,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dict_RH,
    batch_size=1024,
    shuffle=True,
    normalize_flag=False
)

### Neural Networks

### Vanilla Network

### Without any scaling

In [16]:
inp = Input(shape=(64,))
densout = Dense(128, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (6):
    densout = Dense(128, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
out = Dense(64, activation='linear')(densout)
Brute_force_NO_BN = tf.keras.models.Model(inp, out)

In [17]:
Brute_force_NO_BN.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 64)]              0         
_________________________________________________________________
dense (Dense)                (None, 128)               8320      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               16512     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 128)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               16512     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 128)               0     

### With Batch Norm

In [18]:
inp = Input(shape=(64,))
batch_norm_1 = BatchNormalization()(inp)
densout = Dense(128, activation='linear')(batch_norm_1)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (3):
    batch_norm_i = BatchNormalization()(densout)
    densout = Dense(128, activation='linear')(batch_norm_i)
    densout = LeakyReLU(alpha=0.3)(densout)
batch_norm_out = BatchNormalization()(densout)
out = Dense(64, activation='linear')(batch_norm_out)
Brute_force = tf.keras.models.Model(inp, out)

In [19]:
Brute_force.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 64)]              0         
_________________________________________________________________
batch_normalization (BatchNo (None, 64)                256       
_________________________________________________________________
dense_8 (Dense)              (None, 128)               8320      
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 128)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 128)               512       
_________________________________________________________________
dense_9 (Dense)              (None, 128)               16512     
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 128)               0   

#### Traning the model

In [17]:
path_HDF5 = '/oasis/scratch/comet/ankitesh/temp_project/models/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save_BF_BN = ModelCheckpoint(path_HDF5+'CI01_BF_BN.hdf5',save_best_only=True, monitor='val_loss', mode='min')
mcp_save_BF_NO_BN = ModelCheckpoint(path_HDF5+'CI01_BF_NO_BN.hdf5',save_best_only=True, monitor='val_loss', mode='min')


In [18]:
Brute_force.compile(tf.keras.optimizers.Adam(), loss=mse)
# Brute_force_NO_BN.compile(tf.keras.optimizers.Adam(), loss=mse)

In [None]:
Nep = 10-2
Brute_force.load_weights(path_HDF5+'CI01_BF_BN.hdf5')
Brute_force.fit_generator(train_gen, epochs=Nep, validation_data=valid_gen,\
              callbacks=[earlyStopping, mcp_save_BF_BN])

Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
 3197/41376 [=>............................] - ETA: 29:40 - loss: 356.6758

In [None]:
Nep = 6
Brute_force_NO_BN.fit_generator(train_gen, epochs=Nep, validation_data=valid_gen,\
              callbacks=[earlyStopping, mcp_save_BF_NO_BN])

Epoch 1/6
Epoch 2/6

### Q-RH Network

In [20]:
from climate_invariant import *

In [21]:
inp = Input(shape=(64,))
batch_norm_1 = BatchNormalization()(inp)
inpRH = QV2RH(inp_subQ=train_gen.input_transform.sub, 
              inp_divQ=train_gen.input_transform.div, 
              inp_subRH=train_gen_RH.input_transform.sub, 
              inp_divRH=train_gen_RH.input_transform.div, 
              hyam=hyam, hybm=hybm)(batch_norm_1)
batch_norm_2 = BatchNormalization()(inpRH)
densout = Dense(128, activation='linear')(inpRH)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (3):
    batch_norm_i = BatchNormalization()(densout)
    densout = Dense(128, activation='linear')(batch_norm_i)
    densout = LeakyReLU(alpha=0.3)(densout)
batch_norm_out = BatchNormalization()(densout)
out = Dense(64, activation='linear')(batch_norm_out)
Input_RH = tf.keras.models.Model(inp, out)



In [22]:
Input_RH.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 64)]              0         
_________________________________________________________________
batch_normalization_5 (Batch (None, 64)                256       
_________________________________________________________________
q_v2rh (QV2RH)               (None, 64)                0         
_________________________________________________________________
dense_13 (Dense)             (None, 128)               8320      
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 128)               0         
_________________________________________________________________
batch_normalization_7 (Batch (None, 128)               512       
_________________________________________________________________
dense_14 (Dense)             (None, 128)               1651

In [23]:
path_HDF5 = '/oasis/scratch/comet/ankitesh/temp_project/models/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save_RH = ModelCheckpoint(path_HDF5+'CI01_RH_BN.hdf5',save_best_only=True, monitor='val_loss', mode='min')

In [24]:
Input_RH.compile(tf.keras.optimizers.Adam(), loss=mse)

In [None]:
Nep = 10
Input_RH.fit_generator(train_gen, epochs=Nep, validation_data=valid_gen,\
              callbacks=[earlyStopping, mcp_save_RH])

Epoch 1/10

### Build model diagnostics objects

In [7]:
from climate_invariant import *

In [8]:
path_HDF5 = '/oasis/scratch/comet/ankitesh/temp_project/models/' # Path to NN weights
config_file = 'CI_SP_M4K_CONFIG.yml' # Configuration file
data_file = ['CI_SP_M4K_valid.nc','CI_SP_P4K_valid.nc'] # Validation/test data sets
#BFS
NNarray = ['CI01_BF.hdf5','CI01_RH.hdf5','CI01_BF_BN.hdf5','CI01_RH_BN.hdf5'] # NN to evaluate
NNNorm = [True,True,False,False]
NNname = ['BF', 'RH', 'BF BN', 'RH BN'] # Name of NNs for plotting
dict_lay = {'SurRadLayer':SurRadLayer,'MassConsLayer':MassConsLayer,'EntConsLayer':EntConsLayer, 
            'T2TmTNS':T2TmTNS,'QV2RH':QV2RH,'eliq':eliq,'eice':eice,'esat':esat,'qv':qv,'RH':RH}

In [9]:
# Indices of different variables
PHQ_idx = slice(0,30)
TPHYSTND_idx = slice(30,60)

In [10]:
NN = {}; md = {};
# %cd $TRAINDIR/HDF5_DATA
for i,NNs in enumerate(NNarray):
    print('NN name is ',NNs)
    path = path_HDF5+NNs
    NN[NNs] = load_model(path,custom_objects=dict_lay)
    md[NNs] = {}
    for j,data in enumerate(data_file):
        print('data name is ',data)
        md[NNs][data[6:-3]] = ModelDiagnostics(NN[NNs],
                                                '/home/ankitesh/CBrain_project/PrepData/'+config_file,
                                                TRAINDIR+data,normalize_flag=NNNorm[i])

NN name is  CI01_BF.hdf5
data name is  CI_SP_M4K_valid.nc


  config = yaml.load(f)


data name is  CI_SP_P4K_valid.nc


  config = yaml.load(f)


NN name is  CI01_RH.hdf5
data name is  CI_SP_M4K_valid.nc
data name is  CI_SP_P4K_valid.nc
NN name is  CI01_BF_BN.hdf5


  config = yaml.load(f)
  config = yaml.load(f)


data name is  CI_SP_M4K_valid.nc
data name is  CI_SP_P4K_valid.nc
NN name is  CI01_RH_BN.hdf5


  config = yaml.load(f)
  config = yaml.load(f)


data name is  CI_SP_M4K_valid.nc
data name is  CI_SP_P4K_valid.nc


  config = yaml.load(f)
  config = yaml.load(f)


In [11]:
lat_ind = np.arange(26,40)
iini = 1000 # Initial time to sample
iend = iini+47 # One day later

diagno = {} # Diagnostics structure
diagno['truth'] = {} # Diagnostics structure for the truth
for i,NNs in enumerate(NNarray):
    print('i=',i,'& NNs=',NNs,'         ')
    diagno[NNs] = {} # Diagnostics structure for each NN
    for j,data in enumerate(data_file):
        diagno[NNs][data[6:-3]] = {} # Diagnostics structure for each data file
        if i==0: diagno['truth'][data[6:-3]] = {}
        print('j=',j,'& data=',data,'         ',end='\r')
        for itime in tqdm(np.arange(iini,iend)):
            # Get input, prediction and truth from NN
            inp, p, truth = md[NNs][data[6:-3]].get_inp_pred_truth(itime)  # [lat, lon, var, lev]
            # Get convective heating and moistening for each NN
            p = p.numpy()
 
            
            if itime==iini:
                if i==0:
                    
                    diagno['truth'][data[6:-3]]['PHQ'] = md[NNs][data[6:-3]].reshape_ngeo(truth[:,PHQ_idx])[lat_ind,:,:,np.newaxis]
                    diagno['truth'][data[6:-3]]['TPHYSTND'] = md[NNs][data[6:-3]].reshape_ngeo(truth[:,TPHYSTND_idx])[lat_ind,:,:,np.newaxis]
                diagno[NNs][data[6:-3]]['PHQ'] = md[NNs][data[6:-3]].reshape_ngeo(p[:,PHQ_idx])[lat_ind,:,:,np.newaxis]
                diagno[NNs][data[6:-3]]['TPHYSTND'] = md[NNs][data[6:-3]].reshape_ngeo(p[:,TPHYSTND_idx])[lat_ind,:,:,np.newaxis]
            else:
                for istr,field in enumerate(['PHQ','TPHYSTND']):
                    if field=='PHQ': ind_field = PHQ_idx
                    elif field=='TPHYSTND': ind_field = TPHYSTND_idx
                    diagno[NNs][data[6:-3]][field] = np.concatenate((diagno[NNs][data[6:-3]][field],
                                                         md[NNs][data[6:-3]].\
                                                         reshape_ngeo(p[:,ind_field])[lat_ind,:,:,np.newaxis]),
                                                        axis=3)
                    if i==0:
                        diagno['truth'][data[6:-3]][field] = np.concatenate((diagno['truth'][data[6:-3]][field],
                                                                 md[NNs][data[6:-3]].\
                                                                 reshape_ngeo(truth[:,ind_field])[lat_ind,:,:,np.newaxis]),
                                                                axis=3)


i= 0 & NNs= CI01_BF.hdf5          
j= 0 & data= CI_SP_M4K_valid.nc          

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


j= 1 & data= CI_SP_P4K_valid.nc          

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


i= 1 & NNs= CI01_RH.hdf5          
j= 0 & data= CI_SP_M4K_valid.nc          

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


j= 1 & data= CI_SP_P4K_valid.nc          

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


i= 2 & NNs= CI01_BF_BN.hdf5          
j= 0 & data= CI_SP_M4K_valid.nc          

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


j= 1 & data= CI_SP_P4K_valid.nc          

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


i= 3 & NNs= CI01_RH_BN.hdf5          
j= 0 & data= CI_SP_M4K_valid.nc          

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))


j= 1 & data= CI_SP_P4K_valid.nc          

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))




In [12]:
# Load coordinates (just pick any file from the climate model run)
coor = xr.open_dataset("/oasis/scratch/comet/ankitesh/temp_project/data/sp8fbp_minus4k.cam2.h1.0000-01-01-00000.nc",\
                    decode_times=False)
lat = coor.lat; lon = coor.lon; lev = coor.lev;
coor.close();

In [13]:
coor.lat[26:40] # A tropical latitude range that we can use for testing


In [14]:
coor.lat[13:20] # A mid-latitude range that we can use for testing


In [15]:
coor.lat[0:7] # A polar range that we can use for testing


In [16]:
# Plot characteristics
fz = 20
lw = 4
plt.rc('text', usetex=True)
plt.rc('font',size=fz)
#plt.rc('font',**{'family':'serif','serif':['Computer Modern Roman']}, size=fz)
mpl.rcParams['lines.linewidth'] = lw
plt.close('all')

In [17]:
diagno['truth'].keys()


dict_keys(['M4K_valid', 'P4K_valid'])

Validation on the (-4K) dataset

In [19]:
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
## for Palatino and other serif fonts use:
rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)
os.environ['PATH'] = os.environ['PATH'] + ':/home1/apps/texlive/2017/bin/x86_64-linux/latex'
%matplotlib inline

In [20]:
data = 'M4K_valid'

plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
plt.axvline(x=0,c='lightgray')
for iNN,NNs in enumerate(NNarray[:]):
    plt.plot(np.mean(diagno[NNs][data]['PHQ'],axis=(0,1,3)),lev,label=NNname[iNN])
plt.plot(np.mean(diagno['truth'][data]['PHQ'],axis=(0,1,3)),lev,label='Truth',color='k')
plt.legend()
plt.xlabel('Convective moistening ($\mathrm{W\ m^{-2}}$)')
plt.ylabel('Pressure (hPa)')
plt.gca().invert_yaxis()
#plt.xlim((-20,20))

plt.subplot(1,2,2)
plt.axvline(x=0,c='lightgray')
for iNN,NNs in enumerate(NNarray[:]):
    plt.plot(np.mean(diagno[NNs][data]['TPHYSTND'],axis=(0,1,3)),lev,label=NNname[iNN])
plt.plot(np.mean(diagno['truth'][data]['TPHYSTND'],axis=(0,1,3)),lev,label='Truth',color='k')
plt.legend()
plt.xlabel('Convective heating ($\mathrm{W\ m^{-2}}$)')
plt.gca().invert_yaxis()
plt.xlim((-9,9))


(-9, 9)

RuntimeError: latex was not able to process the following string:
b'lp'

Here is the full report generated by latex:
This is pdfTeX, Version 3.1415926-2.5-1.40.14 (TeX Live 2013)
 restricted \write18 enabled.
entering extended mode

(/home/ankitesh/.cache/matplotlib/tex.cache/72237feb031b451a2857f7daf88d3d03.te
x
LaTeX2e <2011/06/27>
Babel <v3.8m> and hyphenation patterns for english, dumylang, nohyphenation, lo
aded.
(/usr/share/texlive/texmf-dist/tex/latex/base/article.cls
Document Class: article 2007/10/19 v1.4h Standard LaTeX document class
(/usr/share/texlive/texmf-dist/tex/latex/base/size10.clo))

! LaTeX Error: File `type1cm.sty' not found.

Type X to quit or <RETURN> to proceed,
or enter new name. (Default extension: sty)

Enter file name: 
! Emergency stop.
<read *> 
         
l.4 \usepackage
               {mathpazo}^^M
No pages of output.
Transcript written on 72237feb031b451a2857f7daf88d3d03.log.




<Figure size 720x360 with 2 Axes>

### Cutom BatchNorm

In [11]:
from tensorflow.keras import initializers
from tensorflow.keras import layers

In [12]:
class CustomBatchNormalization(layers.Layer):
    def __init__(self, momentum=0.99, epsilon=1e-3,beta_initializer='zeros',
                 gamma_initializer='ones', moving_mean_initializer='zeros',
                 moving_range_initializer='ones',**kwargs):
        self.momentum = momentum
        self.epsilon = epsilon
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.moving_mean_initializer = initializers.get(moving_mean_initializer)
        self.moving_range_initializer = (
            initializers.get(moving_range_initializer))
        
        super().__init__(**kwargs)
    
    def build(self,input_shape):
        dim = input_shape[-1]
        shape = (dim,)
        self.gamma = self.add_weight(shape=shape,
                             name='gamma',
                             initializer=self.gamma_initializer,trainable=True)
        self.beta = self.add_weight(shape=shape,
                            name='beta',
                            initializer=self.beta_initializer,
                                   trainable=True)
        
        self.moving_mean = self.add_weight(
            shape=shape,
            name='moving_mean',
            initializer=self.moving_mean_initializer,
            trainable=False)
        
        self.moving_range = self.add_weight(
            shape=shape,
            name='moving_range',
            initializer=self.moving_range_initializer,
            trainable=False)

    def call(self, inputs, training=None):
        input_shape = inputs.shape
        
        if not training:
            scaled = (inputs-self.moving_mean)/(self.moving_range+self.epsilon)
            return self.gamma*scaled + self.beta
        
        mean = tf.math.reduce_mean(inputs,axis=0)
        maxr = tf.math.reduce_max(inputs,axis=0)
        minr = tf.math.reduce_min(inputs,axis=0)
        
        range_diff = tf.math.subtract(maxr,minr)
        self.moving_mean = tf.math.add(self.momentum*self.moving_mean, (1-self.momentum)*mean)
        self.moving_range = tf.math.add(self.momentum*self.moving_range,(1-self.momentum)*range_diff)
        scaled = tf.math.divide(tf.math.subtract(inputs,mean),(range_diff+self.epsilon))
        return tf.math.add(tf.math.multiply(self.gamma,scaled),self.beta)
    
    def get_config(self):
        config = {
            'momentum': self.momentum,
            'epsilon': self.epsilon,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'moving_mean_initializer':
                initializers.serialize(self.moving_mean_initializer),
            'moving_range_initializer':
                initializers.serialize(self.moving_range_initializer)
        }
        base_config = super(CustomBatchNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [13]:
inp = Input(shape=(64,))
batch_norm_1 = CustomBatchNormalization(dynamic=True)(inp)
densout = Dense(128, activation='linear')(batch_norm_1)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (6):
    batch_norm_i = CustomBatchNormalization(dynamic=True)(densout)
    densout = Dense(128, activation='linear')(batch_norm_i)
    densout = LeakyReLU(alpha=0.3)(densout)
batch_norm_out = CustomBatchNormalization(dynamic=True)(densout)
out = Dense(64, activation='linear')(batch_norm_out)
Inp_RH_CBN = tf.keras.models.Model(inp, out)

In [14]:
path_HDF5 = '/oasis/scratch/comet/ankitesh/temp_project/models/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save_RH = ModelCheckpoint(path_HDF5+'CI01_RH_CBN.hdf5',save_best_only=True, monitor='val_loss', mode='min')

In [15]:
Inp_RH_CBN.compile(tf.keras.optimizers.Adam(), loss=mse,experimental_run_tf_function=False)

In [None]:
Nep = 10
Inp_RH_CBN.fit_generator(train_gen, epochs=Nep, validation_data=valid_gen,\
              callbacks=[earlyStopping, mcp_save_RH])

Epoch 1/10
 8879/41376 [=====>........................] - ETA: 54:14 - loss: 929.7906