In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from dataset import MNIST_Moving

In [2]:

train_set = MNIST_Moving(root='.data/mnist', train=True, download=True)
test_set = MNIST_Moving(root='.data/mnist', train=False, download=True)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

In [4]:
input = next(iter(train_loader))
print(input.shape)

torch.Size([100, 20, 1, 64, 64])


In [5]:
print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

for seq, seq_target in train_loader:
    print('--- Sample')
    print('Input:  ', seq.shape)
    print('Target: ', seq_target.shape)
    break

==>>> total trainning batch number: 90
==>>> total testing batch number: 10


ValueError: too many values to unpack (expected 2)

In [34]:
class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias,mode="zeros"):
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        self.mode = mode

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
       
        

    def forward(self, x, cur_state):
        h_cur, c_cur = cur_state
        x = x.to(device)
        h_cur = h_cur.to(device)
        # print(x.size())
        # print(h_cur.size())
        concat_input_hcur = torch.cat([x, h_cur], dim=1) 
        concat_input_hcur = concat_input_hcur.to(device)

        concat_input_hcur_conv = self.conv(concat_input_hcur)
        concat_input_hcur_conv = concat_input_hcur_conv.to(device)

        cc_input_gate, cc_forget_gate, cc_output_gate, cc_output = torch.split(concat_input_hcur_conv, self.hidden_dim, dim=1)
        # print("cci",cc_input_gate.shape)
        # print("ccf",cc_forget_gate.shape)
        # print("ccog",cc_output_gate.shape)
        # print("cco",cc_output.shape)
        # print("cccur",c_cur.shape)
        # print("wci",self.W_ci.shape)
        
        input_gate = torch.sigmoid(cc_input_gate +  c_cur)

        forget_gate = torch.sigmoid(cc_forget_gate +  c_cur)

        output = torch.tanh(cc_output)

        c_next = forget_gate * c_cur + input_gate * output

        output_gate = torch.sigmoid(cc_output_gate +  c_next)

        h_next = output * torch.tanh(c_next)

        return h_next, c_next

    def init_state(self, batch_size, image_size):
        height, width = image_size
        """ Initializing hidden and cell state """
        if(self.mode == "zeros"):
            h = torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
            c = torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
        elif(self.mode == "random"):
            h = torch.randn(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
            c = torch.randn(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
        elif(self.mode == "learned"):
            h = self.learned_h.repeat(batch_size, 1, height, width, device=self.conv.weight.device)
            c = self.learned_c.repeat(batch_size, 1, height, width, device=self.conv.weight.device)
        
        return h, c

        

class ConvLSTM(nn.Module):
    
    """ 
    Custom LSTM for images. Batches of images are fed to a Conv LSTM
    
    Args:
    -----
    input_dim: integer
        Number of channels of the input.
    hidden_dim: integer
        dimensionality of the states in the cell
    kernel_size: tuple
        size of the kernel for convolutions
    num_layers: integer
        number of stacked LSTMS
    """

    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,batch_first=False, bias=True, return_all_layers=False):
        super(ConvLSTM, self).__init__()

        
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
       
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        conv_lstms  = []
        # iterating over no of layers
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            conv_lstms.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.conv_lstms = nn.ModuleList(conv_lstms)

    def forward(self, x, hidden_state=None):
       

        x=x.unsqueeze(dim=1)
        b, _, _, h, w = x.size()

        if hidden_state is not None:
            raise NotImplementedError()
        else:
            hidden_state = self._init_hidden(batch_size=b,
                                             image_size=(h, w))

        
        cur_layer_input = x
        output_list = []
        x_len = x.size(1)
        

        # iterating over no of layers
        for i in range(self.num_layers):

            h, c = hidden_state[i]
            each_layer_output = []
            # iterating over sequence length

            for t in range(x_len):
                h, c = self.conv_lstms[i](x=cur_layer_input[:, t, :, :, :],cur_state=[h, c])
                each_layer_output.append(h)

            stacked_layer_output = torch.stack(each_layer_output, dim=1)
            cur_layer_input = stacked_layer_output

            output_list.append(stacked_layer_output)

        if not self.return_all_layers:
            output_list = output_list[-1:]

        batch_shape = output_list[-1].shape[0]

        return torch.stack(each_layer_output), (h, c)

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.conv_lstms[i].init_state(batch_size, image_size))
        return init_states

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [36]:
class FirstLayerEncBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        # Add BN here 
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3,stride=1,padding=(1, 1))
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3,stride=1,padding=(1, 1))
    
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))

class FirstEncoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.enc_blocks = nn.ModuleList([FirstLayerEncBlock(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        layers = []
        print("Staring encoding process")
        print("Input to the encoder", x.shape)
        for block in self.enc_blocks:
            x = block(x)
            x = self.pool(x)
            print("Output after block")
            print(x.shape)
            layers.append(x)
        return x

In [37]:
class FirstLayerDecBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(in_ch, out_ch, 3,stride=2,padding=(1, 1))
        self.conv2 = nn.ConvTranspose2d(out_ch, out_ch, 2,stride=1)
    
    def forward(self, x):
        return self.conv2((self.conv1(x)))

class FirstDecoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.dec_blocks = nn.ModuleList([FirstLayerDecBlock(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        
    
    def forward(self, x):
        layers = []
        print("Starting decoding step")
        print("Input shape", x.shape)

        for block in self.dec_blocks:
            x = block(x)
            print("Intermediate shape", x.shape)   
        return x

In [38]:
enc_chs=(1,16,64)
e= FirstEncoder(enc_chs)
e.to(device)
x = torch.randn((100,1, 64, 64))
y1= e(input1[:,9:,:,:].float())
print("Correct step")

Staring encoding process
Input to the encoder torch.Size([100, 1, 64, 64])
Output after block
torch.Size([100, 16, 32, 32])
Output after block
torch.Size([100, 64, 16, 16])
Correct step


In [39]:
image_size = (28,28)
output_label_size = 10

In [40]:
conv_model= ConvLSTM(input_dim= 64, hidden_dim = 64, kernel_size = (5,5), num_layers= 1)
if torch.cuda.is_available():
    conv_model.to(device)

In [41]:
print(conv_model)

ConvLSTM(
  (conv_lstms): ModuleList(
    (0): ConvLSTMCell(
      (conv): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
)


In [45]:
c=conv_model(y1)
print(c[1][0].size())
convlstm_out=c[1][0]

torch.Size([100, 64, 16, 16])


In [32]:
dec_chs=(64,16,1)
d= FirstDecoder(dec_chs)
d.to(device)

cv=d(convlstm_out)
print("correct step")

Starting decoding step
Input shape torch.Size([100, 64, 16, 16])
Intermediate shape torch.Size([100, 16, 32, 32])
Intermediate shape torch.Size([100, 1, 64, 64])
correct step


In [33]:
enc_chs1=(64,128)
e1= FirstEncoder(enc_chs1)
e1.to(device)
y2= e1(y1)
print(y2.shape)


Staring encoding process
Input to the encoder torch.Size([100, 64, 16, 16])
Output after block
torch.Size([100, 128, 8, 8])
torch.Size([100, 128, 8, 8])


In [128]:
conv_model1= ConvLSTM(input_dim= 128, hidden_dim = 128, kernel_size = (5,5), num_layers= 1)
if torch.cuda.is_available():
    conv_model1.to(device)

In [130]:
c1=conv_model1(y2)
print(c1[1][0].size())
convlstm_out1=c1[1][0]

torch.Size([100, 128, 8, 8])


In [132]:
dec_chs1=(128,64)
d2= FirstDecoder(dec_chs1)
d2.to(device)

cv1=d2(convlstm_out1)
print(cv1.shape)

hello
torch.Size([100, 128, 8, 8])
torch.Size([100, 64, 16, 16])
torch.Size([100, 64, 16, 16])


In [134]:
enc_chs2=(128,256)
e2= FirstEncoder(enc_chs2)
e2.to(device)
y3= e2(y2)
print(y3.shape)


hello
torch.Size([100, 128, 8, 8])
torch.Size([100, 256, 8, 8])
torch.Size([100, 256, 4, 4])


In [135]:
conv_model2= ConvLSTM(input_dim= 256, hidden_dim = 256, kernel_size = (5,5), num_layers= 1)
if torch.cuda.is_available():
    conv_model2.to(device)

In [137]:
c2=conv_model2(y3)
print(c2[1][0].size())
convlstm_out2=c2[1][0]

torch.Size([100, 256, 4, 4])


In [139]:
dec_chs2=(256,128)
d3= FirstDecoder(dec_chs2)
d3.to(device)

cv2=d3(convlstm_out2)
print(cv2.shape)

hello
torch.Size([100, 256, 4, 4])
torch.Size([100, 128, 8, 8])
torch.Size([100, 128, 8, 8])


In [None]:
class FirstLayerDecBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(in_ch, out_ch, 3,stride=2,padding=(1, 1))
        self.conv2 = nn.ConvTranspose2d(out_ch, out_ch, 2,stride=1)
    
    def forward(self, x):
        return self.conv2((self.conv1(x)))

class FirstDecoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.dec_blocks = nn.ModuleList([FirstLayerDecBlock(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        
    
    def forward(self, x):
        layers = []
        print("hello")
        print(x.shape)
        for block in self.dec_blocks:
            x = block(x)
            print(x.shape)
            
           
        return x

In [None]:
# training
