## Pytorch version of ConvLSTM

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import os
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from multiprocessing import Process

In [2]:
import imp
import sys

sys.path.insert(0, '../../src')
from utils import df_to_xarray,read_xarray, get_point_prediction, custom_rmse

sys.path.insert(0, '../../src/preprocess')
from data_preprocess import preprocess_image_reduced,preprocess_images_nfp, inverse_scale_frame

  import imp
2022-05-31 15:09:38.528396: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [28]:
dir_name="../../data/data1"
val_dir_name="../../data/data2"

data,pco2 = preprocess_images_nfp(dir_name)
data_socat, pco2_socat = preprocess_images_nfp(dir_name, socat = True)

# val_data,val_pco2 = preprocess_images_nfp(val_dir_name,"035")
# val_data_socat,val_pco2_socat = preprocess_images_nfp(val_dir_name,"035",socat=True)

X_index=np.lib.stride_tricks.sliding_window_view(range(421),3) 

y=np.expand_dims(pco2[X_index][1:],axis=4)
X=data[X_index][:-1]

# val_y=np.expand_dims(val_pco2[X_index][1:],axis=4)
# val_X=val_data[X_index][:-1]


print(X.shape, y.shape)

INPUT_SHAPE = X[0].shape

(418, 3, 180, 360, 6) (418, 3, 180, 360, 1)


In [4]:
import torch.nn as nn

class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels= 4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        
        
        input_tensor = input_tensor
        h_cur = h_cur
        
        # print("input_tensor:", input_tensor.size())
        # print("h_cur:",h_cur.size())
        # print("c_hur:",c_cur.size())
        # print("input_dimension:",self.input_dim)
        # print("hidden_dimension:",self.hidden_dim)

        combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
        #print("combined:", combined.size())
        combined_conv = self.conv(combined)
        #print("combined_conv:", combined_conv.size())


        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = nn.ELU()(cc_g)
        
        c_next = f * c_cur + i * g
        h_next = o * nn.ELU()(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))



In [5]:
class EncoderDecoderConvLSTM(nn.Module):
    def __init__(self, nf, in_chan):
        super(EncoderDecoderConvLSTM, self).__init__()

        """ ARCHITECTURE 

        # Encoder (ConvLSTM)
        # Encoder Vector (final hidden state of encoder)
        # Decoder (ConvLSTM) - takes Encoder Vector as input
        # Decoder (3D CNN) - produces regression predictions for our model

        """
        self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan,
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True)

        self.encoder_2_convlstm = ConvLSTMCell(input_dim=nf,
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True)
        
        self.batch_norm = torch.nn.BatchNorm2d(nf)

        self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf,  # nf + 1
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True)

        self.decoder_2_convlstm = ConvLSTMCell(input_dim=nf,
                                               hidden_dim=nf,
                                               kernel_size=(3, 3),
                                               bias=True)

        self.decoder_CNN = nn.Conv3d(in_channels=nf,
                                     out_channels=3,
                                     kernel_size=(1, 3, 3),
                                     padding=(0, 1, 1))


    def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):

        outputs = []

        # encoder
        for t in range(seq_len):
            h_t, c_t = self.encoder_1_convlstm(input_tensor=x[:, t, :, :],
                                               cur_state=[h_t, c_t])  # we could concat to provide skip conn here
            
            h_t2, c_t2 = self.encoder_2_convlstm(input_tensor=h_t,
                                                 cur_state=[h_t2, c_t2])  # we could concat to provide skip conn here

        # encoder_vector
        encoder_vector = h_t2

        # decoder
        for t in range(future_step):
            h_t3, c_t3 = self.decoder_1_convlstm(input_tensor=encoder_vector,
                                                 cur_state=[h_t3, c_t3])  # we could concat to provide skip conn here
            h_t4, c_t4 = self.decoder_2_convlstm(input_tensor=h_t3,
                                                 cur_state=[h_t4, c_t4])  # we could concat to provide skip conn here
            encoder_vector = h_t4
            outputs += [h_t4]  # predictions

        outputs = torch.stack(outputs, 1)
        outputs = outputs.permute(0, 2, 1, 3, 4)
        outputs = self.decoder_CNN(outputs)
        outputs = torch.nn.ELU()(outputs)

        return outputs

    def forward(self, x, future_seq=0, hidden_state=None):

        """
        Parameters
        ----------
        input_tensor:
            5-D Tensor of shape (b, t, c, h, w)        #   batch, time, channel, height, width
        """

        # find size of different input dimensions
        b, seq_len, _, h, w = x.size()

        # initialize hidden states
        h_t, c_t = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t2, c_t2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t3, c_t3 = self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
        h_t4, c_t4 = self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        # autoencoder forward
        outputs = self.autoencoder(x, seq_len, future_seq, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4)

        return outputs

In [6]:
from torch.utils.data import TensorDataset
from torchvision import models
from torchsummary import summary
from pytorch_lightning.callbacks import ModelCheckpoint


checkpoint_callback = ModelCheckpoint(
    "../../models",
    save_top_k=1,
    monitor="loss",
    mode="min",
    verbose=True)

batch_size = 16
lr= 0.0003
epochs = 1



In [7]:
from pytorch_lightning import Trainer
from multiprocessing import Process
import pytorch_lightning as pl

class CESMLighting(pl.LightningModule):

    def __init__(self, hparams=None, model=None, X= None, y=None):
        super(CESMLighting, self).__init__()
        # default config
        self.normalize = False
        self.model = model

        # logging config
        self.log_images = True

        # Training config
        self.batch_size = 16
        self.epochs=100
        self.n_steps_past = 1
        self.n_steps_ahead = 1
        self.X = X
        self.y = y

    def customRmse(self,y_true, y_pred):
        """
        custom_rmse(y_true, y_pred)
        calculates root square mean value with focusing only on the ocean
        """
        y_pred = y_pred[(y_true != 0) & (y_true != 0.0)]
        y_true = y_true[(y_true != 0) & (y_true != 0.0)]
        
        y_pred = torch.Tensor.cuda(y_pred)
        y_true = y_true.type(y_pred.dtype)
        loss = torch.sqrt(torch.mean(torch.square(torch.subtract(y_pred,y_true))))
        return loss

    def forward(self, x):
        x = x.to(device='cuda')
        output = self.model(x, future_seq=self.n_steps_ahead)
        return output

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.permute(0, 1, 4, 2, 3)
        y = y.squeeze().cuda()
        y_hat = self.forward(x).squeeze()  
        
        #print("y: {} y_hat: {}".format(y.shape,y_hat.shape))

        loss = self.customRmse(y, y_hat)

        # save learning_rate
        lr_saved = self.trainer.optimizers[0].param_groups[-1]['lr']
        lr_saved = torch.scalar_tensor(lr_saved).cuda()

        tensorboard_logs = {'train_rmse_loss': loss,
                            'learning_rate': lr_saved}
        if self.global_step % 10 == 0:
            print("Current epoch {} loss: {}".format(self.current_epoch, loss.item()))
        
        return {'loss': loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        x = x.permute(0, 1, 4, 2, 3)
        y_hat = self.forward(x)
        y = y.squeeze().cuda()
        y_hat = self.forward(x).squeeze() 
        return {'test_loss': self.criterion(y_hat, y)}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(),lr = lr)
    
    def train_dataloader(self):
        tensor_X = torch.Tensor(self.X) # transform to torch tensor
        tensor_y = torch.Tensor(self.y)
        cesm = TensorDataset(tensor_X,tensor_y) # create your datset

        train_loader = DataLoader(cesm,batch_size=batch_size,shuffle=False) 
        return train_loader

    def test_dataloader(self):
        tensor_X = torch.Tensor(self.X) # transform to torch tensor
        tensor_y = torch.Tensor(self.y)
        cesm = TensorDataset(tensor_X,tensor_y) # create your datset

        test_loader = DataLoader(cesm,batch_size=batch_size,shuffle=False) 
        return test_loader




In [8]:
torch.cuda.empty_cache()
conv_lstm_model = EncoderDecoderConvLSTM(nf=32, in_chan=6)
model = CESMLighting(model=conv_lstm_model,X=X,y=y)    

if torch.cuda.is_available():
    model.cuda()


trainer = Trainer(max_epochs=35,checkpoint_callback=checkpoint_callback)
trainer.fit(model)


GPU available: True, used: False
TPU available: False, using: 0 TPU cores
Set SLURM handle signals.

  | Name  | Type                   | Params
-------------------------------------------------
0 | model | EncoderDecoderConvLSTM | 266 K 


Training: 0it [00:00, ?it/s]

Current epoch 0 loss: 77.02783966064453
Current epoch 0 loss: 79.63282775878906
Current epoch 0 loss: 79.74320983886719
Current epoch 0 loss: 57.55128860473633
Current epoch 0 loss: 24.480430603027344
Current epoch 0 loss: 16.296785354614258



Epoch 00000: loss reached 15.91725 (best 15.91725), saving model to ../../models/epoch=0.ckpt as top 2


Current epoch 1 loss: 19.806032180786133
Current epoch 1 loss: 13.998576164245605
Current epoch 1 loss: 12.664587020874023
Current epoch 1 loss: 12.021626472473145
Current epoch 1 loss: 11.473539352416992



Epoch 00001: loss reached 12.49705 (best 12.49705), saving model to ../../models/epoch=1.ckpt as top 2


Current epoch 2 loss: 21.973155975341797
Current epoch 2 loss: 12.41552734375
Current epoch 2 loss: 11.047698020935059
Current epoch 2 loss: 12.335343360900879
Current epoch 2 loss: 11.971892356872559



Epoch 00002: loss reached 11.06072 (best 11.06072), saving model to ../../models/epoch=2.ckpt as top 2


Current epoch 2 loss: 11.060720443725586
Current epoch 3 loss: 12.800847053527832
Current epoch 3 loss: 11.979772567749023
Current epoch 3 loss: 11.38894271850586
Current epoch 3 loss: 10.439363479614258
Current epoch 3 loss: 12.33543586730957



Epoch 00003: loss reached 10.50767 (best 10.50767), saving model to ../../models/epoch=3.ckpt as top 2


Current epoch 4 loss: 15.611891746520996
Current epoch 4 loss: 12.915993690490723
Current epoch 4 loss: 11.577649116516113
Current epoch 4 loss: 9.777384757995605
Current epoch 4 loss: 10.01101016998291



Epoch 00004: loss reached 9.91624 (best 9.91624), saving model to ../../models/epoch=4.ckpt as top 2


Current epoch 5 loss: 13.112822532653809
Current epoch 5 loss: 10.136364936828613
Current epoch 5 loss: 8.221946716308594
Current epoch 5 loss: 9.056938171386719
Current epoch 5 loss: 8.926743507385254
Current epoch 5 loss: 9.070084571838379



Epoch 00005: loss reached 9.24798 (best 9.24798), saving model to ../../models/epoch=5.ckpt as top 2


Current epoch 6 loss: 8.369848251342773
Current epoch 6 loss: 9.246298789978027
Current epoch 6 loss: 9.577816009521484
Current epoch 6 loss: 8.127857208251953
Current epoch 6 loss: 8.659805297851562



Epoch 00006: loss reached 8.45080 (best 8.45080), saving model to ../../models/epoch=6.ckpt as top 2


Current epoch 7 loss: 10.993430137634277
Current epoch 7 loss: 10.898250579833984
Current epoch 7 loss: 10.066965103149414
Current epoch 7 loss: 7.9469804763793945
Current epoch 7 loss: 9.129315376281738



Epoch 00007: loss reached 8.63018 (best 8.45080), saving model to ../../models/epoch=7.ckpt as top 2


Current epoch 7 loss: 8.630182266235352
Current epoch 8 loss: 7.314732074737549
Current epoch 8 loss: 7.682033538818359
Current epoch 8 loss: 7.190796852111816
Current epoch 8 loss: 7.248157024383545
Current epoch 8 loss: 7.866811275482178



Epoch 00008: loss reached 7.71010 (best 7.71010), saving model to ../../models/epoch=8.ckpt as top 2


Current epoch 9 loss: 7.413439750671387
Current epoch 9 loss: 6.698433876037598
Current epoch 9 loss: 7.198547840118408
Current epoch 9 loss: 7.007115364074707
Current epoch 9 loss: 6.971301078796387



Epoch 00009: loss reached 7.65654 (best 7.65654), saving model to ../../models/epoch=9.ckpt as top 2


Current epoch 10 loss: 7.354216575622559
Current epoch 10 loss: 6.863980770111084
Current epoch 10 loss: 6.6054582595825195
Current epoch 10 loss: 7.089925765991211
Current epoch 10 loss: 6.842605113983154
Current epoch 10 loss: 6.895432949066162



Epoch 00010: loss reached 7.39292 (best 7.39292), saving model to ../../models/epoch=10.ckpt as top 2


Current epoch 11 loss: 8.349703788757324
Current epoch 11 loss: 6.62127685546875
Current epoch 11 loss: 6.462490081787109
Current epoch 11 loss: 7.434760093688965
Current epoch 11 loss: 7.1342620849609375



Epoch 00011: loss  was not in top 2


Current epoch 12 loss: 6.7259907722473145
Current epoch 12 loss: 6.808413505554199
Current epoch 12 loss: 6.531578540802002
Current epoch 12 loss: 6.721225738525391
Current epoch 12 loss: 7.474700927734375



Epoch 00012: loss reached 7.21664 (best 7.21664), saving model to ../../models/epoch=12.ckpt as top 2


Current epoch 12 loss: 7.216640472412109
Current epoch 13 loss: 6.307109355926514
Current epoch 13 loss: 6.680541038513184
Current epoch 13 loss: 6.426871299743652
Current epoch 13 loss: 6.414636611938477
Current epoch 13 loss: 7.087423801422119



Epoch 00013: loss  was not in top 2


Current epoch 14 loss: 6.437928199768066
Current epoch 14 loss: 6.0503034591674805
Current epoch 14 loss: 6.535053730010986
Current epoch 14 loss: 6.315428256988525
Current epoch 14 loss: 6.3279547691345215



Epoch 00014: loss reached 6.86223 (best 6.86223), saving model to ../../models/epoch=14.ckpt as top 2


Current epoch 15 loss: 9.83947467803955
Current epoch 15 loss: 6.6800689697265625
Current epoch 15 loss: 6.407829761505127
Current epoch 15 loss: 6.543723106384277
Current epoch 15 loss: 6.655492305755615
Current epoch 15 loss: 6.479568958282471



Epoch 00015: loss  was not in top 2


Current epoch 16 loss: 6.495143413543701
Current epoch 16 loss: 6.256110191345215
Current epoch 16 loss: 6.121934413909912
Current epoch 16 loss: 6.617790699005127
Current epoch 16 loss: 6.258050441741943



Epoch 00016: loss reached 6.67916 (best 6.67916), saving model to ../../models/epoch=16.ckpt as top 2


Current epoch 17 loss: 5.936398506164551
Current epoch 17 loss: 6.232260227203369
Current epoch 17 loss: 5.890419960021973
Current epoch 17 loss: 5.833022594451904
Current epoch 17 loss: 6.564389705657959



Epoch 00017: loss reached 6.38574 (best 6.38574), saving model to ../../models/epoch=17.ckpt as top 2


Current epoch 17 loss: 6.385737895965576
Current epoch 18 loss: 5.38316535949707
Current epoch 18 loss: 5.833261966705322
Current epoch 18 loss: 5.439970016479492
Current epoch 18 loss: 5.968562126159668
Current epoch 18 loss: 6.652337551116943



Epoch 00018: loss  was not in top 2


Current epoch 19 loss: 7.481805324554443
Current epoch 19 loss: 5.355412006378174
Current epoch 19 loss: 5.7155351638793945
Current epoch 19 loss: 5.573824882507324
Current epoch 19 loss: 5.631882667541504



Epoch 00019: loss reached 5.59751 (best 5.59751), saving model to ../../models/epoch=19.ckpt as top 2


Current epoch 20 loss: 9.748916625976562
Current epoch 20 loss: 6.480207920074463
Current epoch 20 loss: 6.205139636993408
Current epoch 20 loss: 6.2918901443481445
Current epoch 20 loss: 5.404109001159668
Current epoch 20 loss: 5.244962215423584



Epoch 00020: loss reached 5.68052 (best 5.59751), saving model to ../../models/epoch=20.ckpt as top 2


Current epoch 21 loss: 8.617737770080566
Current epoch 21 loss: 5.186371326446533
Current epoch 21 loss: 4.982203960418701
Current epoch 21 loss: 5.077139854431152
Current epoch 21 loss: 5.630277633666992



Epoch 00021: loss reached 5.19990 (best 5.19990), saving model to ../../models/epoch=21.ckpt as top 2


Current epoch 22 loss: 4.783489227294922
Current epoch 22 loss: 6.319832801818848
Current epoch 22 loss: 5.917200088500977
Current epoch 22 loss: 5.520890235900879
Current epoch 22 loss: 5.173210620880127



Epoch 00022: loss  was not in top 2


Current epoch 22 loss: 5.684325695037842
Current epoch 23 loss: 9.59655475616455
Current epoch 23 loss: 5.119865417480469
Current epoch 23 loss: 7.115817070007324
Current epoch 23 loss: 6.105560779571533
Current epoch 23 loss: 5.560757637023926



Epoch 00023: loss reached 5.51364 (best 5.19990), saving model to ../../models/epoch=23.ckpt as top 2


Current epoch 24 loss: 5.084414482116699
Current epoch 24 loss: 5.826329708099365
Current epoch 24 loss: 5.375465393066406
Current epoch 24 loss: 5.209454536437988
Current epoch 24 loss: 4.528156757354736



Epoch 00024: loss reached 4.98946 (best 4.98946), saving model to ../../models/epoch=24.ckpt as top 2


Current epoch 25 loss: 7.0077409744262695
Current epoch 25 loss: 5.300832748413086
Current epoch 25 loss: 4.884790897369385
Current epoch 25 loss: 4.76303243637085
Current epoch 25 loss: 4.73054838180542
Current epoch 25 loss: 4.7788004875183105



Epoch 00025: loss reached 5.01986 (best 4.98946), saving model to ../../models/epoch=25.ckpt as top 2


Current epoch 26 loss: 6.848832607269287
Current epoch 26 loss: 5.382123947143555
Current epoch 26 loss: 4.258955001831055
Current epoch 26 loss: 4.698692798614502
Current epoch 26 loss: 4.581749439239502



Epoch 00026: loss  was not in top 2


Current epoch 27 loss: 5.365386486053467
Current epoch 27 loss: 5.861317157745361
Current epoch 27 loss: 4.721887111663818
Current epoch 27 loss: 4.674215793609619
Current epoch 27 loss: 4.785220146179199



Epoch 00027: loss reached 4.81152 (best 4.81152), saving model to ../../models/epoch=27.ckpt as top 2


Current epoch 27 loss: 4.811515808105469
Current epoch 28 loss: 4.124114513397217
Current epoch 28 loss: 4.217257022857666
Current epoch 28 loss: 3.9831297397613525
Current epoch 28 loss: 4.225827693939209
Current epoch 28 loss: 4.642294883728027



Epoch 00028: loss reached 4.63099 (best 4.63099), saving model to ../../models/epoch=28_v0.ckpt as top 2


Current epoch 29 loss: 4.638250827789307
Current epoch 29 loss: 4.3092217445373535
Current epoch 29 loss: 4.108780384063721
Current epoch 29 loss: 4.1348772048950195
Current epoch 29 loss: 4.0517754554748535



Epoch 00029: loss reached 4.47527 (best 4.47527), saving model to ../../models/epoch=29_v0.ckpt as top 2


Current epoch 30 loss: 6.928339004516602
Current epoch 30 loss: 5.3869171142578125
Current epoch 30 loss: 4.092748165130615
Current epoch 30 loss: 4.67457389831543
Current epoch 30 loss: 4.5343804359436035
Current epoch 30 loss: 4.555964469909668



Epoch 00030: loss reached 4.60383 (best 4.47527), saving model to ../../models/epoch=30.ckpt as top 2


Current epoch 31 loss: 6.018604278564453
Current epoch 31 loss: 4.7108988761901855
Current epoch 31 loss: 4.130916118621826
Current epoch 31 loss: 4.5749711990356445
Current epoch 31 loss: 3.9530558586120605



Epoch 00031: loss reached 4.50015 (best 4.47527), saving model to ../../models/epoch=31_v0.ckpt as top 2


Current epoch 32 loss: 4.54709005355835
Current epoch 32 loss: 5.141644477844238
Current epoch 32 loss: 4.771138668060303
Current epoch 32 loss: 4.258011817932129
Current epoch 32 loss: 4.168031215667725



Epoch 00032: loss reached 4.31644 (best 4.31644), saving model to ../../models/epoch=32_v0.ckpt as top 2


Current epoch 32 loss: 4.316435813903809
Current epoch 33 loss: 3.9482688903808594
Current epoch 33 loss: 3.87135910987854
Current epoch 33 loss: 3.902240514755249
Current epoch 33 loss: 4.003685474395752
Current epoch 33 loss: 4.30327033996582



Epoch 00033: loss reached 4.31168 (best 4.31168), saving model to ../../models/epoch=33.ckpt as top 2


Current epoch 34 loss: 4.266963481903076
Current epoch 34 loss: 4.3668365478515625
Current epoch 34 loss: 4.030972003936768
Current epoch 34 loss: 4.083408832550049
Current epoch 34 loss: 3.8343653678894043



Epoch 00034: loss  was not in top 2


AttributeError: 'ellipsis' object has no attribute '__module__'

In [11]:
PATH ='../../models/pytorch_convlstm'
torch.save(model.state_dict(), PATH)

In [12]:
conv_lstm_model = EncoderDecoderConvLSTM(nf=32, in_chan=6)
model = CESMLighting(model=conv_lstm_model,X=X,y=y)   
model.load_state_dict(torch.load(PATH))
if torch.cuda.is_available():
    model.cuda()

model.eval()

CESMLighting(
  (model): EncoderDecoderConvLSTM(
    (encoder_1_convlstm): ConvLSTMCell(
      (conv): Conv2d(38, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (encoder_2_convlstm): ConvLSTMCell(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (decoder_1_convlstm): ConvLSTMCell(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (decoder_2_convlstm): ConvLSTMCell(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (decoder_CNN): Conv3d(32, 3, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
  )
)

In [56]:
tensor_X= torch.Tensor(X).cuda()
tensor_y = torch.Tensor(y).cuda()
cesm = TensorDataset(tensor_X,tensor_y) # create your datset

data_loader = DataLoader(cesm,batch_size=1,shuffle=False) 

y_pred=np.zeros(shape=([418,3,180, 360,1]))

for batch_idx, samples in enumerate(data_loader):
    s= samples[0]
    s = s.permute(0, 1, 4, 2, 3)
    #print(batch_idx, s.size())
    yp = model(s)
    yp = yp.permute(0,1,3,4,2)
    yp= yp.cpu().detach().numpy()
    y_pred[batch_idx]=yp

In [60]:
y_pred[y==0]=0.0
pred=y_pred

In [61]:
chl,mld,sss,sst,u10,xco2,icefrac,patm,pco2t2 = read_xarray(dir_name)

y_true,y_pred = inverse_scale_frame(pred,pco2t2.pCO2.data)

In [62]:
y_pred_socat = np.copy(y_pred)
y_true_socat=np.expand_dims(pco2t2.pCO2_socat.data[X_index][1:],axis=4)
y_true_socat = np.nan_to_num(y_true_socat)
y_pred_socat[y_true_socat==0]=0.0

In [64]:
print("Full RMSE score:")
a=custom_rmse(y_pred[:,:1],y_true[:,:1])
print(a.numpy())

print("SOCAT RMSE score:")
b=custom_rmse(y_pred_socat[:,:1],y_true_socat[:,:1])
print(b.numpy())

Full RMSE score:
8.116834158885355
SOCAT RMSE score:
9.008042187833174
