In [None]:
import os
import json
import math
import numpy as np

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

In [None]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 3, kernel_size=5, padding=1),
            act_fn()
            #nn.Conv2d(3, 3, kernel_size=5, padding=1), # 8x8 => 4x4
            #act_fn(),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
encoder=Encoder(3,32,6)

In [None]:
video=torch.zeros((11,3,32,32))
for k in range(11):
  video[k,:,:,:]=torch.rand((3,32,32))

In [None]:
video.shape

torch.Size([11, 3, 32, 32])

In [None]:
encoder(video).shape

torch.Size([11, 3, 6, 6])

In [None]:
class Encoder_time(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv3d(num_input_channels, c_hid, kernel_size=3, padding=1), # 32x32 => 16x16
            act_fn(),
            nn.Conv3d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv3d(c_hid, 2*c_hid, kernel_size=3, padding=1), # 16x16 => 8x8
            act_fn(),
            nn.Conv3d(2*c_hid, 3, kernel_size=(4,3,3), padding=1),
            act_fn()
            #nn.Conv2d(3, 3, kernel_size=5, padding=1), # 8x8 => 4x4
            #act_fn(),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
encoder_time=Encoder_time(3,6,6)

In [None]:
encoded_video=encoder(video)

In [None]:
encoded_video.shape

torch.Size([11, 3, 6, 6])

In [None]:
encoded_video=encoded_video.permute(1,0,2,3)

In [None]:
encoded_video.shape

torch.Size([3, 11, 6, 6])

In [None]:
output=encoder_time(encoded_video)

In [None]:
output.shape

torch.Size([10, 3, 6, 6])

In [None]:
output=output.permute(1,0,2,3)

In [None]:
encoded_video=encoded_video.permute(1,0,2,3)

# Final output

In [None]:
torch.cat((output,encoded_video[0].unsqueeze(0))).shape

torch.Size([11, 3, 6, 6])