In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TemporalAttention(nn.Module):
    def __init__(self, hidden_dim, attention_dim):
        super(TemporalAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, attention_dim),
            nn.Tanh(),
            nn.Linear(attention_dim, 1)
        )
        
    def forward(self, hidden_states):
        # hidden_states shape: (batch_size, seq_len, hidden_dim)
        attention_weights = self.attention(hidden_states)  # (batch_size, seq_len, 1)
        attention_weights = F.softmax(attention_weights, dim=1)  # (batch_size, seq_len, 1)
        attended = torch.bmm(hidden_states.transpose(1, 2), attention_weights)  # (batch_size, hidden_dim, 1)
        attended = attended.squeeze(-1)  # (batch_size, hidden_dim)
        return attended, attention_weights

class TALSTM(nn.Module):
    def __init__(self):
        super(TALSTM, self).__init__()
        
        # Dimensions for the model
        self.video_feature_dim = 32  # Adjusted based on your conv output
        self.audio_feature_dim = 256  # Adjusted based on your conv output
        self.lstm_hidden_dim = 256
        self.attention_dim = 128
        
        # Video Processing Branch
        self.video_conv = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2)),
            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d((2, 2, 2))
        )
        
        # Audio Processing Branch
        self.audio_conv = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )
        
        # LSTM and Attention Layers
        self.video_lstm = nn.LSTM(
            self.video_feature_dim, 
            self.lstm_hidden_dim, 
            batch_first=True, 
            bidirectional=True
        )
        self.video_attention = TemporalAttention(
            self.lstm_hidden_dim * 2, 
            self.attention_dim
        )
        
        self.audio_lstm = nn.LSTM(
            self.audio_feature_dim, 
            self.lstm_hidden_dim, 
            batch_first=True, 
            bidirectional=True
        )
        self.audio_attention = TemporalAttention(
            self.lstm_hidden_dim * 2, 
            self.attention_dim
        )
        
        # Fusion and Output Layers
        self.fusion = nn.Sequential(
            nn.Linear(self.lstm_hidden_dim * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        self.output_layer = nn.Linear(256, 1)  # Binary classification

    def forward(self, video, audio):
        # Process video through CNN
        batch_size = video.size(0)
        video_features = self.video_conv(video)
        video_features = video_features.view(batch_size, self.video_feature_dim, -1)
        video_features = video_features.transpose(1, 2)  # (batch_size, seq_len, features)
        
        # Process audio through CNN
        audio_features = self.audio_conv(audio)
        audio_features = audio_features.transpose(1, 2)  # (batch_size, seq_len, features)
        
        # Process through LSTM and Attention
        video_lstm_out, _ = self.video_lstm(video_features)
        video_attended, video_attention = self.video_attention(video_lstm_out)
        
        audio_lstm_out, _ = self.audio_lstm(audio_features)
        audio_attended, audio_attention = self.audio_attention(audio_lstm_out)
        
        # Fuse features
        fused_features = torch.cat([video_attended, audio_attended], dim=1)
        fused_features = self.fusion(fused_features)
        
        # Final output
        output = self.output_layer(fused_features)
        
        return output, video_attention, audio_attention

# Modified training function to handle attention outputs
def train_talstm(model, train_loader, val_loader, num_epochs=30, device="cpu"):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.BCEWithLogitsLoss()
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0.0
        train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for videos, audios, labels in train_progress:
            videos = videos.to(device)
            audios = audios.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs, video_attention, audio_attention = model(videos, audios)
            outputs = outputs.squeeze()
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            train_progress.set_postfix({'loss': loss.item()})
        
        # Validation phase
        model.eval()
        total_val_loss = 0.0
        val_attention_video = []
        val_attention_audio = []
        
        with torch.no_grad():
            val_progress = tqdm(val_loader, desc="Validation")
            for videos, audios, labels in val_progress:
                videos = videos.to(device)
                audios = audios.to(device)
                labels = labels.to(device)
                
                outputs, video_attention, audio_attention = model(videos, audios)
                outputs = outputs.squeeze()
                loss = criterion(outputs, labels)
                
                total_val_loss += loss.item()
                val_progress.set_postfix({'loss': loss.item()})
                
                # Store attention weights for analysis
                val_attention_video.append(video_attention.cpu().numpy())
                val_attention_audio.append(audio_attention.cpu().numpy())
        
        print(f"Epoch {epoch + 1}/{num_epochs}, "
              f"Train Loss: {total_train_loss / len(train_loader):.4f}, "
              f"Val Loss: {total_val_loss / len(val_loader):.4f}")
    
    return model, val_attention_video, val_attention_audio

# Usage example
def initialize_and_train():
    # Initialize model
    model = TALSTM()
    model = model.to(device)
    
    # Train model
    trained_model, video_attention, audio_attention = train_talstm(
        model, 
        train_loader, 
        val_loader, 
        num_epochs=30, 
        device=device
    )
    
    return trained_model, video_attention, audio_attention

KeyboardInterrupt: 

In [None]:
# Training loop
num_epochs = 10
criterion = nn.CrossEntropyLoss()  # For phoneme classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for video, audio, phoneme_labels in train_loader:  # Load synced video/audio pairs
        video, audio, phoneme_labels = video.to(device), audio.to(device), phoneme_labels.to(device)
        
        optimizer.zero_grad()

        # Forward pass
        outputs, _, _ = model(video, audio)
        
        # Compute loss
        loss = criterion(outputs, phoneme_labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")


In [None]:
# Testing loop with mismatched pairs
model.eval()
mismatched_results = []

with torch.no_grad():
    for video, audio, _ in test_loader:
        video, audio = video.to(device), audio.to(device)

        # Create mismatched pairs
        shuffled_audio = audio[torch.randperm(audio.size(0))]  # Random shuffle of audio embeddings

        # Get predictions
        matched_outputs, _, _ = model(video, audio)  # Matched pair
        mismatched_outputs, _, _ = model(video, shuffled_audio)  # Mismatched pair

        # Store predictions
        matched_predictions = torch.argmax(matched_outputs, dim=1).cpu().numpy()
        mismatched_predictions = torch.argmax(mismatched_outputs, dim=1).cpu().numpy()
        
        mismatched_results.append((matched_predictions, mismatched_predictions))

# Analyze results
for idx, (matched, mismatched) in enumerate(mismatched_results):
    print(f"Sample {idx}: Matched: {matched}, Mismatched: {mismatched}")


In [None]:
import matplotlib.pyplot as plt

# Plot attention weights
def plot_attention(attention_weights, title):
    plt.figure(figsize=(10, 5))
    plt.imshow(attention_weights.cpu().numpy(), cmap='viridis', aspect='auto')
    plt.colorbar()
    plt.title(title)
    plt.xlabel('Time Steps')
    plt.ylabel('Batch Index')
    plt.show()

# Example for one batch
_, video_attention_weights, audio_attention_weights = model(video, shuffled_audio)
plot_attention(video_attention_weights[0], "Video Attention (Mismatched)")
plot_attention(audio_attention_weights[0], "Audio Attention (Mismatched)")
