In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from MovingMNIST import MovingMNIST

In [3]:

train_set = MovingMNIST(root='.data/mnist', train=True, download=True)
test_set = MovingMNIST(root='.data/mnist', train=False, download=True)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

Downloading https://github.com/tychovdo/MovingMNIST/raw/master/mnist_test_seq.npy.gz
Processing...
Done!


In [4]:
print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

for seq, seq_target in train_loader:
    print('--- Sample')
    print('Input:  ', seq.shape)
    print('Target: ', seq_target.shape)
    break

==>>> total trainning batch number: 90
==>>> total testing batch number: 10
--- Sample
Input:   torch.Size([100, 10, 64, 64])
Target:  torch.Size([100, 10, 64, 64])


In [84]:

import io
import imageio
from ipywidgets import widgets, HBox

In [85]:
input, _ = next(iter(train_loader))

# Reverse process before displaying
input = input.cpu().numpy() * 255.0     

for video in input.squeeze(1)[:3]:          # Loop over videos
    with io.BytesIO() as gif:
        imageio.mimsave(gif,video.astype(np.uint8),"GIF",fps=5)
        display(HBox([widgets.Image(value=gif.getvalue())]))

ValueError: ignored

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [168]:
class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias,mode="zeros"):
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        self.mode = mode

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        self.W_ci = nn.Parameter(torch.zeros(1, self.hidden_dim, 16, 16))
        self.W_cf = nn.Parameter(torch.zeros(1, self.hidden_dim,  16, 16))
        self.W_co = nn.Parameter(torch.zeros(1, self.hidden_dim,  16, 16))
        

    def forward(self, x, cur_state):
        h_cur, c_cur = cur_state
        x = x.to(device)
        h_cur = h_cur.to(device)
        # print(x.size())
        # print(h_cur.size())
        concat_input_hcur = torch.cat([x, h_cur], dim=1) 
        concat_input_hcur = concat_input_hcur.to(device)

        concat_input_hcur_conv = self.conv(concat_input_hcur)
        concat_input_hcur_conv = concat_input_hcur_conv.to(device)

        cc_input_gate, cc_forget_gate, cc_output_gate, cc_output = torch.split(concat_input_hcur_conv, self.hidden_dim, dim=1)
        # print("cci",cc_input_gate.shape)
        # print("ccf",cc_forget_gate.shape)
        # print("ccog",cc_output_gate.shape)
        # print("cco",cc_output.shape)
        # print("cccur",c_cur.shape)
        # print("wci",self.W_ci.shape)
        
        input_gate = torch.sigmoid(cc_input_gate + self.W_ci * c_cur)

        forget_gate = torch.sigmoid(cc_forget_gate + self.W_cf * c_cur)

        output = torch.tanh(cc_output)

        c_next = forget_gate * c_cur + input_gate * output

        output_gate = torch.sigmoid(cc_output_gate + self.W_co * c_next)

        h_next = output * torch.tanh(c_next)

        return h_next, c_next

    def init_state(self, batch_size, image_size):
        height, width = image_size
        """ Initializing hidden and cell state """
        if(self.mode == "zeros"):
            h = torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
            c = torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
        elif(self.mode == "random"):
            h = torch.randn(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
            c = torch.randn(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
        elif(self.mode == "learned"):
            h = self.learned_h.repeat(batch_size, 1, height, width, device=self.conv.weight.device)
            c = self.learned_c.repeat(batch_size, 1, height, width, device=self.conv.weight.device)
        
        return h, c

        

class ConvLSTM(nn.Module):


  

    """ 
    Custom LSTM for images. Batches of images are fed to a Conv LSTM
    
    Args:
    -----
    input_dim: integer
        Number of channels of the input.
    hidden_dim: integer
        dimensionality of the states in the cell
    kernel_size: tuple
        size of the kernel for convolutions
    num_layers: integer
        number of stacked LSTMS
    """

    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
       
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers
        classifier_in_dim= (self.hidden_dim[0]*16*16)
        classifier_output_dim = output_label_size

        # FC-classifier
        self.classifier = nn.Linear(classifier_in_dim, classifier_output_dim)

        conv_lstms  = []
        # iterating over no of layers
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            conv_lstms.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.conv_lstms = nn.ModuleList(conv_lstms)

    def forward(self, x, hidden_state=None):
       

        x=x.unsqueeze(dim=1)
        b, _, _, h, w = x.size()

        if hidden_state is not None:
            raise NotImplementedError()
        else:
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))

        
        cur_layer_input = x
        output_list = []
        x_len = x.size(1)
        

        # iterating over no of layers
        for i in range(self.num_layers):

            h, c = hidden_state[i]
            each_layer_output = []
            # iterating over sequence length

            for t in range(x_len):
                h, c = self.conv_lstms[i](x=cur_layer_input[:, t, :, :, :],cur_state=[h, c])
                each_layer_output.append(h)

            stacked_layer_output = torch.stack(each_layer_output, dim=1)
            cur_layer_input = stacked_layer_output

            output_list.append(stacked_layer_output)

        if not self.return_all_layers:
            output_list = output_list[-1:]

        batch_shape = output_list[-1].shape[0]

        # classifying
        final_out= self.classifier(output_list[-1].view(batch_shape, - 1)) # feeding only output at last layer

        return final_out

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.conv_lstms[i].init_state(batch_size, image_size))
        return init_states

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [154]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 1,stride=1,padding=(3, 3))
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 1,stride=1,padding=(2, 2))
    
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))


class Encoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        layers = []
        print("hello")
        for block in self.enc_blocks:
            x = block(x)
            print(x.shape)
            layers.append(x)
            x = self.pool(x)
        return x


class Decoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class EncDecCNN(nn.Module):
    def __init__(self, dec_chs=(16,32,64), enc_chs=(64,32,16), num_class=1, retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim
        self.out_sz  = out_sz

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out

In [169]:
enc_chs=(1,16,32,64)
e= Encoder(enc_chs)
x = torch.randn((100,1, 64, 64))
y1= e(x)
print("y1",y1.size())

hello
torch.Size([100, 16, 74, 74])
torch.Size([100, 32, 47, 47])
torch.Size([100, 64, 33, 33])
y1 torch.Size([100, 64, 16, 16])


In [156]:
image_size = (28,28)
output_label_size = 10

In [170]:
conv_model= ConvLSTM(input_dim= 64, hidden_dim = 64, kernel_size = (5,5), num_layers= 1)
if torch.cuda.is_available():
    conv_model.to(device)

In [171]:
print(conv_model)

ConvLSTM(
  (classifier): Linear(in_features=16384, out_features=10, bias=True)
  (conv_lstms): ModuleList(
    (0): ConvLSTMCell(
      (conv): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
)


In [175]:
c=conv_model(y1)
print(c)

tensor([[-0.0044, -0.0019,  0.0049,  0.0063, -0.0050, -0.0043,  0.0015, -0.0018,
          0.0006,  0.0075],
        [-0.0043, -0.0020,  0.0049,  0.0063, -0.0050, -0.0044,  0.0014, -0.0018,
          0.0006,  0.0075],
        [-0.0043, -0.0020,  0.0049,  0.0063, -0.0050, -0.0043,  0.0014, -0.0018,
          0.0006,  0.0075],
        [-0.0044, -0.0020,  0.0049,  0.0062, -0.0050, -0.0044,  0.0014, -0.0019,
          0.0006,  0.0075],
        [-0.0043, -0.0020,  0.0049,  0.0063, -0.0050, -0.0043,  0.0015, -0.0018,
          0.0006,  0.0075],
        [-0.0043, -0.0020,  0.0049,  0.0063, -0.0050, -0.0044,  0.0014, -0.0018,
          0.0006,  0.0075],
        [-0.0044, -0.0020,  0.0049,  0.0063, -0.0050, -0.0043,  0.0014, -0.0018,
          0.0006,  0.0075],
        [-0.0044, -0.0019,  0.0050,  0.0063, -0.0050, -0.0044,  0.0014, -0.0018,
          0.0006,  0.0075],
        [-0.0043, -0.0020,  0.0049,  0.0062, -0.0050, -0.0044,  0.0014, -0.0018,
          0.0006,  0.0075],
        [-0.0044, -

In [22]:
dec_chs=(64,32,16,1)
d= Decoder(dec_chs)
x = torch.randn((1,64, 16, 16))
y= d(y1[::-1][0], y1[::-1][1:])

IndexError: ignored

In [15]:
EDmodel=EncDecCNN(dec_chs=(64,32,16,1), enc_chs=(1,16,32,64), num_class=10, retain_dim=False)

In [16]:
print(EDmodel)

EncDecCNN(
  (encoder): Encoder(
    (enc_blocks): ModuleList(
      (0): Block(
        (conv1): Conv2d(1, 16, kernel_size=(1, 1), stride=(1, 1))
        (relu): ReLU()
        (conv2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      )
      (1): Block(
        (conv1): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
        (relu): ReLU()
        (conv2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (2): Block(
        (conv1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
        (relu): ReLU()
        (conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Decoder(
    (upconvs): ModuleList(
      (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
      (1): ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2))
      (2): ConvTranspose2d(16, 1, kernel_size=(2, 2), stride=(2, 2))
    )
    (dec_blocks): ModuleList