In [None]:
import os
import json
import math
import numpy as np

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

In [None]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 3, kernel_size=5, padding=1),
            act_fn()
            #nn.Conv2d(3, 3, kernel_size=5, padding=1), # 8x8 => 4x4
            #act_fn(),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
encoder=Encoder(3,32,6)

In [None]:
video=torch.zeros((11,3,32,32))
for k in range(11):
  video[k,:,:,:]=torch.rand((3,32,32))

In [None]:
video.shape

torch.Size([11, 3, 32, 32])

In [None]:
output_first=torch.zeros((11,3,6,6))
for k in range(video.shape[0]):
  output_first[k,:,:,:]=encoder(video[k,:,:,:].unsqueeze(0))

In [None]:
class SelfConnection(nn.Module):
  def __init__(self, units, get_w=True, **kwargs):
        super().__init__(**kwargs)

        # self_connections_metrics = (np.random.random( (units, units) ) > 0.5).astype(dtype=int).astype(dtype=tf.float32)
        self.shift = np.random.randint(int(units*0.3), int(units*0.8))
        self.s = torch.tensor(np.random.random((units)),dtype=torch.float32)

        self.w = torch.tensor(np.random.random((units)),dtype=torch.float32)
        self.b = torch.tensor(np.random.random((units)),dtype=torch.float32)

        self.get_w = get_w
  def forward(self, x, state = 1.0, get_weights = False):
        # weights x some_amount_of (shifted weights)
        print(x.shape)
        w = self.w + ( self.s * torch.roll(self.w, shifts=self.shift) )
        w = w * state
        if self.get_w or get_weights:
            out= (x * w) + self.b, self.w
            print("out shape is ",out[1].shape)
            print(type(out[0]))
            return out[0]
        else:
            print(x.shape,w.shape)
            return (x * w) + self.b

In [None]:
self_conn=SelfConnection(3)

In [217]:
class Encoder_time(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Linear(36*11*3,36*11*3),
            SelfConnection(36*11*3),
            act_fn(),
            nn.Linear(36*11*3,36*11*3),
            SelfConnection(36*11*3),
            act_fn(),
            nn.Linear(36*11*3,36*10*3),
            SelfConnection(36*10*3),
            act_fn(),
        )

    def forward(self, x):
        return self.net(x)

In [218]:
output_first=torch.tensor(output_first)

  """Entry point for launching an IPython kernel.


In [219]:
encoder_time=Encoder_time(3,6,6)

In [220]:
encoded_video=encoder(video)

In [221]:
encoded_video.shape

torch.Size([11, 3, 6, 6])

In [222]:
encoded_video=encoded_video.flatten()

In [223]:
encoded_video.shape

torch.Size([1188])

In [224]:
output=encoder_time(encoded_video)

torch.Size([1188])
out shape is  torch.Size([1188])
<class 'torch.Tensor'>
torch.Size([1188])
out shape is  torch.Size([1188])
<class 'torch.Tensor'>
torch.Size([1080])
out shape is  torch.Size([1080])
<class 'torch.Tensor'>


In [225]:
output=output.reshape(10,3,6,6)

#second encoder output

In [226]:
output.shape

torch.Size([10, 3, 6, 6])

# Final output

In [230]:
encoded_video=encoded_video.reshape(11,3,6,6)

In [231]:
torch.cat((output,encoded_video[0].unsqueeze(0))).shape

torch.Size([11, 3, 6, 6])