# Advanced ConvLSTM
### Overview
Larger, more complex Convolutional Long Short-Term Memory Neural Network (ConvLSTM) using convolutions inside LSTM cells for more effective spatiotemporal forecasting. This architecture is based on the ConvLSTM from <a href="https://arxiv.org/abs/1506.04214">*Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting* by Shi et al. (2015)</a>.

*See <a href="https://paperswithcode.com/method/convlstm">here</a> for latest papers using ConvLSTMs.*

### Dependencies
PyTorch 1.5.0 with CUDA 10.1

In [1]:
import os
import sys
from tqdm.notebook import tqdm
import numpy as np
from datetime import datetime

import torch
from torch import nn
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch.utils.data import DataLoader

import logging
from AdvancedConvRNN.generate_model import *
from dataset_utils import GenericDataset, SpatiotemporalDataset

In [2]:
N_SAMPLES = 10000
N_CHANNELS = 1
IM_HEIGHT = 64
IM_WIDTH = 64
FORECAST_HORIZON = 10
LAGS = 10
MNIST = True

if MNIST:
    
    X = np.load("data/mnist_test_seq.npy").swapaxes(0,1)
    X = np.expand_dims(X, 2)
    trainDataset = GenericDataset(X[0:700], FORECAST_HORIZON, LAGS)
    validDataset = GenericDataset(X[700:], FORECAST_HORIZON, LAGS)
    
else:
    
    X = np.random.normal(0, 1, (N_SAMPLES, N_CHANNELS, IM_HEIGHT, IM_WIDTH))
    trainDataset = SpatiotemporalDataset(X, FORECAST_HORIZON, LAGS)
    validDataset = SpatiotemporalDataset(X, FORECAST_HORIZON, LAGS)

##### Set hyperparameters and random seeds

In [10]:
TIMESTAMP = str(datetime.now())

batch_size = 1
lr = 1e-4
frames_input = 10
frames_output = 10
epochs = 500

random_seed = 1996
np.random.seed(random_seed)
torch.manual_seed(random_seed)

if torch.cuda.device_count() > 1:
    torch.cuda.manual_seed_all(random_seed)
else:
    torch.cuda.manual_seed(random_seed)
    
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

save_dir = './save_model/' + TIMESTAMP

##### Create DataLoaders
These will be used to feed batches of data to the neural network.

In [11]:
trainLoader = torch.utils.data.DataLoader(trainDataset,
                                          batch_size=batch_size,
                                          shuffle=False)
validLoader = torch.utils.data.DataLoader(validDataset,
                                          batch_size=batch_size,
                                          shuffle=False)

##### Get examples of the contents of the DataLoaders

In [12]:
(idx, targetVar, inputVar) = next(iter(trainLoader))

print("Input shape: {}".format(np.shape(inputVar)))
print("Target shape: {}".format(np.shape(targetVar)))

Input shape: torch.Size([1, 10, 1, 64, 64])
Target shape: torch.Size([1, 10, 1, 64, 64])


##### Create the ConvLSTM

In [13]:
net = generate_convlstm().float()

In [15]:
device = "cpu" #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using {}".format(device))

Using cpu


##### Define the training and validation function

In [16]:
def train(net, device):
    '''
    main function to run the training
    '''

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    cur_epoch = 0
    lossfunction = nn.MSELoss().to(device)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, epochs + 1):
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar) in enumerate(t):
            inputs = inputVar.to(device).float()  # B,S,C,H,W
            label = targetVar.to(device).float()  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device).float()
                label = targetVar.to(device).float()
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })

        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)

##### Train!

In [17]:
train(net, device)

HBox(children=(FloatProgress(value=0.0, max=700.0), HTML(value='')))

RuntimeError: All input tensors must be on the same device. Received cpu and cuda:0