#### Use this notebook to understand how we add positional embeddings to the frames

In [None]:
import math
import torch
import torch.nn as nn

In [None]:
class TimeEmbedding(nn.Module):
    """
    Convert a time step into embedding
    """
    def __init__(self, dim: int = 1280):
        """
        Converts timestep t (integer) → embedding (dim,)
        """
        super().__init__()
        self.dim = dim
        
        # MLP: dim → 4*dim → dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
    
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        t: (B,) timestep indices [0, 1000]
        Returns: (B, dim) embeddings
        """
        pos_enc = self.sinusoidal_embedding(t) # Shape (B, dim)
        print(f"\n The positional encodings are of shape \n {pos_enc.shape}")
        return self.mlp(pos_enc)
    
    def sinusoidal_embedding(self, t: torch.Tensor) -> torch.Tensor:
        """
        t: (B,) → (B, dim) using sin/cos
    
        Formula:
        emb[i] = sin(t / 10000^(2i/dim)) if i even
            = cos(t / 10000^(2i/dim)) if i odd
        """
        device = t.device
        half_dim = self.dim // 2
        print(f"\n The original dimension is {self.dim} and \n the half dimension is {half_dim}")
        
        # Frequencies
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        print(f"\n The embeddings are of shape \n {emb.shape}")
        # Outer product
        emb = t[:, None] * emb[None, :]  # (B, half_dim)
        print(f"\n After taking the outer product of timesteps {t.shape}, the embeddings are of shape \n {emb.shape}")
        
        # Sin and cos
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # (B, dim)
        print(f"\n After taking the sine and cosine of the embeddings, the shape of the embedding tensor is \n {emb.shape}")
        
        return emb

In [None]:
batch_size = 4
dim = 1280

time_emb = TimeEmbedding(dim)
t = torch.randint(0, 1000, (batch_size,))

emb = time_emb(t)

assert emb.shape == (batch_size, dim)