In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
from scipy import io as sio
import sklearn.metrics as sm
from sklearn.svm import LinearSVC
import matplotlib.pyplot as plt
from collections import OrderedDict
import json

# TCN Model Implementations (inspired by Lea et al. research)

class TemporalConv1D(nn.Module):
    """Basic temporal convolution with causal padding"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, causal=True):
        super(TemporalConv1D, self).__init__()
        self.causal = causal
        if causal:
            # Causal padding: pad only on the left side
            self.padding = (kernel_size - 1) * dilation
        else:
            # Acausal padding: pad on both sides
            self.padding = (kernel_size - 1) * dilation // 2
            
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 
                             stride=stride, dilation=dilation, padding=0)
        
    def forward(self, x):
        if self.causal:
            x = F.pad(x, (self.padding, 0))
        else:
            pad_left = self.padding
            pad_right = self.padding
            x = F.pad(x, (pad_left, pad_right))
        return self.conv(x)

class DilatedTCN(nn.Module):
    """Dilated TCN based on WaveNet architecture (Lea et al.)"""
    def __init__(self, n_feat, n_classes, n_channels=64, n_layers=8, kernel_size=3, 
                 dropout=0.2, causal=True, activation='relu'):
        super(DilatedTCN, self).__init__()
        
        self.n_layers = n_layers
        self.causal = causal
        
        # Input projection
        self.input_conv = nn.Conv1d(n_feat, n_channels, 1)
        
        # Dilated convolution layers
        self.dilated_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        self.bn_layers = nn.ModuleList()
        
        for i in range(n_layers):
            dilation = 2 ** i
            
            # Dilated convolution
            dilated_conv = TemporalConv1D(n_channels, n_channels, kernel_size, 
                                        dilation=dilation, causal=causal)
            self.dilated_convs.append(dilated_conv)
            
            # Residual connection
            self.residual_convs.append(nn.Conv1d(n_channels, n_channels, 1))
            
            # Skip connection
            self.skip_convs.append(nn.Conv1d(n_channels, n_channels, 1))
            
            # Batch normalization
            self.bn_layers.append(nn.BatchNorm1d(n_channels))
        
        # Output layers
        self.output_conv1 = nn.Conv1d(n_channels, n_channels, 1)
        self.output_conv2 = nn.Conv1d(n_channels, n_classes, 1)
        self.dropout = nn.Dropout(dropout)
        
        # Activation function
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            self.activation = nn.ReLU()
    
    def forward(self, x):
        # Input: (batch_size, sequence_length, n_features)
        # Convert to (batch_size, n_features, sequence_length)
        x = x.transpose(1, 2)
        
        # Input projection
        x = self.input_conv(x)
        
        skip_connections = []
        
        for i in range(self.n_layers):
            # Store input for residual connection
            residual = x
            
            # Dilated convolution
            x = self.dilated_convs[i](x)
            x = self.bn_layers[i](x)
            x = self.activation(x)
            x = self.dropout(x)
            
            # Skip connection
            skip = self.skip_convs[i](x)
            skip_connections.append(skip)
            
            # Residual connection
            x = self.residual_convs[i](x) + residual
        
        # Combine skip connections
        x = sum(skip_connections)
        x = self.activation(x)
        
        # Output layers
        x = self.output_conv1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.output_conv2(x)
        
        # Convert back to (batch_size, sequence_length, n_classes)
        x = x.transpose(1, 2)
        
        return x

class ED_TCN(nn.Module):
    """Encoder-Decoder TCN (Lea et al.)"""
    def __init__(self, n_nodes, kernel_size, n_classes, n_feat, causal=True, 
                 activation='relu', dropout=0.2):
        super(ED_TCN, self).__init__()
        
        self.n_layers = len(n_nodes)
        self.causal = causal
        
        # Encoder
        self.encoder_convs = nn.ModuleList()
        self.encoder_bn = nn.ModuleList()
        
        in_channels = n_feat
        for i, out_channels in enumerate(n_nodes):
            conv = TemporalConv1D(in_channels, out_channels, kernel_size, causal=causal)
            self.encoder_convs.append(conv)
            self.encoder_bn.append(nn.BatchNorm1d(out_channels))
            in_channels = out_channels
        
        # Decoder
        self.decoder_convs = nn.ModuleList()
        self.decoder_bn = nn.ModuleList()
        
        for i in range(self.n_layers):
            if i == 0:
                in_channels = n_nodes[-1]
                out_channels = n_nodes[-1-i]
            else:
                in_channels = n_nodes[-i] * 2  # Skip connections double the channels
                out_channels = n_nodes[-1-i]
            
            conv = TemporalConv1D(in_channels, out_channels, kernel_size, causal=causal)
            self.decoder_convs.append(conv)
            self.decoder_bn.append(nn.BatchNorm1d(out_channels))
        
        # Output layer
        self.output_conv = nn.Conv1d(n_nodes[0], n_classes, 1)
        
        # Activation and dropout
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            self.activation = nn.ReLU()
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Input: (batch_size, sequence_length, n_features)
        # Convert to (batch_size, n_features, sequence_length)
        x = x.transpose(1, 2)
        
        # Encoder with skip connections
        encoder_outputs = []
        for i in range(self.n_layers):
            x = self.encoder_convs[i](x)
            x = self.encoder_bn[i](x)
            x = self.activation(x)
            x = self.dropout(x)
            encoder_outputs.append(x)
        
        # Decoder with skip connections
        for i in range(self.n_layers):
            if i > 0:
                # Add skip connection from encoder
                skip_idx = self.n_layers - 1 - i
                x = torch.cat([x, encoder_outputs[skip_idx]], dim=1)
            
            x = self.decoder_convs[i](x)
            x = self.decoder_bn[i](x)
            x = self.activation(x)
            x = self.dropout(x)
        
        # Output layer
        x = self.output_conv(x)
        
        # Convert back to (batch_size, sequence_length, n_classes)
        x = x.transpose(1, 2)
        
        return x

class tCNN(nn.Module):
    """Temporal CNN (Lea et al. ECCV 2016)"""
    def __init__(self, n_nodes, kernel_size, n_classes, n_feat, causal=True, dropout=0.2):
        super(tCNN, self).__init__()
        
        self.convs = nn.ModuleList()
        self.bn_layers = nn.ModuleList()
        
        in_channels = n_feat
        for out_channels in n_nodes:
            conv = TemporalConv1D(in_channels, out_channels, kernel_size, causal=causal)
            self.convs.append(conv)
            self.bn_layers.append(nn.BatchNorm1d(out_channels))
            in_channels = out_channels
        
        # Output layer
        self.output_conv = nn.Conv1d(n_nodes[-1], n_classes, 1)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        # Input: (batch_size, sequence_length, n_features)
        # Convert to (batch_size, n_features, sequence_length)
        x = x.transpose(1, 2)
        
        for i, (conv, bn) in enumerate(zip(self.convs, self.bn_layers)):
            x = conv(x)
            x = bn(x)
            x = self.activation(x)
            x = self.dropout(x)
        
        # Output layer
        x = self.output_conv(x)
        
        # Convert back to (batch_size, sequence_length, n_classes)
        x = x.transpose(1, 2)
        
        return x

# ============================================================================
# Video Feature Extraction
# ============================================================================

class VideoFeatureExtractor(nn.Module):
    """Extract spatial features from video frames"""
    def __init__(self, feature_dim=512):
        super(VideoFeatureExtractor, self).__init__()
        
        # Spatial CNN for frame-level features
        self.spatial_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        self.fc = nn.Linear(512, feature_dim)
        
    def forward(self, x):
        # x shape: (batch_size, sequence_length, channels, height, width)
        batch_size, seq_len, c, h, w = x.shape
        
        # Process each frame
        x = x.view(-1, c, h, w)
        features = self.spatial_cnn(x)
        features = features.view(features.size(0), -1)
        features = self.fc(features)
        
        # Reshape back to sequence format
        features = features.view(batch_size, seq_len, -1)
        return features

# ============================================================================
# Complete Model
# ============================================================================

class TCNEmotionClassifier(nn.Module):
    """Complete TCN-based emotion classifier"""
    def __init__(self, model_type="ED-TCN", n_feat=512, n_classes=6, n_nodes=[64, 96], 
                 kernel_size=3, causal=True, dropout=0.2, max_len=None):
        super(TCNEmotionClassifier, self).__init__()
        
        self.model_type = model_type
        self.n_classes = n_classes
        self.max_len = max_len
        
        # Feature extractor
        self.feature_extractor = VideoFeatureExtractor(n_feat)
        
        # Temporal model
        if model_type == "DilatedTCN":
            self.temporal_model = DilatedTCN(n_feat, n_classes, n_channels=n_nodes[0], 
                                           n_layers=8, kernel_size=kernel_size, 
                                           dropout=dropout, causal=causal)
        elif model_type == "ED-TCN":
            self.temporal_model = ED_TCN(n_nodes, kernel_size, n_classes, n_feat, 
                                       causal=causal, dropout=dropout)
        elif model_type == "tCNN":
            self.temporal_model = tCNN(n_nodes, kernel_size, n_classes, n_feat, 
                                     causal=causal, dropout=dropout)
        else:
            raise ValueError(f"Unknown model type: {model_type}")
    
    def forward(self, x, mask=None):
        # Extract spatial features
        features = self.feature_extractor(x)
        
        # Apply temporal model
        output = self.temporal_model(features)
        
        # Apply mask if provided (for variable length sequences)
        if mask is not None:
            output = output * mask.unsqueeze(-1)
        
        # For multi-label classification, apply sigmoid
        return torch.sigmoid(output)

# ============================================================================
# Dataset and Training
# ============================================================================

class VideoEmotionDataset(Dataset):
    """Dataset for video emotion recognition"""
    def __init__(self, video_paths, labels, sequence_length=32, frame_size=(224, 224)):
        self.video_paths = video_paths
        self.labels = labels
        self.sequence_length = sequence_length
        self.frame_size = frame_size
        
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        
        # Load video frames
        frames = self.load_video_frames(video_path)
        
        return frames, torch.FloatTensor(label)
    
    def load_video_frames(self, video_path):
        cap = cv2.VideoCapture(video_path)
        frames = []
        
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Sample frames uniformly
        if total_frames > self.sequence_length:
            frame_indices = np.linspace(0, total_frames - 1, self.sequence_length, dtype=int)
        else:
            frame_indices = list(range(total_frames))
            # Pad with last frame if needed
            while len(frame_indices) < self.sequence_length:
                frame_indices.append(frame_indices[-1])
        
        for frame_idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            
            if ret:
                frame = cv2.resize(frame, self.frame_size)
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = frame.astype(np.float32) / 255.0
                frames.append(frame)
            else:
                # Use zero frame if reading fails
                frames.append(np.zeros((*self.frame_size, 3), dtype=np.float32))
        
        cap.release()
        
        # Convert to tensor format
        frames = np.stack(frames)
        frames = torch.FloatTensor(frames).permute(0, 3, 1, 2)
        
        return frames

def mask_data(X, Y, max_len, mask_value=-1):
    """Mask sequences to same length for batch processing"""
    X_masked = []
    Y_masked = []
    masks = []
    
    for x, y in zip(X, Y):
        seq_len = len(x)
        
        if seq_len < max_len:
            # Pad sequences
            pad_len = max_len - seq_len
            x_pad = np.full((pad_len, x.shape[1]), mask_value, dtype=x.dtype)
            y_pad = np.zeros((pad_len, y.shape[1]), dtype=y.dtype)
            
            x_masked = np.vstack([x, x_pad])
            y_masked = np.vstack([y, y_pad])
            
            mask = np.concatenate([np.ones(seq_len), np.zeros(pad_len)])
        else:
            # Truncate if too long
            x_masked = x[:max_len]
            y_masked = y[:max_len]
            mask = np.ones(max_len)
        
        X_masked.append(x_masked)
        Y_masked.append(y_masked)
        masks.append(mask)
    
    return np.array(X_masked), np.array(Y_masked), np.array(masks)

def unmask_predictions(predictions, masks):
    """Remove padding from predictions"""
    unmasked = []
    for pred, mask in zip(predictions, masks):
        seq_len = int(mask.sum())
        unmasked.append(pred[:seq_len])
    return unmasked

# ============================================================================
# Training and Evaluation
# ============================================================================

class EmotionTrainer:
    """Training pipeline for TCN emotion classifier"""
    def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.device = device
        self.emotion_labels = ['anger', 'disgust', 'fear', 'happiness', 'sadness', 'surprise']
        
    def train_epoch(self, dataloader, optimizer, criterion, max_len=None):
        self.model.train()
        total_loss = 0
        all_predictions = []
        all_labels = []
        
        for batch_idx, (data, target) in enumerate(dataloader):
            # Handle variable length sequences
            if max_len:
                # Convert to feature sequences first
                batch_features = []
                batch_labels = []
                
                for i in range(data.size(0)):
                    frames = data[i].unsqueeze(0)
                    features = self.model.feature_extractor(frames)
                    batch_features.append(features.squeeze(0))
                    batch_labels.append(target[i])
                
                # Mask to same length
                X_masked, Y_masked, masks = mask_data(
                    [f.cpu().numpy() for f in batch_features],
                    [l.unsqueeze(0).repeat(f.size(0), 1).cpu().numpy() for f, l in zip(batch_features, batch_labels)],
                    max_len
                )
                
                data = torch.FloatTensor(X_masked).to(self.device)
                target = torch.FloatTensor(Y_masked).to(self.device)
                mask = torch.FloatTensor(masks).to(self.device)
                
                # Forward pass with temporal model only
                output = self.model.temporal_model(data)
                output = output * mask.unsqueeze(-1)
                
                # Compute loss only on valid time steps
                loss = criterion(output.view(-1, output.size(-1)), 
                               target.view(-1, target.size(-1)))
                
                # Weight loss by mask
                mask_flat = mask.view(-1)
                loss = (loss.mean(dim=1) * mask_flat).sum() / mask_flat.sum()
                
            else:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                
                # For sequence output, take mean over time
                if len(output.shape) == 3:
                    output = output.mean(dim=1)
                    
                loss = criterion(output, target)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')
        
        return total_loss / len(dataloader)
    
    def evaluate(self, dataloader, criterion, max_len=None):
        self.model.eval()
        total_loss = 0
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for data, target in dataloader:
                if max_len:
                    # Similar masking for evaluation
                    batch_features = []
                    batch_labels = []
                    
                    for i in range(data.size(0)):
                        frames = data[i].unsqueeze(0)
                        features = self.model.feature_extractor(frames)
                        batch_features.append(features.squeeze(0))
                        batch_labels.append(target[i])
                    
                    X_masked, Y_masked, masks = mask_data(
                        [f.cpu().numpy() for f in batch_features],
                        [l.unsqueeze(0).repeat(f.size(0), 1).cpu().numpy() for f, l in zip(batch_features, batch_labels)],
                        max_len
                    )
                    
                    data = torch.FloatTensor(X_masked).to(self.device)
                    target = torch.FloatTensor(Y_masked).to(self.device)
                    mask = torch.FloatTensor(masks).to(self.device)
                    
                    output = self.model.temporal_model(data)
                    output = output * mask.unsqueeze(-1)
                    
                    # Unmask predictions for evaluation
                    pred_unmasked = unmask_predictions(output.cpu().numpy(), masks)
                    label_unmasked = unmask_predictions(Y_masked, masks)
                    
                    for p, l in zip(pred_unmasked, label_unmasked):
                        all_predictions.append((p.mean(axis=0) > 0.5).astype(int))
                        all_labels.append(l[0])  # Take first label (they're all the same)
                        
                else:
                    data, target = data.to(self.device), target.to(self.device)
                    output = self.model(data)
                    
                    if len(output.shape) == 3:
                        output = output.mean(dim=1)
                    
                    predictions = (output > 0.5).float()
                    all_predictions.extend(predictions.cpu().numpy())
                    all_labels.extend(target.cpu().numpy())
        
        # Compute metrics
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        
        accuracy = sm.accuracy_score(all_labels, all_predictions)
        f1_macro = sm.f1_score(all_labels, all_predictions, average='macro')
        f1_micro = sm.f1_score(all_labels, all_predictions, average='micro')
        
        return accuracy, f1_macro, f1_micro
    
    def train(self, train_loader, val_loader, num_epochs=50, lr=0.001, max_len=None):
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
        criterion = nn.BCELoss()
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
        
        best_f1 = 0
        train_losses = []
        val_accuracies = []
        val_f1_scores = []
        
        for epoch in range(num_epochs):
            print(f'\nEpoch {epoch+1}/{num_epochs}')
            print('-' * 50)
            
            # Training
            train_loss = self.train_epoch(train_loader, optimizer, criterion, max_len)
            train_losses.append(train_loss)
            
            # Validation
            val_acc, val_f1_macro, val_f1_micro = self.evaluate(val_loader, criterion, max_len)
            val_accuracies.append(val_acc)
            val_f1_scores.append(val_f1_macro)
            
            scheduler.step()
            
            print(f'Train Loss: {train_loss:.4f}')
            print(f'Val Acc: {val_acc:.4f}, Val F1 (macro): {val_f1_macro:.4f}, Val F1 (micro): {val_f1_micro:.4f}')
            
            # Save best model
            if val_f1_macro > best_f1:
                best_f1 = val_f1_macro
                torch.save(self.model.state_dict(), f'best_{self.model.model_type}_emotion_model.pth')
                print(f'New best model saved with Val F1: {val_f1_macro:.4f}')
        
        return {
            'train_losses': train_losses,
            'val_accuracies': val_accuracies,
            'val_f1_scores': val_f1_scores
        }

# ============================================================================
# Example Usage
# ============================================================================

def main():
    # Model configuration
    model_type = "ED-TCN"  # Options: "DilatedTCN", "ED-TCN", "tCNN"
    n_nodes = [64, 96, 128]
    kernel_size = 3
    causal = True  # Set to False for bidirectional/acausal
    n_classes = 6  # Number of emotion classes
    sequence_length = 32
    max_len = 64  # For masking variable length sequences
    
    # Initialize model
    model = TCNEmotionClassifier(
        model_type=model_type,
        n_feat=512,
        n_classes=n_classes,
        n_nodes=n_nodes,
        kernel_size=kernel_size,
        causal=causal,
        max_len=max_len
    )
    
    print(f"Model: {model_type}")
    print(f"Parameters: {sum(p.numel() for p in model.parameters())}")
    print(f"Causal: {causal}")
    
    # Example training (replace with your data)
    # train_paths = ["path/to/video1.mp4", ...]
    # train_labels = [[1, 0, 0, 1, 0, 0], ...]  # Multi-label format
    # 
    # train_dataset = VideoEmotionDataset(train_paths, train_labels, sequence_length)
    # train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    # 
    # trainer = EmotionTrainer(model)
    # history = trainer.train(train_loader, val_loader, num_epochs=50, max_len=max_len)
    
    print("Model ready for training!")

if __name__ == "__main__":
    main()