In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import numpy as np

import io
import imageio
from ipywidgets import widgets, HBox

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# Original ConvLSTM cell as proposed by Shi et al.
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTMCell, self).__init__()  

        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, 
            out_channels=4 * out_channels, 
            kernel_size=kernel_size, 
            padding=padding)           

        # Initialize weights for Hadamard Products
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )

        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(C_conv)

        output_gate = torch.sigmoid(o_conv + self.W_co * C )

        # Current Hidden State
        H = output_gate * self.activation(C)

        return H, C

In [3]:
class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()

        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

    def forward(self, X):

        # X is a frame sequence (batch_size, num_channels, seq_len, height, width)

        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):

            H, C = self.convLSTMcell(X[:,:,time_step], H, C)

            output[:,:,time_step] = H

        return output

In [4]:
class Seq2Seq(nn.Module):

    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
    activation, frame_size, num_layers):

        super(Seq2Seq, self).__init__()

        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add rest of the layers
        for l in range(2, num_layers+1):

            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=frame_size)
                )
                
            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
                ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):

        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])
        
        return nn.Sigmoid()(output)

In [5]:
# Load Data as Numpy Array
# MovingMNIST = np.load('/kaggle/input/mnist-moving/mnist_test_seq.npy')
MovingMNIST = np.load('/kaggle/input/our-dummy-dataset/merged.npy')
print(MovingMNIST.shape)

# MovingMNIST = MovingMNIST.transpose(1, 0, 2, 3)
# print(MovingMNIST.shape)

# Shuffle Data
np.random.shuffle(MovingMNIST)

# Train, Test, Validation splits
# train_data = MovingMNIST[:6000]         
# val_data = MovingMNIST[8000:9000]       
# test_data = MovingMNIST[9000:10000]

train_data = MovingMNIST[:15]         
val_data = MovingMNIST[15:17]       
test_data = MovingMNIST[17:20]     

def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch)
#     batch = batch / 255.0                        
    batch = batch.to(device)   

    # Randomly pick 10 frames as input, 11th frame is target
    rand = np.random.randint(10,15)                     
    return batch[:,:,rand-10:rand], batch[:,:,rand]     


# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, 
                        batch_size=2, collate_fn=collate)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, 
                        batch_size=2, collate_fn=collate)

(20, 2, 18, 64, 64)


In [6]:
# # Get a batch
# input, _ = next(iter(val_loader))
# print(input.shape)

# # Reverse process before displaying
# input = input.cpu().numpy() * 255.0     

# for video in input.squeeze(1)[:3]:          # Loop over videos
#     with io.BytesIO() as gif:
#         imageio.mimsave(gif,video.astype(np.uint8),"GIF",fps=5)
#         display(HBox([widgets.Image(value=gif.getvalue())]))

In [7]:
# Get a batch
input, _ = next(iter(val_loader))
print(input.shape)
print(_.shape)

# Reverse process before displaying
input = input.cpu().numpy() * 255.0

input = np.transpose(input, (0, 2, 3, 4, 1))

for video in input[:2]:  # Loop over first 3 videos
    frames = []
    for frame in video:
        # Take only the first channel
        frame = frame[:,:,0]
        frames.append(frame.astype(np.uint8))
    
    with io.BytesIO() as gif:
        imageio.mimsave(gif, frames, format="GIF", fps=1)
        display(HBox([widgets.Image(value=gif.getvalue())]))

  batch = torch.tensor(batch)


torch.Size([2, 2, 10, 64, 64])
torch.Size([2, 2, 64, 64])


HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x00\x00\x00%%%&&&\'\'\'((()))***+++,,,---...///00011…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x10\x10\x10\x11\x11\x11\x12\x12\x12\x13\x13\x13\x15\…

In [8]:
# The input video frames are grayscale, thus single channel
model = Seq2Seq(num_channels=2, num_kernels=64, 
kernel_size=(3, 3), padding=(1, 1), activation="relu", 
frame_size=(64, 64), num_layers=6).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCEWithLogitsLoss(reduction='sum')

In [9]:
from tqdm import tqdm

min_val_loss = 100000

num_epochs = 10
for epoch in tqdm(range(1, num_epochs+1), desc="Epochs"):
    
    train_loss = 0
    model.train()
    for batch_num, (input, target) in enumerate(train_loader, 1):
        output = model(input)
        loss = criterion(output.flatten(), target.flatten())
        loss.backward()
        optim.step()
        optim.zero_grad()
        train_loss += loss.item()
    train_loss /= len(train_loader.dataset)
    
    val_loss = 0
    model.eval()
    with torch.no_grad():
        for input, target in val_loader:
            output = model(input)
            loss = criterion(output.flatten(), target.flatten())
            val_loss += loss.item()
    val_loss /= len(val_loader.dataset)
    
    # save the model
    if(val_loss < min_val_loss):
        min_val_loss = val_loss
        torch.save(model.state_dict(), "/kaggle/working/model.pth")
    
    print(f"Epoch:{epoch} Training Loss:{train_loss:.2f} Validation Loss:{val_loss:.2f}\n")

Epochs:  10%|█         | 1/10 [00:03<00:35,  3.92s/it]

Epoch:1 Training Loss:6334.54 Validation Loss:6885.02



Epochs:  20%|██        | 2/10 [00:06<00:24,  3.11s/it]

Epoch:2 Training Loss:5791.56 Validation Loss:6881.19



Epochs:  30%|███       | 3/10 [00:08<00:19,  2.85s/it]

Epoch:3 Training Loss:5719.89 Validation Loss:6786.67



Epochs:  40%|████      | 4/10 [00:11<00:16,  2.75s/it]

Epoch:4 Training Loss:5697.52 Validation Loss:6281.96



Epochs:  50%|█████     | 5/10 [00:14<00:13,  2.70s/it]

Epoch:5 Training Loss:5687.28 Validation Loss:5867.89



Epochs:  60%|██████    | 6/10 [00:16<00:10,  2.64s/it]

Epoch:6 Training Loss:5683.39 Validation Loss:5779.40



Epochs:  70%|███████   | 7/10 [00:19<00:07,  2.59s/it]

Epoch:7 Training Loss:5682.21 Validation Loss:5780.77



Epochs:  80%|████████  | 8/10 [00:21<00:05,  2.56s/it]

Epoch:8 Training Loss:5681.52 Validation Loss:5783.33



Epochs:  90%|█████████ | 9/10 [00:24<00:02,  2.54s/it]

Epoch:9 Training Loss:5681.10 Validation Loss:5846.59



Epochs: 100%|██████████| 10/10 [00:26<00:00,  2.67s/it]

Epoch:10 Training Loss:5682.50 Validation Loss:5822.47






In [10]:
def collate_test(batch):

    # Last 10 frames are target
    target = np.array(batch)[:,:,10:]       
    
    # Add channel dim, scale pixels between 0 and 1, send to GPU
#     batch = torch.tensor(batch).unsqueeze(1)          
    batch = torch.tensor(batch) 
#     batch = batch / 255.0                             
    batch = batch.to(device)                          
    return batch, target

# Test Data Loader
test_loader = DataLoader(test_data,shuffle=True, 
                         batch_size=2, collate_fn=collate_test)

# Get a batch
batch, target = next(iter(test_loader))
print(batch.shape)
print(target.shape)

# Initialize output sequence
output = np.zeros(target.shape, dtype=np.uint8)

# Loop over timesteps
with torch.no_grad():
    for timestep in range(target.shape[1]):
      input = batch[:,:,timestep:timestep+10]   
      output[:,:,timestep] = (model(input).detach().cpu() * 255.0).numpy()


torch.Size([2, 2, 18, 64, 64])
(2, 2, 8, 64, 64)


In [11]:
target = np.transpose(target, (0, 2, 3, 4, 1))
output = np.transpose(output, (0, 2, 3, 4, 1))

for tgt, out in zip(target, output):       # Loop over samples
    # Write target video as gif
    
    tgt = tgt * 255.0
    
    frames = []
    for frame in tgt:
        # Take only the first channel
        frame = frame[:,:,0]
        frames.append(frame.astype(np.uint8))
        
    with io.BytesIO() as gif:
        imageio.mimsave(gif, frames, "GIF", fps = 1)    
        target_gif = gif.getvalue()

    # Write output video as gif
    frames = []
    for frame in out:
        # Take only the first channel
        frame = frame[:,:,0]
        frames.append(frame.astype(np.uint8))
        
    with io.BytesIO() as gif:
        imageio.mimsave(gif, frames, "GIF", fps = 1)    
        output_gif = gif.getvalue()

    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x02\x02\x02\x03\x03\x03\x04\x04\x04\x05\x05\x05\x06\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\x0f\x0f\x0f\x11\x11\x11\x12\x12\x12\x13\x13\x13\x14\…