##### Use this notebook to understand how the Attention layers work in the video generation module

In [None]:
# Necessary imports
import torch
import torch.nn as nn
from einops import rearrange

In [None]:
# Spatial Attention
# Spatial attention is where, we have apply the attention map 
# across the height and width dimensions
class SpatialAttention(nn.Module):
    """
    Spatial Attention
    """
    def __init__(self, channels, num_heads:int = 8):
        super().__init__()
        assert channels%num_heads == 0
        self.mha = nn.MultiheadAttention(
            embed_dim=channels,
            num_heads=num_heads,
            batch_first=True
        )
    
    def forward(self, x: torch.Tensor):
        """
        Forward pass
        x: (B, C, H, W)
        """
        _, _, H, W = x.shape
        # Reshape to (B, H*W, C)
        print(f"\n The inputs are of shape \n {x.shape}")
        x = rearrange(x, "b c h w -> b (h w) c")
        print(f"\n The inputs to the attention are \n {x.shape}")
        # Apply attention
        x, x_attn_weights = self.mha(x, x, x)
        print(f"\n The output of the attention network is \n {x.shape}")
        print(f"\n The attention weights are of shape \n {x_attn_weights.shape}")
        x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
        print(f"\n The final shape after Spatial attention is \n {x.shape}")
        return x

# Test the spatial attention code
B, C, H, W = 16, 512, 32, 32
x = torch.randn(B, C, H, W)
spatial_attn = SpatialAttention(C)
out = spatial_attn(x)

In [None]:
class TemporalAttention(nn.Module):
    """
    Module for temporal attention
    """
    def __init__(self, channels: int, num_heads: int = 8):
        """
        Temporal self-attention over T dimension
        Input: (B, C, T, H, W)
        Output: (B, C, T, H, W)
        """
        super().__init__()
        self.mha = nn.MultiheadAttention(
            embed_dim=channels,
            num_heads=num_heads,
            batch_first=True
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass
        x: (B, C, T, H, W)
        """
        B, _, _, H, _ = x.shape
        print(f"\n The shape of the inputs going into the temporal attention module is \n {x.shape}")
        # Reshape to (B*H*W, T, C)
        x = rearrange(x, "b c t h w -> (b h w) t c")
        print(f"\n Prior to actual application of the attention mechanism, the shape of the inputs is \n {x.shape}")
        # Apply attention
        x, x_attn_weights = self.mha(x, x, x)
        # Reshape back to (B, C, T, H, W)
        print(f"\n Right after the attention mechanism, the outputs are of shape \n {x.shape}")
        print(f"\n The attention weights are of shape \n {x_attn_weights.shape}")
        x = rearrange(x, "(b h w) t c -> b c t h w", b=B, h=H)
        print(f"\n The outputs are reshaped to \n {x.shape}")
        return x
B, C, T, H, W = 2, 512, 16, 32, 32
x = torch.randint(0, 256, (B, C, T, H, W)).float()
temp_attention_module = TemporalAttention(C)
x_temp = temp_attention_module(x)
assert x_temp.shape == (B, C, T, H, W)