# Basic ConvLSTM
### Overview
Basic introduction to the ConvLSTM layer implemented by <a href="https://github.com/ndrplz">nrdplz</a> from <a href="https://github.com/ndrplz/ConvLSTM_pytorch">this</a> repository.

### Dependencies
PyTorch 1.5.0 with CUDA 10.1

In [26]:
import os
import sys
from tqdm.notebook import tqdm
import numpy as np
from datetime import datetime

import torch
from torch import nn
from torch.optim import lr_scheduler
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from BasicConvLSTM.convlstm import *
from earlystopping import EarlyStopping
from dataset_utils import GenericDataset, SpatiotemporalDataset

torch.cuda.empty_cache()

### Example Usage of ConvLSTM Layer
Parameters:

        input_dim: Number of channels in input
        
        hidden_dim: Number of hidden channels
        
        kernel_size: Size of kernel in convolutions
        
        num_layers: Number of LSTM layers stacked on each other
        
        batch_first: Whether or not dimension 0 is the batch or not
        
        bias: Bias or no bias in Convolution
        
        return_all_layers: Return the list of computations for all layers
        
        Note: Will do same padding.

    Input:
    
        A tensor of size B, T, C, H, W or T, B, C, H, W
        
    Output:
    
        A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
        
            0 - layer_output_list is the list of lists of length T of each output
            
            1 - last_state_list is the list of last states
            
                    each element of the list is a tuple (h, c) for hidden state and memory
                    
    Example:
    
        >> x = torch.rand((32, 10, 64, 128, 128))
        
        >> convlstm = ConvLSTM(64, 16, (3, 3), 1, True, True, False)
        
        >> _, last_states = convlstm(x)
        
        >> h = last_states[0][0]  # 0 for layer index, 0 for h index

In [27]:
x = torch.rand((2, 10, 3, 32, 32))
convlstm = ConvLSTM(3, [16, 16, 16], (3, 3), 3, True, True, True)
layer_output_list, last_state_list = convlstm(x)

print("Length of layer output list: {}".format(len(layer_output_list)))
print("Length of last state list: {}".format(len(last_state_list)))

Length of layer output list: 3
Length of last state list: 3


### Training on Moving MNIST with Basic ConvLSTM

In [28]:
N_SAMPLES = 10000
N_CHANNELS = 1
IM_HEIGHT = 64
IM_WIDTH = 64
FORECAST_HORIZON = 10
LAGS = 10
MNIST = True

batch_size = 4
lr = 1e-4
frames_input = 10
frames_output = 10
epochs = 500

TIMESTAMP = str(datetime.now())

save_dir = './save_model/' + TIMESTAMP

if MNIST:
    
    X = np.load("data/mnist_test_seq.npy").swapaxes(0,1)
    X = np.expand_dims(X, 2)
    trainDataset = GenericDataset(X[0:700], FORECAST_HORIZON, LAGS)
    validDataset = GenericDataset(X[700:], FORECAST_HORIZON, LAGS)
    
else:
    
    X = np.random.normal(0, 1, (N_SAMPLES, N_CHANNELS, IM_HEIGHT, IM_WIDTH))
    trainDataset = SpatiotemporalDataset(X, FORECAST_HORIZON, LAGS)
    validDataset = SpatiotemporalDataset(X, FORECAST_HORIZON, LAGS)

In [29]:
trainLoader = torch.utils.data.DataLoader(trainDataset,
                                          batch_size=batch_size,
                                          shuffle=False)
validLoader = torch.utils.data.DataLoader(validDataset,
                                          batch_size=batch_size,
                                          shuffle=False)

In [65]:
convlstm = ConvLSTM(1, [64, 1], (3, 3), 2, True, True, True)

layer_output_list, last_state_list = convlstm(torch.Tensor(inputs).unsqueeze(0))

In [66]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using {}".format(device))

Using cuda:0


In [67]:
def train(net, device):
    '''
    main function to run the training
    '''

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, epochs + 1):
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar) in enumerate(t):
            inputs = inputVar.to(device).float()  # B,S,C,H,W
            label = targetVar.to(device).float()  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred, _ = net(inputs)  # B,S,C,H,W
            pred = pred[-1]
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device).float()
                label = targetVar.to(device).float()
                pred, _ = net(inputs)
                pred = pred[-1]
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })

        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)

In [68]:
train(convlstm, device)

HBox(children=(FloatProgress(value=0.0, max=175.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=2325.0), HTML(value='')))

KeyboardInterrupt: 