In [None]:
import torch
import torch.nn as nn
import math
import IPython.display as ipd

from transformers import EncodecModel, AutoProcessor
import librosa

## Code Explanation: Autoregressive Model with Positional Encoding and Cross-Attention Fusion

### 1. **PositionalEncoding Class**
   This class handles the creation and application of sinusoidal positional embeddings. The primary function of positional encodings is to inject information about the relative or absolute position of the tokens in the sequence, as transformers are invariant to sequence order.

   - **`__init__`**: Initializes the positional encoding matrix.
     - `d_model`: Dimension of the embedding space (the number of features each token will have).
     - `max_len`: Maximum sequence length supported (this is set to 5000 by default).
     - We create a matrix `pe` filled with zeros of shape `(max_len, d_model)`.
     - `position`: A tensor representing positions (0, 1, 2, ..., max_len).
     - `div_term`: A term used for scaling based on the dimensions of the model.
     - For even positions, we assign sine values (`sin(position * div_term)`), and for odd positions, we assign cosine values (`cos(position * div_term)`).
     - The positional encoding matrix is registered as a buffer, meaning it will be stored in the model but won’t require gradients.
  
   - **`forward`**: Adds the positional encodings to the input tensor `x` and returns the result. The input tensor is assumed to be of shape `(batch_size, seq_len, d_model)`. The positional encodings are added to the input embeddings to make the model aware of token positions.

### 2. **AutoregressiveBranch Class**
   This class represents a branch of the model that processes the input sequence and produces an output based on short-term or long-term context. It uses the transformer encoder layer for sequence modeling.

   - **`__init__`**: Initializes the components for the branch:
     - `input_dim`: Dimension of the input features (e.g., the size of the video features).
     - `hidden_dim`: Dimension of the hidden representations.
     - `num_layers`: Number of transformer encoder layers.
     - `window_size`: Defines how many past time steps are considered when processing the input.
     - The branch uses the `PositionalEncoding` to inject positional information into the input.
     - `nn.TransformerEncoderLayer`: Defines the transformer encoder layer.
     - `nn.TransformerEncoder`: A stack of transformer encoder layers.
     - `nn.Linear`: A fully connected layer to transform the output from the transformer into the desired `hidden_dim`.

   - **`forward`**: 
     - Takes the input `x`, slices it to use only the last `window_size` tokens (to model short-term memory).
     - Applies positional encoding to the input.
     - Passes the input through the transformer encoder.
     - Uses the last hidden state (the final time step) and applies a fully connected layer to produce the output.

### 3. **CrossAttentionFusion Class**
   This class is responsible for fusing short-term and long-term context using a cross-attention mechanism. Cross-attention allows us to focus on the most relevant parts of the input sequence (from the short-term and long-term outputs) when producing the final representation.

   - **`__init__`**: Initializes a multihead attention mechanism (`nn.MultiheadAttention`) with the specified dimensions and number of heads.
     - The fully connected layer `fc` is used to project the output of the attention mechanism into the desired hidden dimension.
   
   - **`forward`**: 
     - Takes the short-term and long-term outputs, applies multi-head attention (short-term is the query, long-term is the key and value).
     - The output from the attention mechanism is passed through a fully connected layer for further processing.

### 4. **MusicDecoder Class**
   This class handles the final step, where the model decodes the fused features into a sequence of music tokens.

   - **`__init__`**: Initializes the transformer decoder layer and the fully connected output layer.
     - `input_dim`: The dimensionality of the fused input.
     - `output_dim`: The number of possible output tokens (this is equivalent to the size of the music token vocabulary).
     - `num_layers`: The number of transformer decoder layers.
     - `num_heads`: Number of heads in multi-head attention.
     - `hidden_dim`: The dimension of the hidden layers.
   
   - **`forward`**:
     - The input is permuted to shape `(seq_len, batch_size, input_dim)` for compatibility with the transformer decoder.
     - The transformer decoder processes the input sequence and outputs a sequence of decoded tokens.
     - The final output is passed through a fully connected layer to map the decoded sequence to the music token vocabulary size.

### 5. **AutoregressiveModel Class**
   This is the main model, combining all the previously defined components. It takes the video input and processes it through short-term and long-term branches, fuses the results, and decodes it into music tokens.

   - **`__init__`**: Initializes the short-term and long-term branches, the fusion mechanism, and the music decoder.
     - The short-term and long-term branches are responsible for capturing short-term and long-term dependencies from the input video sequence.
     - The fusion mechanism combines the results from these two branches using cross-attention.
     - The music decoder generates a sequence of music tokens from the fused features.
   
   - **`forward`**:
     - Takes the video input and processes it through the short-term and long-term branches.
     - The results are fused using the cross-attention mechanism.
     - The fused features are passed to the music decoder to generate a sequence of music tokens.
     - The final output is a sequence of music tokens with shape `[batch_size, music_seq_len, codebook_vocab]`.

### Final Architecture Overview:
- The model consists of multiple components, each specialized for a specific task:
  - **Positional Encoding**: Adds positional information to the input.
  - **Autoregressive Branch**: Processes the input using transformer encoders to capture temporal dependencies.
  - **Cross-Attention Fusion**: Combines the outputs from short-term and long-term branches using attention.
  - **Music Decoder**: Decodes the fused output into a sequence of music tokens.

This architecture allows the model to capture both short-term and long-term dependencies in the input (video features) and generate a sequence of music tokens in an autoregressive manner.


In [97]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        # pe: Shape = (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # position: Shape = (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # div_term: Shape = (d_model/2,)
        pe[:, 0::2] = torch.sin(position * div_term)  # Shape = (max_len, d_model/2)
        pe[:, 1::2] = torch.cos(position * div_term)  # Shape = (max_len, d_model/2)
        self.register_buffer('pe', pe.unsqueeze(0))  # pe: Shape = (1, max_len, d_model)
    
    def forward(self, x):
        # x: Shape = (batch_size, seq_len, d_model)
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]  # Output: Shape = (batch_size, seq_len, d_model)

    

class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_layers):
        super(TransformerEncoder, self).__init__()
        
        # Define the transformer encoder layers
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        
        # Linear projection to map input features to transformer input size (hidden_dim)
        self.input_projection = nn.Linear(input_dim, hidden_dim)
        
    def forward(self, x):
        """
        x: Input shape [batch_size, window_size, input_dim]
        The shape needs to be (batch_size, seq_len, feature_dim) for transformer encoder.
        """
        # Project input features to hidden_dim
        x = self.input_projection(x)  # Shape: [batch_size, window_size, hidden_dim]
        
        # Pass through the transformer encoder (transposed to match shape requirements)
        x = x.permute(1, 0, 2)  # Shape: [window_size, batch_size, hidden_dim]
        x = self.transformer_encoder(x)
        x = x.permute(1, 0, 2)  # Revert to [batch_size, window_size, hidden_dim]
        
        return x

    
class SlidingWindowMusicPredictor(nn.Module):
    def __init__(self,
                 vid_input_dim,
                 window_size,         # Length of the sliding window Ls
                 overlap,             # Overlap between windows O
                 hidden_dim,          # Hidden dimension for the transformer
                 num_heads,           # Number of attention heads for the Transformer
                 num_layers):         # Number of layers in the Transformer
        super(SlidingWindowMusicPredictor, self).__init__()
        
        self.window_size = window_size
        self.overlap = overlap
        
        # Transformer Encoder for long-term dependencies
        self.long_term_model = TransformerEncoder(input_dim=1024, hidden_dim=hidden_dim, 
                                                  num_heads=num_heads, num_layers=num_layers)
                
    def forward(self, visual_features):
        """
        visual_features shape: [batch_size, seq_len, feature_dim]
        - batch_size: Number of videos in a batch.
        - seq_len: Number of frames in the video (sequence length).
        - feature_dim: Number of features for each frame (e.g., output dimension of the visual encoder).
        """
        batch_size, seq_len, feature_dim = visual_features.shape
        
        # Initialize t (window start position)
        t = 0
        predictions = []
        
        while t + self.window_size <= seq_len:
            # Step 2: Extract features within the current sliding window [t, t + Ls]
            window_features = visual_features[:, t:t + self.window_size, :]
            
            # Step 3: Capture long-term dependencies using the Transformer Encoder
            # Pass the window features through the long-term Transformer encoder model
            transformer_output = self.long_term_model(window_features)  # Shape: [batch_size, window_size, hidden_dim]
            
            # Optionally, use the last output of the transformer (or the average)
            # To summarize the window output into a single vector, use the last token or pooling
            predictions.append(transformer_output[:, -1, :])  # Using the last token in the window
        
            # Move the window forward by setting t = t + Ls - O
            t = t + self.window_size - self.overlap
        
        # Concatenate all predictions and take the mean or last token
        # The output will have shape [batch_size, hidden_dim]
        all_predictions = torch.stack(predictions, dim=1)  # Shape: [batch_size, num_windows, hidden_dim]
        
        # Optionally, aggregate across all windows (mean, max, or just the last one)
        # Here, we will take the mean across all windows as an example
        final_representation = all_predictions.mean(dim=1)  # Shape: [batch_size, hidden_dim]
        
        return final_representation


class CrossAttentionFusion(nn.Module):
    def __init__(self, vid_input_dim, hidden_dim, num_heads=8):
        super(CrossAttentionFusion, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, short_term_output, long_term_output):
        # short_term_output: Shape = (batch_size, hidden_dim)
        # long_term_output: Shape = (batch_size, hidden_dim)
        attn_output, _ = self.attention(short_term_output.unsqueeze(1), long_term_output.unsqueeze(1), long_term_output.unsqueeze(1))
        # attn_output: Shape = (batch_size, 1, hidden_dim)
        fused_output = self.fc(attn_output.squeeze(1))  # fused_output: Shape = (batch_size, hidden_dim)
        return fused_output


class MusicDecoder(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers=2, num_heads=8, hidden_dim=512):
        super(MusicDecoder, self).__init__()
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=input_dim, nhead=num_heads)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        # x: Shape = (batch_size, music_seq_len, input_dim)
        x = x.permute(1, 0, 2)  # Shape = (music_seq_len, batch_size, input_dim) [for transformer]
        decoded_output = self.transformer_decoder(x, x)  # decoded_output: Shape = (music_seq_len, batch_size, input_dim)
        output_tokens = self.fc(decoded_output.permute(1, 0, 2))  # output_tokens: Shape = (batch_size, music_seq_len, output_dim)
        return output_tokens


class AutoregressiveModel(nn.Module):
    def __init__(self, vid_input_dim=1024, music_seq_len=1500, codebook_vocab=512, hidden_dim=512, num_layers=2, short_window=1, long_window=5, num_heads=8):
        super(AutoregressiveModel, self).__init__()
        self.short_term_branch = SlidingWindowMusicPredictor(vid_input_dim, short_window, overlap=0, \
                 hidden_dim=hidden_dim, num_heads=num_heads,num_layers=num_layers)
        
        self.long_term_branch = SlidingWindowMusicPredictor(vid_input_dim, long_window, overlap=0, hidden_dim=hidden_dim, num_heads=num_heads,num_layers=num_layers)
        self.fusion = CrossAttentionFusion(hidden_dim, hidden_dim, num_heads)
        self.music_decoder = MusicDecoder(hidden_dim, codebook_vocab, num_layers, num_heads, hidden_dim)
        self.music_seq_len = music_seq_len
    
    def forward(self, x):
        # x: Shape = (batch_size, seq_len, vid_input_dim)
        short_term_out = self.short_term_branch(x)  # short_term_out: Shape = (batch_size, hidden_dim)
        long_term_out = self.long_term_branch(x)  # long_term_out: Shape = (batch_size, hidden_dim)
        fused_output = self.fusion(short_term_out, long_term_out)  # fused_output: Shape = (batch_size, hidden_dim)
        fused_output = fused_output.unsqueeze(1).repeat(1, self.music_seq_len, 1)  # Shape = (batch_size, music_seq_len, hidden_dim)
        music_output = self.music_decoder(fused_output)  # music_output: Shape = (batch_size, music_seq_len, codebook_vocab)
        return music_output


### Now Testing Encodec - we will try to GoT encoded tokens

In [98]:
# load the model + processor (for pre-processing the audio)
encodec = EncodecModel.from_pretrained("facebook/encodec_24khz")
codebook_size = encodec.config.codebook_size

processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")

file_path = "/Users/scottmerrill/Documents/UNC/MultiModal/VMR/notebooks/9bZkp7q19f0.mp3"
audio_sample, sample_rate = librosa.load(file_path, sr=processor.sampling_rate, duration=10)



In [99]:
# pre-process the audio inputs
inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")

# explicitly encode then decode the audio inputs
encoder_outputs = encodec.encode(inputs["input_values"], inputs["padding_mask"])

# pass the encoder outputs to the decoder to get the compressed waveform
audio_values = encodec.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0]

In [20]:
ipd.Audio(audio_sample, rate=processor.sampling_rate)

In [21]:
ipd.Audio(audio_values.detach().numpy()[0][0], rate=processor.sampling_rate)

### These output sequence lengths will be the same if we keep the input size fixed, which will simplify things

In [25]:
encoder_outputs.audio_codes.flatten().shape # so we will have 1502 autoregressive music tokens to predict

torch.Size([1502])

In [100]:
batch_size = 4
video_feature_dim = 1024
video_seq_length = 10
music_seq_len = 1502
codebook_vocab = codebook_size  

model = AutoregressiveModel(vid_input_dim=video_feature_dim, music_seq_len=music_seq_len, codebook_vocab=codebook_vocab)

x = torch.randn(batch_size, video_seq_length, video_feature_dim)

predictions = model(x)
print("Output shape:", output.shape)

Output shape: torch.Size([4, 1502, 1024])


## now need a CE loss function

In [101]:
# simulate 1 example in a batch
targets = torch.stack([encoder_outputs.audio_codes.flatten() for x in range(batch_size)])

targets = torch.stack([torch.nn.functional.one_hot(targets[i], num_classes=codebook_size).float() for i in range(len(targets))])

In [102]:
criterion = nn.CrossEntropyLoss()
loss = criterion(predictions, targets)
loss

tensor(10.7369, grad_fn=<DivBackward1>)