# Spatiotemporal Forecasting Models with Moving MNIST
Moving MNIST is a video prediction task consisting of short, 20 frame videos of moving numbers. The task is to use the first 10 frames to predict the next 10 frames. 

Models currently implemented are:

 * ConvLSTM (neural network code has been taken from <a href="https://github.com/jhhuang96">jhhuang96's</a> ConvLSTM <a href="https://github.com/jhhuang96/ConvLSTM-PyTorch">GitHub repository</a>)
 
Models to be implemented:

 * Physical Dynamics Net (https://arxiv.org/abs/2003.01460)
 * Scalable Gradients for Stochastic Differential Equations (https://arxiv.org/abs/2001.01328)
 * ODE2VAE: Deep generative second order ODEs with Bayesian neural networks (https://arxiv.org/abs/1905.10994)
 * Disentangling Multiple Features in Video Sequences using Gaussian Processes in Variational Autoencoders (https://arxiv.org/abs/2001.02408)

In [1]:
import os
#from data.mm import MovingMNIST
import sys
from tqdm.notebook import tqdm
import numpy as np

import torch
from torch import nn
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import logging
from ConvRNN.generate_model import *
from earlystopping import EarlyStopping

In [2]:
class SpatiotemporalDataset(Dataset):

    def __init__(self, X, forecast_horizon, lags):
        self.X = X
        self.forecast_horizon = forecast_horizon
        self.lags = lags

    def __len__(self):
        length = len(self.X) - self.forecast_horizon - self.lags + 1
        return length

    def __getitem__(self, idx):
        x = self.X[idx:idx+self.lags]
        y = self.X[idx+self.lags:idx+self.lags+self.forecast_horizon]

        return idx, x, y, 0, 0

In [3]:
N_SAMPLES = 1000
N_CHANNELS = 1
IM_HEIGHT = 64
IM_WIDTH = 64
FORECAST_HORIZON = 10
LAGS = 10

X = np.random.normal(0, 1, (N_SAMPLES, N_CHANNELS, IM_HEIGHT, IM_WIDTH))

##### Set hyperparameters and random seeds

In [4]:
TIMESTAMP = "2020-03-09T00-00-00"

batch_size = 4
lr = 1e-4
frames_input = 10
frames_output = 10
epochs = 500

random_seed = 1996
np.random.seed(random_seed)
torch.manual_seed(random_seed)

if torch.cuda.device_count() > 1:
    torch.cuda.manual_seed_all(random_seed)
else:
    torch.cuda.manual_seed(random_seed)
    
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

save_dir = './save_model/' + TIMESTAMP

##### Create DataLoaders
These will be used to feed batches of data to the neural network.

In [5]:
trainDataset = SpatiotemporalDataset(X, FORECAST_HORIZON, LAGS)
validDataset = SpatiotemporalDataset(X, FORECAST_HORIZON, LAGS)

trainLoader = torch.utils.data.DataLoader(trainDataset,
                                          batch_size=batch_size,
                                          shuffle=False)
validLoader = torch.utils.data.DataLoader(validDataset,
                                          batch_size=batch_size,
                                          shuffle=False)

##### Get examples of the contents of the DataLoaders

In [6]:
(idx, targetVar, inputVar, _, _) = next(iter(trainLoader))

print("Input shape: {}".format(np.shape(inputVar)))
print("Target shape: {}".format(np.shape(targetVar)))

Input shape: torch.Size([4, 10, 1, 64, 64])
Target shape: torch.Size([4, 10, 1, 64, 64])


##### Create the ConvLSTM

In [7]:
net = generate_convlstm().float()

Using cuda:0


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using {}".format(device))

##### Define the training and validation function

In [8]:
def train(net, device):
    '''
    main function to run the training
    '''
    run_dir = './runs/' + TIMESTAMP
    if not os.path.isdir(run_dir):
        os.makedirs(run_dir)
    tb = SummaryWriter(run_dir)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.to(device)

    if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
        # load existing model
        print('==> loading existing model')
        model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar'))
        net.load_state_dict(model_info['state_dict'])
        optimizer = torch.optim.Adam(net.parameters())
        optimizer.load_state_dict(model_info['optimizer'])
        cur_epoch = model_info['epoch'] + 1
    else:
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        cur_epoch = 0
    lossfunction = nn.MSELoss().cuda()
    optimizer = optim.Adam(net.parameters(), lr=lr)
    pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                      factor=0.5,
                                                      patience=4,
                                                      verbose=True)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    # mini_val_loss = np.inf
    for epoch in range(cur_epoch, epochs + 1):
        ###################
        # train the model #
        ###################
        t = tqdm(trainLoader, leave=False, total=len(trainLoader))
        for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
            inputs = inputVar.to(device).float()  # B,S,C,H,W
            label = targetVar.to(device).float()  # B,S,C,H,W
            optimizer.zero_grad()
            net.train()
            pred = net(inputs)  # B,S,C,H,W
            loss = lossfunction(pred, label)
            loss_aver = loss.item() / batch_size
            train_losses.append(loss_aver)
            loss.backward()
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
            optimizer.step()
            t.set_postfix({
                'trainloss': '{:.6f}'.format(loss_aver),
                'epoch': '{:02d}'.format(epoch)
            })
        tb.add_scalar('TrainLoss', loss_aver, epoch)
        ######################
        # validate the model #
        ######################
        with torch.no_grad():
            net.eval()
            t = tqdm(validLoader, leave=False, total=len(validLoader))
            for i, (idx, targetVar, inputVar, _, _) in enumerate(t):
                if i == 3000:
                    break
                inputs = inputVar.to(device).float()
                label = targetVar.to(device).float()
                pred = net(inputs)
                loss = lossfunction(pred, label)
                loss_aver = loss.item() / batch_size
                # record validation loss
                valid_losses.append(loss_aver)
                #print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
                t.set_postfix({
                    'validloss': '{:.6f}'.format(loss_aver),
                    'epoch': '{:02d}'.format(epoch)
                })

        tb.add_scalar('ValidLoss', loss_aver, epoch)
        torch.cuda.empty_cache()
        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = len(str(epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.6f} ' +
                     f'valid_loss: {valid_loss:.6f}')

        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        pla_lr_scheduler.step(valid_loss)  # lr_scheduler
        model_dict = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    with open("avg_train_losses.txt", 'wt') as f:
        for i in avg_train_losses:
            print(i, file=f)

    with open("avg_valid_losses.txt", 'wt') as f:
        for i in avg_valid_losses:
            print(i, file=f)

##### Train!

In [9]:
train(net, device)

HBox(children=(FloatProgress(value=0.0, max=246.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=246.0), HTML(value='')))

[  0/500] train_loss: 0.250081 valid_loss: 0.250058
Validation loss decreased (inf --> 0.250058).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=246.0), HTML(value='')))

KeyboardInterrupt: 

##### Analyse training metrics
To view training loss in Tensorboard, enter ```tensorboard --logdir=<path_to_runs_dir>``` into the command line and navigate to ```localhost:6066/``` in your browser.