In [1]:
import torch
import torch.nn as nn
import math

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
class TransformerEncoder3D(nn.Module):
    def __init__(self, emb_size, num_layers, num_heads, dim_feedforward, dropout):
        super(TransformerEncoder3D, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
    def forward(self, x):
        x = self.transformer_encoder(x)
        return x

In [4]:
class CausalSelfAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super(CausalSelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=emb_size, num_heads=num_heads, dropout=dropout)
        self.mask = None

    def forward(self, x):
        seq_len = x.size(0)
        if self.mask is None or self.mask.size(0) != seq_len:
            self.mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
        attn_output, _ = self.attention(x, x, x, attn_mask=self.mask)
        return attn_output

class TemporalTransformerEncoder(nn.Module):
    def __init__(self, emb_size, num_layers, num_heads, dim_feedforward, dropout):
        super(TemporalTransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=emb_size, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout)
            for _ in range(num_layers)
        ])
        self.causal_attention = CausalSelfAttention(emb_size, num_heads, dropout)
        
    def forward(self, x):
        x = x.transpose(0, 1)  # Transformer expects sequence as first dimension
        x = self.causal_attention(x)
        for layer in self.layers:
            x = layer(x)
        x = x.transpose(0, 1)  # Revert back to original dimensions
        return x

In [51]:
class VideoTransformerModel(nn.Module):
    def __init__(self, 
                 video_dimension, 
                 emb_size,
                 d_model,
                 patch_size, 
                 num_layers_spatial, 
                 num_heads_spatial, 
                 dim_feedforward_spatial, 
                 dropout_spatial, 
                 num_layers_temporal, 
                 num_heads_temporal, 
                 dim_feedforward_temporal, 
                 dropout_temporal):
        super(VideoTransformerModel, self).__init__()
        self.patch_embedding = PatchEmbedding(video_dimension, patch_size, emb_size)
        self.spatial_transformer = TransformerEncoder3D(d_model, num_layers_spatial, num_heads_spatial, dim_feedforward_spatial, dropout_spatial)
        self.temporal_transformer = TemporalTransformerEncoder(d_model, num_layers_temporal, num_heads_temporal, dim_feedforward_temporal, dropout_temporal)
        
    def forward(self, x):
        patches, spatial_dims = self.patch_embedding(x)
        print(f"Patches: {patches.shape}")
        spatially_encoded_patches = self.spatial_transformer(patches)
        print(f"Spatial: {spatially_encoded_patches.shape}")
        temporally_encoded_patches = self.temporal_transformer(spatially_encoded_patches)
        print(f"Temporal: {temporally_encoded_patches.shape}")
        batch_size = temporally_encoded_patches.size(0)
        return temporally_encoded_patches

In [52]:
with torch.inference_mode():
    model = VideoTransformerModel(
        video_dimension=(11, 3, 128, 128),
        emb_size=512,
        d_model=512,
        patch_size=(2, 8, 8),
        num_layers_spatial=4,
        num_heads_spatial=8,
        dim_feedforward_spatial=2048,
        dropout_spatial=0.1,
        num_layers_temporal=4, 
        num_heads_temporal=8,
        dim_feedforward_temporal=2048,
        dropout_temporal=0.1
    ).to(device)
    outputs = model(torch.rand(1, 11, 3, 128, 128).to(device))
    print(f"Output shape: {outputs.shape}")

Patches: torch.Size([1, 6, 512])
Spatial: torch.Size([1, 6, 512])
Temporal: torch.Size([1, 6, 512])
Output shape: torch.Size([1, 6, 512])


In [43]:
class PatchEmbedding(nn.Module):
    def __init__(self, video_dimensions, patch_dim, emb_size):
        super(PatchEmbedding, self).__init__()
        self.patch_dim = patch_dim
        self.emb_size = emb_size
        self.channels = video_dimensions[1]
        self.video_height = video_dimensions[2]
        self.video_width = video_dimensions[3]
        
        self.patch_video = nn.Conv3d(in_channels=self.channels,
                                    out_channels=self.channels,
                                    kernel_size=self.patch_dim,
                                    stride=self.patch_dim)
        
        self.patch_first_frame = nn.Conv2d(in_channels=self.channels,
                                           out_channels=self.channels,
                                           kernel_size=self.patch_dim[1:],
                                           stride=self.patch_dim[1:])
        
        self.flatten = nn.Flatten(start_dim=2, end_dim=-1)
        patch_height = self.calculate_3d_conv_output_size(video_dimensions[2], self.patch_dim[1], self.patch_dim[1], 0)
        patch_width = self.calculate_3d_conv_output_size(video_dimensions[3], self.patch_dim[2], self.patch_dim[2], 0)

        flatten_size = self.channels * patch_height * patch_width
        self.linear = nn.Linear(in_features=flatten_size,
                                out_features=emb_size)
    
    def calculate_3d_conv_output_size(self, input_size, kernel_size, stride, padding):
        return math.floor((input_size - kernel_size + 2 * padding) / stride) + 1
    
    def forward(self, x):
        first_frame, frames = torch.split(x, [1, x.size(1)-1], dim=1)
        first_frame = first_frame.transpose(1, 2).squeeze(dim=2)
        first_frame = self.patch_first_frame(first_frame)
        first_frame = first_frame.unsqueeze(dim=2).transpose(1, 2)
        frames = frames.transpose(1, 2)
        frames = self.patch_video(frames)
        frames = frames.transpose(1, 2)
        video_tokens = torch.cat([first_frame, frames], dim=1)
        # print(f"video_tokens shape: {video_tokens.shape}")
        video_tokens = self.flatten(video_tokens)
        # print(f"video_tokens flatten shape: {video_tokens.shape}")
        video_tokens = self.linear(video_tokens)
        # print(f"video_tokens linear shape: {video_tokens.shape}")
        return video_tokens, video_tokens.shape

In [44]:
with torch.inference_mode():
    model = PatchEmbedding((11, 3, 128, 128), (2, 8, 8), 512).to(device)
    outputs = model(torch.rand(1, 11, 3, 128, 128).to(device))
    print(f"PatchEmbedding shape: {outputs}")


PatchEmbedding shape: (tensor([[[ 3.6181e-01,  1.2451e-01,  8.6804e-02,  ...,  1.2611e-01,
           5.1994e-02,  2.2712e-01],
         [ 4.1853e-01,  4.0628e-01,  5.1473e-02,  ..., -4.4006e-02,
          -6.5528e-02,  4.6966e-02],
         [ 3.6156e-01,  2.5480e-01, -6.7917e-02,  ..., -3.1022e-02,
          -1.6191e-01, -9.6736e-02],
         [ 4.7576e-01,  2.5257e-01,  1.9194e-03,  ...,  1.4130e-01,
          -2.1290e-01,  2.3011e-02],
         [ 1.7433e-01,  2.2282e-01, -1.1072e-01,  ...,  8.0207e-02,
          -4.5680e-02,  1.0851e-02],
         [ 2.8652e-01,  2.2029e-01, -8.4725e-02,  ...,  1.7288e-01,
          -2.2659e-05, -1.3890e-01]]], device='cuda:0'), torch.Size([1, 6, 512]))
