In [None]:
!pip install -q mamba-ssm causal-conv1d>=1.2.0
!pip install -q vllm>=0.5.5
!pip install -q accelerate
!pip install -q transformers mamba-ssm
!pip install -q convokit

In [None]:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
from torch.utils.data import Dataset
import random
import numpy as np
import os
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from convokit import Corpus, download

**MAMBA MENTALITY**

Takes into account also the whole idea of continuing to speak or not, switching who's speaking, and ending the sccene/movie. Along with the MLP for actually what would be said.

In [None]:
# Load models
model_full = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-370m").to("cuda")
model_character = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-370m").to("cuda")

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

# Combined model with speaker transition and end-of-movie classifier
class CombinedModel(nn.Module):
    def __init__(self, model_full, model_character, embedding_dim=768, hidden_dim=512):
        super().__init__()
        self.model_full = model_full
        self.model_character = model_character
        self.embedding = nn.Embedding(embedding_dim, embedding_dim)

        # MLP layers for text prediction
        self.fc1 = nn.Linear(embedding_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, tokenizer.vocab_size)  # Output layer for next token prediction

        # Turn-taking prediction (binary classification: 0 = continue, 1 = stop)
        self.turn_fc = nn.Linear(hidden_dim, 2)

        # Speaker transition prediction (categorical classification for speaker_id)
        self.speaker_fc = nn.Linear(hidden_dim, 3)  # Assume 3 possible speakers as an example

        # End-of-movie prediction (binary classification: 0 = not end, 1 = end)
        self.end_fc = nn.Linear(hidden_dim, 2)  # 2 classes: continue or end movie

    def forward(self, input_full, input_character):
        # Pass through models
        full_output = self.model_full(input_full)["logits"]
        character_output = self.model_character(input_character)["logits"]

        # Apply character mask to char_output (broadcast mask to match vocab dim)
        character_mask = character_mask.unsqueeze(-1).expand_as(character_output)
        masked_character_output = character_output * character_mask  # zeros out other tokens

        # Combine the outputs (concatenate)
        combined_output = torch.cat((full_output, masked_character_output), dim=-1)

        # Pass through embedding
        embedded = self.embedding(combined_output)

        # Pass through MLP for token prediction
        x = torch.relu(self.fc1(embedded))
        token_predictions = self.fc2(x)

        # Predict speech turn-taking: whether character stops talking
        turn_predictions = self.turn_fc(x)

        # Predict speaker transition: who starts speaking next
        speaker_predictions = self.speaker_fc(x)

        # Predict end of movie: whether the conversation ends
        end_predictions = self.end_fc(x)

        return token_predictions, turn_predictions, speaker_predictions, end_predictions


def combined_loss(token_predictions, turn_predictions, speaker_predictions, end_predictions, target_tokens, target_turns, target_speakers, target_end):
    token_loss = nn.CrossEntropyLoss()(token_predictions.view(-1, token_predictions.size(-1)), target_tokens.view(-1))
    turn_loss = nn.CrossEntropyLoss()(turn_predictions.view(-1, 2), target_turns.view(-1))
    speaker_loss = nn.CrossEntropyLoss()(speaker_predictions.view(-1, speaker_predictions.size(-1)), target_speakers.view(-1))
    end_loss = nn.CrossEntropyLoss()(end_predictions.view(-1, 2), target_end.view(-1))

    return token_loss + turn_loss + speaker_loss + end_loss


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

class MovieDataset(Dataset):
    def __init__(self, dialogues, speakers, turns, end, tokenizer, max_length=512):
        self.dialogues = dialogues
        self.speakers = speakers
        self.turns = turns
        self.end = end
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dialogues)

    def __getitem__(self, idx):
        dialogue = self.dialogues[idx]
        speaker = self.speakers[idx]
        turn = self.turns[idx]
        end = self.end[idx]

        # Tokenize the dialogue
        inputs = self.tokenizer(dialogue, return_tensors="pt", truncation=True, padding=True, max_length=self.max_length)

        # Convert tokenized data to tensors
        input_ids = inputs['input_ids'].squeeze(0)  # Remove the batch dimension
        attention_mask = inputs['attention_mask'].squeeze(0)  # Same size as input_ids

        # Generate token-level speaker IDs based on original speaker list
        # For simplicity, assume each dialogue has a single speaker for the entire token sequence
        # If needed, this can be refined by tokenizing per sentence and then mapping speaker labels.
        speakers_per_token = [self.speakers[idx]] * input_ids.size(0)  # Each token has the same speaker as the dialogue

        # Convert speakers list to tensor
        token_speaker_ids = torch.tensor(speakers_per_token)

        # Create character mask for the given character (assuming you're interested in a specific speaker)
        character_mask = (token_speaker_ids == speaker).long()  # Mask for the target character speaking

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'speaker': speaker,
            'turn': turn,
            'end': end,
            'character_mask': character_mask  # Add the character mask to the batch
        }

# Example dataset and dataloader
dialogues = ["Hello, how are you?", "I'm good, thanks! How about you?"]
speakers = [0, 1]  # Assume character 0 speaks first, then character 1
turns = [0, 1]  # Speaker 0 continues, then speaker 1 starts
end = [0, 1]  # Movie continues after first dialogue, ends after second dialogue

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
dataset = MovieDataset(dialogues, speakers, turns, end, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Initialize model
model = CombinedModel(model_full, model_character).to("cuda")

# Optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# To track metrics
train_losses, train_ppls = [], []
val_losses, val_ppls = [], []

epochs = 3  # or more
for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    total_train_token_loss = 0
    total_train_ppl = 0
    num_batches = len(train_dataloader)

    for batch in train_dataloader:
        # Move to GPU
        input_ids = batch['input_ids'].to("cuda")
        speaker_labels = batch['speaker'].to("cuda")
        turn_labels = batch['turn'].to("cuda")
        end_labels = batch['end'].to("cuda")

        # Generate character mask
        character_mask = batch['speaker'].unsqueeze(-1).to("cuda")  # Shape (batch_size, 1)

        # Forward pass
        token_preds, turn_preds, speaker_preds, end_preds = model(input_ids, input_ids, character_mask)

        # Losses
        token_loss = criterion(token_preds.view(-1, token_preds.size(-1)), input_ids.view(-1))
        turn_loss = criterion(turn_preds.view(-1, 2), turn_labels.view(-1))

        turn_mask = (turn_labels == 1)
        if turn_mask.sum() > 0:
            speaker_loss = criterion(
                speaker_preds[turn_mask].view(-1, speaker_preds.size(-1)),
                speaker_labels[turn_mask].view(-1)
            )
        else:
            speaker_loss = 0.0

        end_loss = criterion(end_preds.view(-1, 2), end_labels.view(-1))

        # Total loss
        loss = token_loss + turn_loss + speaker_loss + end_loss
        total_train_loss += loss.item()
        total_train_token_loss += token_loss.item()

        with torch.no_grad():
            total_train_ppl += torch.exp(token_loss).item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Average train loss and perplexity
    avg_train_loss = total_train_loss / num_batches
    avg_train_ppl = total_train_ppl / num_batches
    train_losses.append(avg_train_loss)
    train_ppls.append(avg_train_ppl)

    # Validation Perplexity scores
    model.eval()
    total_val_loss = 0
    total_val_token_loss = 0
    total_val_ppl = 0
    num_val_batches = len(val_dataloader)

    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to("cuda")
            speaker_labels = batch['speaker'].to("cuda")
            turn_labels = batch['turn'].to("cuda")
            end_labels = batch['end'].to("cuda")

            # Generate character mask
            character_mask = batch['speaker'].unsqueeze(-1).to("cuda")  # Shape (batch_size, 1)

            token_preds, turn_preds, speaker_preds, end_preds = model(input_ids, input_ids, character_mask)

            token_loss = criterion(token_preds.view(-1, token_preds.size(-1)), input_ids.view(-1))
            turn_loss = criterion(turn_preds.view(-1, 2), turn_labels.view(-1))

            turn_mask = (turn_labels == 1)
            if turn_mask.sum() > 0:
                speaker_loss = criterion(
                    speaker_preds[turn_mask].view(-1, speaker_preds.size(-1)),
                    speaker_labels[turn_mask].view(-1)
                )
            else:
                speaker_loss = 0.0

            end_loss = criterion(end_preds.view(-1, 2), end_labels.view(-1))

            val_loss = token_loss + turn_loss + speaker_loss + end_loss
            total_val_loss += val_loss.item()
            total_val_token_loss += token_loss.item()
            total_val_ppl += torch.exp(token_loss).item()

    avg_val_loss = total_val_loss / num_val_batches
    avg_val_ppl = total_val_ppl / num_val_batches
    val_losses.append(avg_val_loss)
    val_ppls.append(avg_val_ppl)

    print(f"Epoch {epoch+1}/{epochs} | "
          f"Train Loss: {avg_train_loss:.4f} | Train PPL: {avg_train_ppl:.2f} || "
          f"Val Loss: {avg_val_loss:.4f} | Val PPL: {avg_val_ppl:.2f}")

In [None]:
# Plotting our Train/Test Perplexity scores
plt.figure(figsize=(10, 5))
plt.plot(train_ppls, label="Train Perplexity")
plt.plot(val_ppls, label="Val Perplexity")
plt.xlabel("Epoch")
plt.ylabel("Perplexity")
plt.title("Train vs Validation Perplexity")
plt.legend()
plt.grid(True)
plt.show()