<a href="https://colab.research.google.com/github/thenaivekid/Implementation-of-VideoMAE-for-Video-Classification/blob/main/Implementatioin_of_Video_MAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Optional, Tuple, Union, Callable, List, Dict, Any
import collections.abc
import json # For loading config from string/file
from dataclasses import dataclass

@dataclass
class BaseModelOutput:
    last_hidden_state: Any
    hidden_states: Optional[Tuple[Any, ...]] = None
    attentions: Optional[Tuple[Any, ...]] = None
# --- Configuration Class ---
class VideoMAEConfig:
    def __init__(self, **kwargs):
        # Essential parameters (with defaults for standalone use if not all are in kwargs)
        self.image_size: Union[int, Tuple[int, int]] = kwargs.get("image_size", 224)
        self.patch_size: Union[int, Tuple[int, int]] = kwargs.get("patch_size", 16)
        self.num_channels: int = kwargs.get("num_channels", 3)
        self.hidden_size: int = kwargs.get("hidden_size", 768)
        self.num_frames: int = kwargs.get("num_frames", 16)
        self.tubelet_size: int = kwargs.get("tubelet_size", 2)

        self.num_hidden_layers: int = kwargs.get("num_hidden_layers", 12)
        self.num_attention_heads: int = kwargs.get("num_attention_heads", 12)
        self.intermediate_size: int = kwargs.get("intermediate_size", 3072)

        self.hidden_act: str = kwargs.get("hidden_act", "gelu")
        self.hidden_dropout_prob: float = kwargs.get("hidden_dropout_prob", 0.0)
        self.attention_probs_dropout_prob: float = kwargs.get("attention_probs_dropout_prob", 0.0)
        self.qkv_bias: bool = kwargs.get("qkv_bias", True)
        self.layer_norm_eps: float = kwargs.get("layer_norm_eps", 1e-12)

        self.initializer_range: float = kwargs.get("initializer_range", 0.02)
        self.use_mean_pooling: bool = kwargs.get("use_mean_pooling", True)

        self.id2label: Optional[Dict[int, str]] = kwargs.get("id2label", None)
        self.label2id: Optional[Dict[str, int]] = kwargs.get("label2id", None)
        self.num_labels: Optional[int] = kwargs.get("num_labels", None)
        if self.num_labels is None and self.id2label is not None:
            self.num_labels = len(self.id2label)

        self.problem_type: Optional[str] = kwargs.get("problem_type", None)

        # Allow any other kwargs to be set as attributes
        for key, value in kwargs.items():
            setattr(self, key, value)

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]):
        return cls(**config_dict)

    @classmethod
    def from_json_file(cls, json_file_path: str):
        with open(json_file_path, 'r') as f:
            config_dict = json.load(f)
        return cls(**config_dict)

# --- Helper: Sinusoid Encoding ---
def get_sinusoid_encoding_table(n_position, d_hid):
    """Sinusoid position encoding table"""
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

# --- Modules ---
class VideoMAEPatchEmbeddings(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()

        image_size = config.image_size
        patch_size = config.patch_size

        image_size_tuple = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size_tuple = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)

        self.image_size = image_size_tuple
        self.patch_size = patch_size_tuple
        self.tubelet_size = int(config.tubelet_size)
        self.num_frames = config.num_frames
        self.num_channels = config.num_channels
        self.hidden_size = config.hidden_size

        num_patches_per_frame = (self.image_size[0] // self.patch_size[0]) * \
                                (self.image_size[1] // self.patch_size[1])
        num_temporal_patches = self.num_frames // self.tubelet_size
        self.num_patches = num_patches_per_frame * num_temporal_patches

        self.projection = nn.Conv3d(
            in_channels=self.num_channels,
            out_channels=self.hidden_size,
            kernel_size=(self.tubelet_size, self.patch_size[0], self.patch_size[1]),
            stride=(self.tubelet_size, self.patch_size[0], self.patch_size[1]),
        )

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size, num_frames, num_channels, height, width = pixel_values.shape
        if num_channels != self.num_channels:
            raise ValueError("Channel mismatch")
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(f"Input image size mismatch")
        if num_frames != self.num_frames:
            raise ValueError(f"Input frame count mismatch")

        pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
        embeddings = self.projection(pixel_values)
        embeddings = embeddings.flatten(2).transpose(1, 2)
        return embeddings

class VideoMAEEmbeddings(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.patch_embeddings = VideoMAEPatchEmbeddings(config)
        self.num_patches = self.patch_embeddings.num_patches
        self.position_embeddings_table = get_sinusoid_encoding_table(self.num_patches, config.hidden_size)

    def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
        embeddings = self.patch_embeddings(pixel_values)
        pos_embed = self.position_embeddings_table.type_as(embeddings).to(device=embeddings.device, copy=True)
        embeddings = embeddings + pos_embed
        if bool_masked_pos is not None:
            batch_size, _, num_channels_emb = embeddings.shape
            # Simplified handling for bool_masked_pos as in original.
            # This part is typically for MAE pre-training.
            embeddings = embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels_emb)
        return embeddings

def eager_attention_forward(
    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
    attention_mask: Optional[torch.Tensor], scaling: float,
    dropout_p: float = 0.0, training: bool = False,
):
    attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout_p, training=training)
    if attention_mask is not None:
        attn_weights = attn_weights * attention_mask
    attn_output = torch.matmul(attn_weights, value)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights

class VideoMAESelfAttention(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError("Hidden size not multiple of num_attention_heads")

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.dropout_prob = config.attention_probs_dropout_prob
        self.scaling = self.attention_head_size**-0.5

        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)

        self.q_bias = None
        self.v_bias = None
        if config.qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(self.all_head_size))
            self.v_bias = nn.Parameter(torch.zeros(self.all_head_size))

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self, hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None and self.v_bias is not None else None

        queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
        keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
        values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)

        query_layer = self.transpose_for_scores(queries)
        key_layer = self.transpose_for_scores(keys)
        value_layer = self.transpose_for_scores(values)

        context_layer, attention_probs = eager_attention_forward(
            query=query_layer, key=key_layer, value=value_layer,
            attention_mask=head_mask, scaling=self.scaling,
            dropout_p=self.dropout_prob, training=self.training,
        )
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.reshape(new_context_layer_shape)
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
        return outputs

class VideoMAESelfOutput(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
        # input_tensor is not strictly needed here for VideoMAE logic, but kept for potential compatibility
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states

class VideoMAEAttention(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.attention = VideoMAESelfAttention(config)
        self.output = VideoMAESelfOutput(config)

    def forward(
        self, hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:

        self_outputs = self.attention(hidden_states, head_mask, output_attentions)
        attention_output = self.output(self_outputs[0], hidden_states) # Pass hidden_states for API consistency
        outputs = (attention_output,) + self_outputs[1:]
        return outputs

class VideoMAEIntermediate(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if config.hidden_act == "gelu":
            self.intermediate_act_fn = nn.GELU()
        elif config.hidden_act == "relu":
            self.intermediate_act_fn = nn.ReLU()
        else:
            raise ValueError(f"Unsupported activation: {config.hidden_act}")

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

class VideoMAEOutput(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = hidden_states + input_tensor
        return hidden_states

class VideoMAELayer(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.attention = VideoMAEAttention(config)
        self.intermediate = VideoMAEIntermediate(config) # this is up scaling part of mlp
        self.output = VideoMAEOutput(config) # this is second part of mlp
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(
        self, hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:

        normed_hidden_states = self.layernorm_before(hidden_states)
        self_attention_outputs = self.attention(
            normed_hidden_states, head_mask, output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]
        hidden_states = attention_output + hidden_states # First residual

        normed_hidden_states_after_attn = self.layernorm_after(hidden_states)
        intermediate_output = self.intermediate(normed_hidden_states_after_attn)
        layer_output = self.output(intermediate_output, hidden_states) # Second residual inside self.output
        outputs = (layer_output,) + outputs
        return outputs

class VideoMAEEncoder(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.layer = nn.ModuleList([VideoMAELayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self, hidden_states: torch.Tensor,
        head_mask: Optional[List[Optional[torch.Tensor]]] = None,
        output_attentions: bool = False, output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            layer_head_mask = head_mask[i] if head_mask is not None and i < len(head_mask) else None
            layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
            hidden_states = layer_outputs[0]
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


class ImageClassifierOutput():
    def __init__(self, loss=None, logits=None, hidden_states=None, attentions=None):
        self.loss = loss
        self.logits = logits
        self.hidden_states = hidden_states
        self.attentions = attentions
    def __repr__(self):
        parts = []
        if self.loss is not None: parts.append(f"loss={self.loss.item() if isinstance(self.loss, torch.Tensor) else self.loss}")
        if self.logits is not None: parts.append(f"logits.shape={self.logits.shape}")
        if self.hidden_states is not None: parts.append(f"hidden_states.len={len(self.hidden_states)}")
        if self.attentions is not None: parts.append(f"attentions.len={len(self.attentions)}")
        return f"ImageClassifierOutput({', '.join(parts)})"


class VideoMAEModel(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.config = config
        self.embeddings = VideoMAEEmbeddings(config)
        self.encoder = VideoMAEEncoder(config)

        # self.fc_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
        if config.use_mean_pooling:
            self.layernorm = None
        else:
            self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # # Initialize weights and apply final processing
        # self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):



        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        # head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(pixel_values, bool_masked_pos)

        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        sequence_output = encoder_outputs.last_hidden_state
        if self.layernorm is not None:
            sequence_output = self.layernorm(sequence_output)

        if not return_dict:
            return (sequence_output,) + encoder_outputs[1:]

        return BaseModelOutput(
            last_hidden_state=sequence_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )



class VideoMAEForVideoClassification(nn.Module):
    def __init__(self, config: VideoMAEConfig):
        super().__init__()
        self.config = config # Store config

        if config.num_labels is None:
            raise ValueError("config.num_labels must be set, e.g., from len(config.id2label).")
        self.num_labels = config.num_labels

        # Use VideoMAEModel as the base, rename it to self.videomae
        self.videomae = VideoMAEModel(config)

        self.classifier = nn.Linear(config.hidden_size, self.num_labels) if self.num_labels > 0 else nn.Identity()
        if config.use_mean_pooling:
            self.fc_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        else:
            self.fc_norm = None
        self._init_weights()

    def _init_weights(self, module: Optional[nn.Module] = None):
        if module is None: module = self
        if isinstance(module, (nn.Linear, nn.Conv3d)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None: module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        for child in module.children():
            if child is not module: self._init_weights(child)

    def forward(
        self, pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[List[Optional[torch.Tensor]]] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ):
        # Use the renamed self.videomae
        outputs = self.videomae(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True, # Ensure we get a dict-like output
        )
        sequence_output = outputs.last_hidden_state


        if self.fc_norm is not None: # Corresponds to use_mean_pooling = True
            pooled_output = sequence_output.mean(dim=1)
            sequence_output = self.fc_norm(pooled_output)
        else:
            sequence_output = sequence_output[:, 0] # Assumes CLS token or similar logic

        logits = self.classifier(sequence_output)
        loss = None
        if labels is not None:
            current_problem_type = self.config.problem_type
            if current_problem_type is None:
                if self.num_labels == 1: current_problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    current_problem_type = "single_label_classification"
                else: current_problem_type = "multi_label_classification"

            if current_problem_type == "regression":
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.squeeze(), labels.squeeze().float()) if self.num_labels == 1 else loss_fct(logits, labels.float())
            elif current_problem_type == "single_label_classification":
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif current_problem_type == "multi_label_classification":
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits, labels.float())

        hidden_s = outputs.hidden_states
        attns = outputs.attentions


        return ImageClassifierOutput(
            loss=loss, logits=logits, hidden_states=hidden_s, attentions=attns,
        )

from safetensors.torch import load_file
config_json_str = """
{
    "architectures": ["VideoMAEForVideoClassification"],
    "attention_probs_dropout_prob": 0.0,
    "decoder_hidden_size": 384,
    "decoder_intermediate_size": 1536,
    "decoder_num_attention_heads": 6,
    "decoder_num_hidden_layers": 4,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.0,
    "hidden_size": 768,
    "id2label": {
        "0": "drink water", "1": "brush teeth", "2": "pick up", "3": "reading", "4": "writing",
        "5": "cheer up", "6": "jump up", "7": "phone call", "8": "taking a selfie", "9": "salute"
    },
    "image_size": 224,
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "label2id": {
        "drink water": 0, "brush teeth": 1, "pick up": 2, "reading": 3, "writing": 4,
        "cheer up": 5, "jump up": 6, "phone call": 7, "taking a selfie": 8, "salute": 9
    },
    "layer_norm_eps": 1e-12,
    "model_type": "videomae",
    "norm_pix_loss": false,
    "num_attention_heads": 12,
    "num_channels": 3,
    "num_frames": 16,
    "num_hidden_layers": 12,
    "patch_size": 16,
    "qkv_bias": true,
    "torch_dtype": "float32",
    "transformers_version": "4.21.0.dev0",
    "tubelet_size": 2,
    "use_mean_pooling": true
}
"""
config_dict = json.loads(config_json_str)

# For testing, reduce layers for speed
# config_dict["num_hidden_layers"] = 2
# config_dict["hidden_size"] = 128
# config_dict["intermediate_size"] = 256
# config_dict["num_attention_heads"] = 4

device = "cuda" if torch.cuda.is_available() else "cpu"
mae_config = VideoMAEConfig.from_dict(config_dict)

batch_size = 2
# Model expects (B, NumFrames, NumChannels, H, W)
video_input = torch.randn(batch_size, mae_config.num_frames, mae_config.num_channels,
                          mae_config.image_size, mae_config.image_size)
# dummy_labels = torch.randint(0, mae_config.num_labels, (batch_size,))

print(f"Input video shape: {video_input.shape}")
print(f"Number of labels from config: {mae_config.num_labels}")

model = VideoMAEForVideoClassification(mae_config) # Instantiate the model

model = model.to(device)
video_input = video_input.to(device)
with torch.no_grad():
    outputs = model(video_input)

print(f"\nOutput type: {type(outputs)}")
print(outputs)

if outputs.loss is not None:
    print(f"Calculated loss: {outputs.loss.item()}")

if outputs.hidden_states:
    print(f"Number of hidden_states layers: {len(outputs.hidden_states)}")
    print(f"Shape of first hidden_state: {outputs.hidden_states[0].shape}")
    print(f"Shape of last hidden_state (from encoder): {outputs.hidden_states[-1].shape}")
if outputs.attentions:
    print(f"Number of attention layers: {len(outputs.attentions)}")
    # Note: Attention output shape from eager_attention_forward is (B, H, N, N)
    # After VideoMAESelfAttention, it's still this shape when output_attentions=True
    print(f"Shape of first attention weights: {outputs.attentions[0].shape}")


total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
