<a href="https://colab.research.google.com/github/vivek-chandan/-TimeSformer-from-scratch/blob/main/TimeSformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from einops import rearrange
from tqdm import tqdm
import timm


In [5]:
device = "cuda" if torch.cuda.is_available else "cpu"

# video data part currently not coding I will do it after transorfmer module

# Transformer Modules

In [11]:
class TimeSformerBlock(nn.Module):
    def __init__(self, dim, heads):
      super().__init__()

      self.temporal_attn = nn.MultiheadAttention(dim, heads, batch_first = True)
      self.spetial_attn = nn.MultiheadAttention(dim, heads, batch_first = True)

      self.norm1 = nn.LayerNorm(dim)
      self.norm2 = nn.LayerNorm(dim)
      self.norm3 = nn.LayerNorm(dim)

      self.mlp = nn.Sequential(
          nn.Linear(dim, dim * 4),
          nn.GELU(),
          nn.Linear(dim * 4, dim)
      )

    def forward(self, x):
      # temporal attention
      B, T, N, D = x.shape # BatchDim, TemporalDim , NumOfPatches, EmbeddingDim
      xt = rearrange(x, "b t n d -> (b  n) t d")
      xt = self.temporal_attn(xt, xt, xt)[0]
      xt = rearrange(xt, "(b n) t d -> b t n d", b = B, n = N)
      x = x + self.norm1(xt)

      # spatial attention

      xs = rearrange(x, "b t n d -> (b t) n d")
      xs = self.spetial_attn(xs, xs, xs)[0]
      xs = rearrange(xs, "(b t) n d -> b t n d", b = B, t = T)
      x = x + self.norm2(xs)

      # mlp

      xl = self.mlp(x)
      x = x + self.norm3(xl)

      return x

In [12]:
import torch.nn as nn # Added import for nn

# TimeSformer

class TimeSformer(nn.Module): # Corrected Modules to Module
  def __init__(self ,
               num_classes,
               num_frames=8,
               img_size = 224,
               patch_size = 16 ,
               embed_dim = 768,
               depths =12,
               heads = 12,):

    super(). __init__()

    self.num_frames = num_frames
    self.patch_size = patch_size
    self.embed_dim = embed_dim
    self.img_size = img_size
    self.num_patches = (img_size // patch_size) ** 2

    self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim) )

    self.time_embed = nn.Parameter(torch.randn(1, self.num_frames + 1, embed_dim ))
    self.space_embed = nn.Parameter(torch.randn(1, self.num_patches , embed_dim ))

    self.blocks = nn.ModuleList([
        TimeSformerBlock(embed_dim, heads) for _ in range(depths) # Corrected TimeSformerBlock to TransformerBlock
        ]
        )
    self.norm = nn.LayerNorm(embed_dim)
    self.head = nn.Linear(embed_dim, num_classes)

  def forward(self, x):
   B, T, C, H, W = x.shape # B-> batch , T -> temporal , C -> chanel, H-> height ,W->width
   x = x.view(B * T, C, H, W)
   x = self.patch_embed(x)
   x= x.flatten(2).transpose(1, 2)
   x = x.view(B, T, -1, self.embed_dim)

   # add position embedding excludding cls token
   x = x + self.time_embed[:,1:T+1, None, :]+self.space_embed[:, None, :, :]

   # add cls token
   cls = self.cls_token.expand(B, -1, self.num_patches, -1)
   cls = cls + self.time_embed[:, :1, None, :]
   # preprend cls token
   x = torch.cat((cls, x), dim=1)

   #transformer blocks
   for block in self.blocks:
     x = block(x)

   cls_out = self.norm(x[:, 0,0])
   out = self.head(cls_out)

   return out

In [13]:
model = TimeSformer(num_classes=101,
                    num_frames= 8,
                    embed_dim= 768,
                    depth = 12,
                    heads =12).to(device)


NameError: name 'train_dataset' is not defined