In [1]:
from climsim_utils.data_utils import *

In [2]:
grid_path = '../grid_info/ClimSim_low-res_grid-info.nc'
norm_path = '../preprocessing/normalizations/'

grid_info = xr.open_dataset(grid_path)
input_mean = xr.open_dataset(norm_path + 'inputs/input_mean.nc').astype(np.float32)
input_max = xr.open_dataset(norm_path + 'inputs/input_max.nc').astype(np.float32)
input_min = xr.open_dataset(norm_path + 'inputs/input_min.nc').astype(np.float32)
output_scale = xr.open_dataset(norm_path + 'outputs/output_scale.nc').astype(np.float32)

ml_backend = 'pytorch'
input_abbrev = 'mlexpand'
output_abbrev = 'mlo'
data = data_utils(grid_info = grid_info, 
                  input_mean = input_mean, 
                  input_max = input_max, 
                  input_min = input_min, 
                  output_scale = output_scale,
                  ml_backend = ml_backend,
                  normalize = True,
                  input_abbrev = input_abbrev,
                  output_abbrev = output_abbrev,
                  save_h5=True,
                  save_npy=False,
                  )
data.set_to_v2_vars()


In [3]:
data_dir = "/network/group/aopp/predict/HMC009_UKKONEN_CLIMSIM/ClimSim_data/ClimSim_low-res-expanded/train/preprocessed/"
data_fname = "train_first4months.h5"

data_path = data_dir + data_fname

In [4]:
hf = h5py.File(data_path, 'r')
print(hf.keys())
# <KeysViewHDF5 ['input_lev', 'input_sca', 'output_lev', 'output_sca']>
print(hf.attrs.keys())
print(hf['input_lev'].attrs.keys())
# future training data should have a "varnames" attribute for each dataset type 

#2D Input variables: ['state_t', 'state_q0001', 'state_q0002', 'state_q0003', 'state_u', 'state_v', 
# 'pbuf_ozone', 'pbuf_CH4', 'pbuf_N2O']
# We need pressure!

#1D (scalar) Input variables: ['state_ps', 'pbuf_SOLIN', 'pbuf_LHFLX', 'pbuf_SHFLX', 'pbuf_TAUX', 
# 'pbuf_TAUY', 'pbuf_COSZRS', 'cam_in_ALDIF', 'cam_in_ALDIR', 'cam_in_ASDIF', 'cam_in_ASDIR', 
# 'cam_in_LWUP', 'cam_in_ICEFRAC', 'cam_in_LANDFRAC', 'cam_in_OCNFRAC', 'cam_in_SNOWHICE', 
# 'cam_in_SNOWHLAND', 'lat', 'lon']

#2D Output variables: ['ptend_t', 'ptend_q0001', 'ptend_q0002', 'ptend_q0003', 'ptend_u', 'ptend_v']

#1D (scalar) Output variables: ['cam_out_NETSW', 'cam_out_FLWDS', 'cam_out_PRECSC', 
#'cam_out_PRECC', 'cam_out_SOLS', 'cam_out_SOLL', 'cam_out_SOLSD', 'cam_out_SOLLD']
hf.close()

<KeysViewHDF5 ['input_lev', 'input_sca', 'output_lev', 'output_sca']>
<KeysViewHDF5 []>
<KeysViewHDF5 []>


In [10]:
hf = h5py.File(data_path, 'r')
x_lay = hf['input_lev'][:]
x_sfc = hf['input_sca'][:]
y_lay = hf['output_lev'][:]
y_sfc = hf['output_sca'][:]
hf.close()

In [11]:
print(x_lay.shape, x_lay.min(), x_lay.max())

(3316608, 60, 9) -0.9369526 1.0


In [19]:
print(x_sfc.shape, x_sfc.min(), x_sfc.max())
print(y_lay.shape, y_lay.min(), y_lay.max())
print(y_sfc.shape, y_sfc.min(), y_sfc.max())


(3316608, 17) -0.8866354 0.96252674
(3316608, 60, 6) -2.848058 3.7311456
(3316608, 8) 0.0 2.592345


In [5]:
import os 
import gc
import sys
import torch
import torch.nn as nn
import torch.nn.parameter as Parameter

class MyRNN(nn.Module):
    def __init__(self, RNN_type='LSTM', 
                 nx = 9, nx_sfc=17, 
                 ny = 8, ny_sfc=8, 
                 nneur=(64,64), 
                 outputs_one_longer=False, # if True, inputs are a sequence
                 # of N and outputs a sequence of N+1 (e.g. predicting fluxes)
                 concat=False):
        # Simple bidirectional RNN (Either LSTM or GRU) for predicting column 
        # outputs shaped either (B, L, Ny) or (B, L+1, Ny) from column inputs
        # (B, L, Nx) and optionally surface inputs (B, Nx_sfc) 
        # If surface inputs exist, they are used to initialize first (upward) RNN 
        # Assumes top-of-atmosphere is first in memory i.e. at index 0 
        # if it's not the flip operations need to be moved!
        super(MyRNN, self).__init__()
        self.nx = nx
        self.ny = ny 
        self.nx_sfc = nx_sfc 
        self.ny_sfc = ny_sfc
        self.nneur = nneur 
        self.outputs_one_longer=outputs_one_longer
        if len(nneur) < 1 or len(nneur) > 3:
            sys.exit("Number of RNN layers and length of nneur should be 2 or 3")

        self.RNN_type=RNN_type
        if self.RNN_type=='LSTM':
            RNN_model = nn.LSTM
        elif self.RNN_type=='GRU':
            RNN_model = nn.GRU
        else:
            raise NotImplementedError()
                    
        self.concat=concat

        if self.nx_sfc > 0:
            self.mlp_surface1  = nn.Linear(nx_sfc, self.nneur[0])
            if self.RNN_type=="LSTM":
                self.mlp_surface2  = nn.Linear(nx_sfc, self.nneur[0])

        self.rnn1      = RNN_model(nx,            self.nneur[0], batch_first=True) # (input_size, hidden_size, num_layers=1
        self.rnn2      = RNN_model(self.nneur[0], self.nneur[1], batch_first=True)
        if len(self.nneur)==3:
            self.rnn3      = RNN_model(self.nneur[1], self.nneur[2], batch_first=True)

        # The final hidden variable is either the output from the last RNN, or
        # the  concatenated outputs from all RNNs
        if concat:
            nh_rnn = sum(nneur)
        else:
            nh_rnn = nneur[-1]

        self.mlp_output = nn.Linear(nh_rnn, self.ny)
        if self.ny_sfc>0:
            self.mlp_surface_output = nn.Linear(nneur[-1], self.ny_sfc)
        
            
    def forward(self, inputs_main, inputs_sfc=None):
            
        # batch_size = inputs_main.shape[0]
        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      
        if inputs_sfc is not None:
            sfc1 = self.mlp_surface1(inputs_sfc)
            sfc1 = nn.Tanh()(sfc1)
            
            if self.RNN_type=="LSTM":
                sfc2 = self.mlp_surface2(inputs_sfc)
                sfc2 = nn.Tanh()(sfc2)
                hidden = (sfc1.view(1,-1,self.nneur[0]), sfc2.view(1,-1,self.nneur[0])) # (h0, c0)
            else:
                hidden = (sfc1.view(1,-1,self.nneur[0]))
        else:
            hidden = None

        # print(f'Using state1 {hidden}')
        # TOA is first in memory, so we need to flip the axis
        inputs_main = torch.flip(inputs_main, [1])
      
        out, hidden = self.rnn1(inputs_main, hidden)
        
        if self.outputs_one_longer:
            out = torch.cat((sfc1, out),axis=1)

        out = torch.flip(out, [1]) # the surface was processed first, but for
        # the second RNN (and the final output) we want TOA first
        
        out2, hidden2 = self.rnn2(out) 
        
        (last_h, last_c) = hidden2

        if len(self.nneur)==3:
            rnn3_input = torch.flip(out2, [1])
            
            out3, hidden3 = self.rnn3(rnn3_input) 
            
            out3 = torch.flip(out3, [1])
            
            if self.concat:
                rnnout = torch.cat((out3, out2, out),axis=2)
            else:
                rnnout = out3
        else:
            if self.concat:
                rnnout = torch.cat((out2, out),axis=2)
            else:
                rnnout = out2
        
        out = self.mlp_output(rnnout)

        if self.ny_sfc>0:
            #print("shape last_c", last_c.shape)
            # use cell state or hidden state?
            out_sfc = self.mlp_surface_output(last_h.squeeze())
            return out, out_sfc
        else:
            return out 

In [49]:
nb = 1024 

x_lay0 = torch.from_numpy(x_lay[0:10*nb])
x_sfc0 = torch.from_numpy(x_sfc[0:10*nb])
y_lay0 = torch.from_numpy(y_lay[0:10*nb])
y_sfc0 = torch.from_numpy(y_sfc[0:10*nb])

print(x_lay0.shape, x_sfc0.shape)
print(y_lay0.shape, y_sfc0.shape)

ns, nlay, nx = x_lay0.shape
_, nx_sfc    = x_sfc0.shape
_, _, ny     = y_lay0.shape
_, ny_sfc    = y_sfc0.shape


torch.Size([10240, 60, 9]) torch.Size([10240, 17])
torch.Size([10240, 60, 6]) torch.Size([10240, 8])


In [24]:
nx = 9
nx_sfc = 17
ny = 6
ny_sfc = 8

add_refpres = True
if add_refpres:
    nx = nx + 1

model = MyRNN(RNN_type='LSTM', 
             nx = nx, nx_sfc=nx_sfc, 
             ny = ny, ny_sfc=ny_sfc, 
             nneur=(64,64))

In [54]:
out, out_sfc = model(x_lay0, x_sfc0)
print(out.shape, out_sfc.shape)

torch.Size([10240, 60, 6]) torch.Size([10240, 8])


In [62]:
state_ps = x_sfc0[:,0:1]

if data.normalize:
    state_ps = state_ps*(data.input_max['state_ps'].values - data.input_min['state_ps'].values) + data.input_mean['state_ps'].values

print("shape ps", state_ps.shape, "min", state_ps.min(), "max", state_ps.max())


shape ps torch.Size([10240, 1]) min tensor(75208.2969) max tensor(102864.2344)


  state_ps = state_ps*(data.input_max['state_ps'].values - data.input_min['state_ps'].values) + data.input_mean['state_ps'].values


In [74]:
#pressure_grid_p1 = np.array(data.grid_info['P0']*data.grid_info['hyai'])[:,np.newaxis,np.newaxis]
pressure_grid_p1 = torch.from_numpy(np.array(data.grid_info['P0']*data.grid_info['hyai'])[np.newaxis,:])
#print(pressure_grid_p1.shape)
pressure_grid_p2 = torch.from_numpy(data.grid_info['hybi'].values[np.newaxis, :]) * state_ps
#print(pressure_grid_p2.shape)
pressure_grid = pressure_grid_p1 + pressure_grid_p2
#print(pressure_grid.shape, pressure_grid.min(), pressure_grid.max())
dp     = pressure_grid[:,1:61] - pressure_grid[:,0:60]

In [98]:
p1 = np.array(data.grid_info['P0']*data.grid_info['hyam'])[np.newaxis,:] 
p2 = data.grid_info['hybm'].values[np.newaxis, :] * data.grid_info['P0'].values

In [99]:
pref = p1 + p2 
print(pref/100)

[[7.83478113e-02 1.41108318e-01 2.52923297e-01 4.49250635e-01
  7.86346161e-01 1.34735576e+00 2.24477729e+00 3.61643148e+00
  5.61583643e+00 8.40325322e+00 1.21444894e+01 1.70168280e+01
  2.32107981e+01 3.09143463e+01 4.02775807e+01 5.13746323e+01
  6.41892284e+01 7.86396576e+01 9.46300920e+01 1.12091274e+02
  1.30977804e+02 1.51221318e+02 1.72673905e+02 1.95087710e+02
  2.18155935e+02 2.41600379e+02 2.65258515e+02 2.89122322e+02
  3.13312087e+02 3.38006999e+02 3.63373492e+02 3.89523338e+02
  4.16507922e+02 4.44331412e+02 4.72957206e+02 5.02291917e+02
  5.32152273e+02 5.62239392e+02 5.92149276e+02 6.21432841e+02
  6.49689897e+02 6.76656485e+02 7.02242188e+02 7.26498589e+02
  7.49537645e+02 7.71445217e+02 7.92234260e+02 8.11856675e+02
  8.30259643e+02 8.47450653e+02 8.63535902e+02 8.78715875e+02
  8.93246018e+02 9.07385213e+02 9.21354397e+02 9.35316717e+02
  9.49378056e+02 9.63599599e+02 9.78013432e+02 9.92635544e+02]]


In [25]:
class generator_xy(torch.utils.data.Dataset):
    def __init__(self, filepath, nloc=384, nlev=60, add_refpres=True, cuda=False):
        self.filepath = filepath
        # The file list will be divided into chunks (a list of lists)eg [[12,4,32],[1,9,3]..]
        # where the length of each item is the chunk size; i.e. how many files 
        # are loaded at once (in this example 3 files)
        # self.chunk_size = chunk_size # how many batches are loaded at once in getitem
        self.nloc = nloc
        self.nlev = nlev
        # self.nloc = int(os.path.basename(self.filepath).split('_')[-1])
        # self.stateful = stateful
        self.refpres = np.array([7.83478113e-02,1.41108318e-01,2.52923297e-01,4.49250635e-01,
                    7.86346161e-01,1.34735576e+00,2.24477729e+00,3.61643148e+00,
                    5.61583643e+00,8.40325322e+00,1.21444894e+01,1.70168280e+01,
                    2.32107981e+01,3.09143463e+01,4.02775807e+01,5.13746323e+01,
                    6.41892284e+01,7.86396576e+01,9.46300920e+01,1.12091274e+02,
                    1.30977804e+02,1.51221318e+02,1.72673905e+02,1.95087710e+02,
                    2.18155935e+02,2.41600379e+02,2.65258515e+02,2.89122322e+02,
                    3.13312087e+02,3.38006999e+02,3.63373492e+02,3.89523338e+02,
                    4.16507922e+02,4.44331412e+02,4.72957206e+02,5.02291917e+02,
                    5.32152273e+02,5.62239392e+02,5.92149276e+02,6.21432841e+02,
                    6.49689897e+02,6.76656485e+02,7.02242188e+02,7.26498589e+02,
                    7.49537645e+02,7.71445217e+02,7.92234260e+02,8.11856675e+02,
                    8.30259643e+02,8.47450653e+02,8.63535902e+02,8.78715875e+02,
                    8.93246018e+02,9.07385213e+02,9.21354397e+02,9.35316717e+02,
                    9.49378056e+02,9.63599599e+02,9.78013432e+02,9.92635544e+02],dtype=np.float32)
        self.refpres_norm = (self.refpres-self.refpres.min())/(self.refpres.max()-self.refpres.min())*2 - 1

        if 'train' in self.filepath:
            self.is_validation = False
            print("Training dataset, path is: {}".format(self.filepath))
        else:
            self.is_validation = True
            print("Validation dataset, path is: {}".format(self.filepath))
        self.cuda = cuda

        self.add_refpres = add_refpres
        # batch_idx_expanded =  [0,1,2,3...ntime*1024]
        hdf = h5py.File(self.filepath, 'r')
        self.ntimesteps = hdf['input_lev'].shape[0]//self.nloc
        hdf.close()
        print("Number of locations {}; time steps {}".format(self.nloc, self.ntimesteps))
        # indices_all = list(np.arange(self.ntimesteps*self.nloc))
        # chunksize_tot = self.nloc*self.chunk_size
        # indices_chunked = self.chunkize(indices_all,chunksize_tot,False) 
        # self.hdf = h5py.File(self.filepath, 'r')

    def __len__(self):
        return self.ntimesteps*self.nloc
    
    def __getitem__(self, batch_indices):
        hdf = h5py.File(self.filepath, 'r')
        # hdf = self.hdf
        
        x_lay_b = hdf['input_lev'][batch_indices,:]
        x_sfc_b = hdf['input_sca'][batch_indices,:]
        y_lay_b = hdf['output_lev'][batch_indices,:]
        y_sfc_b = hdf['output_sca'][batch_indices,:]
        
        if self.add_refpres:
            dim0,dim1,dim2 = x_lay_b.shape
            # if self.norm=="minmax":
            refpres_norm = self.refpres_norm.reshape((1,-1,1))
            refpres_norm = np.repeat(refpres_norm, dim0,axis=0)
            #self.x[:,:,nx-1] = refpres_norm
            x_lay_b = np.concatenate((x_lay_b, refpres_norm),axis=2)
            # self.x  = torch.cat((self.x,refpres_norm),dim=3)
            del refpres_norm 

        hdf.close()

        x_lay_b = torch.from_numpy(x_lay_b)
        x_sfc_b = torch.from_numpy(x_sfc_b)
        y_lay_b = torch.from_numpy(y_lay_b)
        y_sfc_b = torch.from_numpy(y_sfc_b)

        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # x_lay_b = x_lay_b.to(device)
        # x_sfc_b = x_sfc_b.to(device)
        # y_lay_b = y_lay_b.to(device)
        # sp = sp.to(device)

        gc.collect()

        return x_lay_b, x_sfc_b, y_lay_b, y_sfc_b

def chunkize(filelist, chunk_size, shuffle_before_chunking=False, shuffle_after_chunking=True):
    import random
    # Takes a list, shuffles it (optional), and divides into chunks of length n
    # (no concept of batches within this function, chunk size is given in number of samples)
    def divide(filelist,chunk_size):
        # looping till length l
        for i in range(0, len(filelist), chunk_size): 
            yield filelist[i:i + chunk_size]  
    if shuffle_before_chunking:
        random.shuffle(filelist)
        # we need the indices to be sorted within a chunk because these indices
        # are used to index into the first dimension of a H5 file
        for i in range(filelist):
            filelist[i] = sorted(filelist[i])
            
    mylist = list(divide(filelist,chunk_size))
    if shuffle_after_chunking:
        random.shuffle(mylist)  
    return mylist


class BatchSampler(torch.utils.data.Sampler):
    def __init__(self, num_samples_per_chunk, num_samples, shuffle=False):
        self.num_samples_per_chunk = num_samples_per_chunk
        self.num_samples = num_samples
        indices_all = list(range(self.num_samples))
        print("Shuffling the indices: {}".format(shuffle))
        self.indices_chunked = chunkize(indices_all,self.num_samples_per_chunk,
                                        shuffle_before_chunking=False,
                                        shuffle_after_chunking=shuffle)
        #print("indices chunked [0]", self.indices_chunked[0])
        # one item is one chunk, consisting of chunk_factor*batch_size samples
        
    def __len__(self):
        return self.num_samples // self.batch_size

    def __iter__(self):
        return iter(self.indices_chunked)
        # for batch in self.indices_chunked:
        #     yield batch

In [63]:
shuffle_data = True 

train_locs = nloc = 384
batch_size = train_locs 

# To improve IO, which is a bottleneck, increase the batch size by a factor of chunk_factor and 
# load this many batches at once. These chunks then need to be manually split into batches 
# within the data iteration loop    

# chunk size in number of batches
#chunk_size = 72 # one day (3 time steps in an hour, 72 in a day)
#chunk_size = 360     
chunk_size = 720 # 10 days
# chunk size in number of elements
num_samples_per_chunk = chunk_size*batch_size

train_data = generator_xy(data_path, nloc=train_locs, add_refpres=add_refpres)

batch_sampler_tr = BatchSampler(num_samples_per_chunk,
                                num_samples=train_data.__len__(), shuffle=shuffle_data)

Training dataset, path is: /network/group/aopp/predict/HMC009_UKKONEN_CLIMSIM/ClimSim_data/ClimSim_low-res-expanded/train/preprocessed/train_first4months.h5
Number of locations 384; time steps 8637
Shuffling the indices: True


In [64]:
72*384

27648

In [65]:
from torch.utils.data import DataLoader


use_val = False 

num_workers = 2
prefetch_factor = 1
pin = False
persistent=False

train_loader = DataLoader(dataset=train_data, num_workers=num_workers, sampler=batch_sampler_tr, 
                          batch_size=None,batch_sampler=None,prefetch_factor=prefetch_factor, 
                          pin_memory=pin, persistent_workers=persistent)

if use_val:
    val_loader = DataLoader(dataset=val_data, num_workers=num_workers,sampler=batch_sampler_val,
                            batch_size=None,batch_sampler=None,prefetch_factor=prefetch_factor, 
                            pin_memory=pin, persistent_workers=persistent)


In [66]:
import gc
gc.collect()

1679

In [67]:
cuda = torch.cuda.is_available() 
device = torch.device("cuda" if cuda else "cpu")
print(device)
model = model.to(device)

if cuda:
    mp_autocast = True 
    # if torch.cuda.is_bf16_supported(): 
    #     dtype=torch.bfloat16 
    #     use_scaler = False
    # else:
    #     dtype=torch.float16
    #     use_scaler = True 
    dtype=torch.float16
    use_scaler = True 
else:
    mp_autocast = False
    use_scaler = False
    
    
if use_scaler:
    # scaler = torch.amp.GradScaler(autocast = True)
    scaler = torch.amp.GradScaler(device.type)

cuda


In [68]:
torch.cuda.is_bf16_supported()

True

In [69]:

def my_mse(y_true_lay, y_true_sfc, y_pred_lay, y_pred_sfc):
    mse1 = torch.mean(torch.square(y_pred_lay - y_true_lay))
    mse2 = torch.mean(torch.square(y_pred_sfc - y_true_sfc))
    return (mse1+mse2)/2

loss_fn = my_mse

In [70]:
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [71]:
from torchmetrics.regression import R2Score
import time
num_epochs = 2

autoregressive = False

if not autoregressive:
    timewindow = 1
    timestep_scheduling=False
else:
    timewindow = 3
    timestep_scheduling=True
    timestep_schedule = np.arange(num_epochs)
    timestep_schedule[:] = timewindow

    if timestep_scheduling:
        timestep_schedule[0:3] = 1
        timestep_schedule[3:4] = timewindow-1
        timestep_schedule[4:] = timewindow
        timestep_schedule[5:] = timewindow+1
        timestep_schedule[6:] = timewindow+2

    
use_wandb = False

class model_train_eval:
    def __init__(self, dataloader, model, autoregressive=True, train=True):
        super().__init__()
        self.loader = dataloader
        self.train = train
        self.report_freq = 800
        self.model = model 
        self.autoregressive = autoregressive
        if self.autoregressive:
            self.model.reset_states()
        self.metric_R2 =  R2Score().to(device) 
        self.metrics= {'loss': 0, 'mean_squared_error': 0,  # the latter is just MSE
                        'mean_absolute_error': 0, 'R2' : 0}

    def eval_one_epoch(self, epoch, timewindow=1):
        report_freq = self.report_freq
        running_loss = 0.0 
        epoch_loss = 0.0
        epoch_mse = 0.0; epoch_mae = 0.0
        t_comp =0 
        if self.autoregressive:
            preds_lay = []; preds_sfc = []
            targets_lay = []; targets_sfc = [] 
        t0_it = time.time()
        j = 0; k = 0; k2=2    
        if self.autoregressive:
            loss_update_start_index = 60
        else:
            loss_update_start_index = 0
        for i,data in enumerate(self.loader):
            inputs_lay_chunks, inputs_sfc_chunks, targets_lay_chunks, targets_sfc_chunks = data
            inputs_lay_chunks   = inputs_lay_chunks.to(device)
            inputs_sfc_chunks   = inputs_sfc_chunks.to(device)
            targets_sfc_chunks  = targets_sfc_chunks.to(device)
            targets_lay_chunks  = targets_lay_chunks.to(device)
            
            inputs_lay_chunks    = torch.split(inputs_lay_chunks, batch_size)
            inputs_sfc_chunks    = torch.split(inputs_sfc_chunks, batch_size)
            targets_sfc_chunks   = torch.split(targets_sfc_chunks, batch_size)
            targets_lay_chunks   = torch.split(targets_lay_chunks, batch_size)
         
            # to speed-up IO, we loaded chunks=many batches, which now need to be divided into batches
            for ichunk in range(len(inputs_lay_chunks)):
                inputs_lay = inputs_lay_chunks[ichunk]
                inputs_sfc = inputs_sfc_chunks[ichunk]
                target_lay = targets_lay_chunks[ichunk]
                target_sfc = targets_sfc_chunks[ichunk]

                tcomp0= time.time()
                    
                if mp_autocast:
                    with torch.autocast(device_type=device.type, dtype=dtype):
                        pred_lay, pred_sfc = self.model(inputs_lay, inputs_sfc)
                else:
                    pred_lay, pred_sfc = self.model(inputs_lay, inputs_sfc)
                    
                if self.autoregressive:
                    # In the autoregressive training case are gathering many time steps before computing loss
                    preds_lay.append(pred_lay)
                    preds_sfc.append(pred_sfc)
                    targets_lay.append(target_lay)
                    targets_sfc.append(target_sfc)
                else:
                    preds_lay = pred_lay
                    preds_sfc = pred_sfc 
                    targets_lay = target_lay
                    targets_sfc = target_sfc
                    
                    
                if (not self.autoregressive) or (self.autoregressive and (j+1) % timewindow==0):
            
                    if self.autoregressive:
                        preds_lay   = torch.stack(preds_lay)
                        preds_sfc   = torch.stack(preds_sfc)
                        targets_lay = torch.stack(targets_lay)
                        targets_sfc = torch.stack(targets_sfc)
        
                    if mp_autocast:
                        with torch.autocast(device_type=device.type, dtype=dtype):
                            loss = loss_fn(targets_lay, targets_sfc, preds_lay, preds_sfc)
                    else:
                        loss = loss_fn(targets_lay, targets_sfc, preds_lay, preds_sfc)
            
                    if self.train:
                        if use_scaler:
                            scaler.scale(loss).backward()
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            loss.backward()       
                            optimizer.step()
            
                        optimizer.zero_grad()
                            
                    running_loss    += loss.item()
                    #mae             = metrics.mean_absolute_error(targets_lay, preds_lay)
                    if j>loss_update_start_index:
                        with torch.no_grad():
                            epoch_loss      += loss.item()
                            #epoch_energy    += energy.item()
                            #epoch_mseu      += mse.item()
                            #epoch_mae       += mae.item()
                        
                           # yto, ypo =  denorm_func(targets_lay, preds_lay)
                            # -------------- TO-DO:  DE-NORM OUTPUT --------------
                            yto, ypo = targets_lay, preds_lay
                            self.metric_R2.update(ypo.reshape((-1,ny)), yto.reshape((-1,ny)))
                           # if track_ks:
                           #     if (j+1) % max(timewindow*4,12)==0:
                           #         epoch_ks += kolmogorov_smirnov(yto,ypo).item()
                           #         k2 += 1
                            k += 1
                    if self.autoregressive:
                        preds_lay = []; preds_sfc = []
                        targets_lay = []; targets_sfc = [] 
                    if self.autoregressive: 
                        model.detach_states()
                
                t_comp += time.time() - tcomp0
                # # print statistics 
                if j % report_freq == (report_freq-1): # print every 200 minibatches
                    elaps = time.time() - t0_it
                    running_loss = running_loss / (report_freq/timewindow)
                    #running_energy = running_energy / (report_freq/timewindow)
                    r2raw = self.metric_R2.compute()
                    print("[{:d}, {:d}] Loss: {:.2e}  runningR2: {:.2f}, elapsed {:.1f}s (compute {:.1f})" .format(epoch + 1, 
                                                    j+1, running_loss, r2raw, elaps, t_comp))
                    running_loss = 0.0
                    running_energy = 0.0
                    t0_it = time.time()
                    t_comp = 0
                j += 1

        self.metrics['loss'] =  epoch_loss / k
        self.metrics['mean_squared_error'] = epoch_loss / k

        #self.metrics['energymetric'] = epoch_energy / k
        #self.metrics['mean_absolute_error'] = epoch_mae / k
        #self.metrics['ks'] =  epoch_ks / k2
        self.metrics['R2'] = self.metric_R2.compute()
        self.metric_R2.reset()
        if self.autoregressive:
            self.model.reset_states()
        
        datatype = "TRAIN" if self.train else "VAL"
        print('Epoch {} {} loss: {:.2e}  MSE: {:.2e}  R2: {:.2f}'.format(epoch+1, datatype, self.metrics['loss'], 
                                                                        self.metrics['mean_squared_error'], 
                                                                        self.metrics['R2']))

    if cuda: torch.cuda.empty_cache()
    gc.collect()


In [72]:
train_runner = model_train_eval(train_loader, model, autoregressive, train=True)



for epoch in range(num_epochs):
    t0 = time.time()
    
    if timestep_scheduling:
        timewindoww=timestep_schedule[epoch]            
    else:
        timewindoww=timewindow
        
    print("Epoch {} Training rollout timesteps: {} ".format(epoch+1, timewindoww))
    train_runner.eval_one_epoch(epoch, timewindoww)
    
    if use_wandb: wandb.log(train_runner.metrics)
    
    if use_val:
        if epoch%2:
            print("VALIDATION..")
            val_runner.eval_one_epoch(epoch, timewindoww)

            losses_val = {"val_"+k: v for k, v in val_runner.metrics.items()}
            if use_wandb: wandb.log(losses_val)

            val_loss = losses_val["val_loss"]

            # MODEL CHECKPOINT IF VALIDATION LOSS IMPROVED
            if save_model and val_loss < best_val_loss:
              torch.save({
                          'epoch': epoch,
                          'model_state_dict': model.state_dict(),
                          'optimizer_state_dict': optimizer.state_dict(),
                          'val_loss': val_loss,
                          }, SAVE_PATH)  
              best_val_loss = val_loss 
              
    print('Epoch {}/{} complete, took {:.2f} seconds, autoreg window was {}'.format(epoch+1,num_epochs,time.time() - t0,timewindoww))

Epoch 1 Training rollout timesteps: 1 
[1, 800] Loss: 2.79e-03  runningR2: 0.17, elapsed 13.7s (compute 4.0)
[1, 1600] Loss: 2.71e-03  runningR2: 0.18, elapsed 6.2s (compute 4.0)
[1, 2400] Loss: 2.69e-03  runningR2: 0.18, elapsed 4.3s (compute 4.0)
[1, 3200] Loss: 2.74e-03  runningR2: 0.18, elapsed 6.1s (compute 4.0)
[1, 4000] Loss: 2.68e-03  runningR2: 0.18, elapsed 4.3s (compute 4.0)
[1, 4800] Loss: 2.75e-03  runningR2: 0.18, elapsed 4.8s (compute 3.8)
[1, 5600] Loss: 2.72e-03  runningR2: 0.18, elapsed 4.3s (compute 3.9)
[1, 6400] Loss: 2.72e-03  runningR2: 0.18, elapsed 5.6s (compute 4.0)
[1, 7200] Loss: 2.66e-03  runningR2: 0.18, elapsed 5.5s (compute 4.0)
[1, 8000] Loss: 2.65e-03  runningR2: 0.19, elapsed 4.3s (compute 4.0)
Epoch 1 TRAIN loss: 2.71e-03  MSE: 2.71e-03  R2: 0.19
Epoch 1/2 complete, took 62.49 seconds, autoreg window was 1
Epoch 2 Training rollout timesteps: 1 
[2, 800] Loss: 2.63e-03  runningR2: 0.19, elapsed 12.9s (compute 3.8)
[2, 1600] Loss: 2.54e-03  runningR2: 