In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, roc_auc_score, roc_curve
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torchvision import transforms, models
from torchvision.models import swin_v2_b, Swin_V2_B_Weights
from PIL import Image
import copy
import time
import math
from collections import Counter
from tqdm import tqdm
from torch.amp import autocast, GradScaler
import warnings
import torch.nn.functional as F
warnings.filterwarnings("ignore")

class FocalLoss(nn.Module):
    """
    Focal Loss implementation for addressing class imbalance
    Standard parameters: alpha=1.0, gamma=2.0
    """
    def __init__(self, alpha=1.0, gamma=2.0, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        # Compute cross entropy
        ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
        
        # Compute p_t
        pt = torch.exp(-ce_loss)
        
        # Compute focal loss
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class LabelSmoothingCrossEntropy(nn.Module):
    """
    Label Smoothing Cross Entropy Loss
    Standard parameter: smoothing=0.1
    """
    def __init__(self, smoothing=0.1, weight=None):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing
        self.weight = weight
        
    def forward(self, inputs, targets):
        num_classes = inputs.size(-1)
        log_probs = F.log_softmax(inputs, dim=-1)
        
        # Create smoothed targets
        with torch.no_grad():
            smooth_targets = torch.zeros_like(log_probs)
            smooth_targets.fill_(self.smoothing / (num_classes - 1))
            smooth_targets.scatter_(1, targets.unsqueeze(1), 1.0 - self.smoothing)
        
        # Apply class weights if provided
        if self.weight is not None:
            smooth_targets = smooth_targets * self.weight.unsqueeze(0)
        
        loss = -smooth_targets * log_probs
        return loss.sum(dim=-1).mean()

class FocalLossWithLabelSmoothing(nn.Module):
    """
    Combined Focal Loss with Label Smoothing
    Standard parameters: alpha=1.0, gamma=2.0, smoothing=0.1
    """
    def __init__(self, alpha=1.0, gamma=2.0, smoothing=0.1, weight=None):
        super(FocalLossWithLabelSmoothing, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.smoothing = smoothing
        self.weight = weight
        
    def forward(self, inputs, targets):
        num_classes = inputs.size(-1)
        
        # Apply label smoothing
        with torch.no_grad():
            smooth_targets = torch.zeros_like(inputs)
            smooth_targets.fill_(self.smoothing / (num_classes - 1))
            smooth_targets.scatter_(1, targets.unsqueeze(1), 1.0 - self.smoothing)
        
        # Compute log probabilities
        log_probs = F.log_softmax(inputs, dim=-1)
        
        # Compute cross entropy with smooth targets
        ce_loss = -smooth_targets * log_probs
        ce_loss = ce_loss.sum(dim=-1)
        
        # Compute pt for focal loss
        pt = torch.exp(-ce_loss)
        
        # Apply focal loss weighting
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        # Apply class weights if provided
        if self.weight is not None:
            # Get class weights for each sample based on original targets
            sample_weights = self.weight[targets]
            focal_loss = focal_loss * sample_weights
        
        return focal_loss.mean()


# Load datasets
train_df = pd.read_csv('')
val_df = pd.read_csv('')
test_df = pd.read_csv('')

# Hyperparameters - Optimized for T4 GPU
SEED = 42
BATCH_SIZE = 4  # Reduced for memory efficiency
GRADIENT_ACCUMULATION_STEPS = 2  # Effective batch size = 4 * 2 = 8
NUM_EPOCHS = 500
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
DROPOUT_RATE = 0.5
MAX_LENGTH = 256   # increased from 128
IMAGE_SIZE = 224
NUM_CLASSES = 3
PATIENCE = 5
WARMUP_RATIO = 0.1

# CONFIGURATION FLAGS - CHANGE THESE FOR YOUR EXPERIMENTS
FREEZE_EARLY_LAYERS = False  # Set to False to unfreeze early layers
USE_FOCAL_LOSS = True         # Set to True to use Focal Loss
USE_SEPARATE_LOSS= False    
USE_LABEL_SMOOTHING = True     # Set to True to use Label Smoothing
USE_COMBINED_LOSS = True       # Set to True to use both together
CO_ATTENTION_HIDDEN_DIM = 512  # Change this for different hidden dimensions (256, 512, 768, 1024)


# Model configurations
TEXT_MODEL_NAME = "microsoft/mdeberta-v3-base"
IMAGE_DIR = ""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


print(f"Using device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE}, Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

# Set seeds for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Target classes
target_classes = ['x', 'y', 'z']

# Compute class weights for imbalanced dataset
class_counts = Counter(train_df['class_idx'])
total_samples = sum(class_counts.values())
num_classes = len(class_counts)
class_weights = torch.tensor([
    total_samples / (num_classes * class_counts[i]) for i in range(num_classes)
], dtype=torch.float32).to(DEVICE)


# LOSS FUNCTION CONFIGURATION - CHANGE #3
if USE_SEPARATE_LOSS:
    if USE_COMBINED_LOSS:
        # Combined focal + label smoothing
        text_criterion = FocalLossWithLabelSmoothing(alpha=1.0, gamma=2.0, smoothing=0.1)
        visual_criterion = FocalLossWithLabelSmoothing(alpha=1.0, gamma=2.0, smoothing=0.1, weight=class_weights)
        print('Using separate combined Focal+LabelSmoothing losses')
    elif USE_FOCAL_LOSS:
        # Focal loss only
        text_criterion = FocalLoss(alpha=1.0, gamma=2.0)
        visual_criterion = FocalLoss(alpha=1.0, gamma=2.0, weight=class_weights)
        print('Using separate Focal losses')
    elif USE_LABEL_SMOOTHING:
        # Label smoothing only
        text_criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
        visual_criterion = LabelSmoothingCrossEntropy(smoothing=0.1, weight=class_weights)
        print('Using separate Label Smoothing losses')
    else:
        # Original losses
        text_criterion = nn.CrossEntropyLoss()
        visual_criterion = nn.CrossEntropyLoss(weight=class_weights)
        print('Using original separate losses')
else:
    # Single loss configuration
    if USE_COMBINED_LOSS:
        criterion = FocalLossWithLabelSmoothing(alpha=1.0, gamma=2.0, smoothing=0.1, weight=class_weights)
        print('Using combined Focal+LabelSmoothing loss')
    elif USE_FOCAL_LOSS:
        criterion = FocalLoss(alpha=1.0, gamma=2.0, weight=class_weights)
        print('Using Focal loss')
    elif USE_LABEL_SMOOTHING:
        criterion = LabelSmoothingCrossEntropy(smoothing=0.1, weight=class_weights)
        print('Using Label Smoothing loss')
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print('Using original weighted CrossEntropy loss')

print(f'Loss configuration: Focal={USE_FOCAL_LOSS}, LabelSmoothing={USE_LABEL_SMOOTHING}, Combined={USE_COMBINED_LOSS}')


# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, use_fast=False)

# Image transforms
weights = Swin_V2_B_Weights.IMAGENET1K_V1
image_transforms = weights.transforms()

class MultimodalDataset(Dataset):
    def __init__(self, dataframe, image_dir, tokenizer, max_length, image_transforms):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.image_transforms = image_transforms
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Process text
        text = str(row['text'])
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        # Process image
        img_name = row['image']
        img_path = os.path.join(self.image_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            image = Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color='white')
            
        if self.image_transforms:
            image = self.image_transforms(image)
        
        label = row['class_idx']
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'image': image,
            'label': torch.tensor(label, dtype=torch.long)
        }

class CoAttentionLayer(nn.Module):
    """Co-attention mechanism for multimodal fusion"""
    def __init__(self, text_dim, image_dim, hidden_dim=512):
        super(CoAttentionLayer, self).__init__()
        self.text_dim = text_dim
        self.image_dim = image_dim
        self.hidden_dim = hidden_dim  # CHANGE #2: This controls co-attention dimensions
        
        # Linear projections
        self.text_proj = nn.Linear(text_dim, hidden_dim)
        self.image_proj = nn.Linear(image_dim, hidden_dim)
        
        # Attention weights - you can modify num_heads here too
        num_heads = min(8, hidden_dim // 64)  # Adaptive heads based on hidden_dim
        self.text_to_image_attn = nn.MultiheadAttention(hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True)
        self.image_to_text_attn = nn.MultiheadAttention(hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True)
        
        # Layer normalization
        self.text_ln = nn.LayerNorm(hidden_dim)
        self.image_ln = nn.LayerNorm(hidden_dim)
        
        # Feed forward networks - dimensions scale with hidden_dim
        ffn_dim = hidden_dim * 2
        self.text_ffn = nn.Sequential(
            nn.Linear(hidden_dim, ffn_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(ffn_dim, hidden_dim)
        )
        
        self.image_ffn = nn.Sequential(
            nn.Linear(hidden_dim, ffn_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(ffn_dim, hidden_dim)
        )
        
    def forward(self, text_features, image_features):
        # Project to same dimension
        text_proj = self.text_proj(text_features)  # [batch, seq_len, hidden_dim]
        image_proj = self.image_proj(image_features).unsqueeze(1)  # [batch, 1, hidden_dim]
        
        # Co-attention: text attends to image
        text_attended, _ = self.text_to_image_attn(
            query=text_proj,
            key=image_proj,
            value=image_proj
        )
        text_attended = self.text_ln(text_attended + text_proj)
        text_attended = text_attended + self.text_ffn(text_attended)
        
        # Co-attention: image attends to text
        image_attended, _ = self.image_to_text_attn(
            query=image_proj,
            key=text_proj,
            value=text_proj
        )
        image_attended = self.image_ln(image_attended + image_proj)
        image_attended = image_attended + self.image_ffn(image_attended)
        
        return text_attended, image_attended.squeeze(1)

class MultimodalCoAttentionClassifier(nn.Module):
    def __init__(self, text_model_name, num_classes, dropout_rate=0.3):
        super(MultimodalCoAttentionClassifier, self).__init__()
        
        # Text encoder
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        text_hidden_size = self.text_encoder.config.hidden_size
        
        # Image encoder
        self.image_encoder = swin_v2_b(weights=Swin_V2_B_Weights.IMAGENET1K_V1)
        # Remove the final classification head
        self.image_encoder.head = nn.Identity()
        image_hidden_size = 1024  # Swin-V2-B output dimension
        
        # LAYER FREEZING CONFIGURATION - CHANGE #1
        if FREEZE_EARLY_LAYERS:
            print("Freezing early layers to save memory...")
            # Freeze early layers of text encoder
            for param in self.text_encoder.embeddings.parameters():
                param.requires_grad = False
            for i in range(6):  # Freeze first 6 layers
                for param in self.text_encoder.encoder.layer[i].parameters():
                    param.requires_grad = False
                    
            # Freeze early layers of image encoder
            for param in self.image_encoder.features[:4].parameters():
                param.requires_grad = False
        else:
            print("All layers are unfrozen for full training...")
        
        # Co-attention layer - CHANGE #2: Use configurable hidden dimension
        self.co_attention = CoAttentionLayer(text_hidden_size, image_hidden_size, CO_ATTENTION_HIDDEN_DIM)
        
        # Fusion and classification - dimensions adapt to co-attention hidden dim
        fusion_input_dim = CO_ATTENTION_HIDDEN_DIM * 2  # text + image features
        fusion_hidden_dim = CO_ATTENTION_HIDDEN_DIM
        
        self.fusion_layer = nn.Sequential(
            nn.Linear(fusion_input_dim, fusion_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(fusion_hidden_dim, fusion_hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )
        
        self.classifier = nn.Linear(fusion_hidden_dim // 2, num_classes)
        
        # For separate loss computation
        if USE_SEPARATE_LOSS:
            self.text_classifier = nn.Linear(CO_ATTENTION_HIDDEN_DIM, num_classes)
            self.image_classifier = nn.Linear(CO_ATTENTION_HIDDEN_DIM, num_classes)
        
    def forward(self, input_ids, attention_mask, images):
        # Text encoding
        text_outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        text_features = text_outputs.last_hidden_state  # [batch, seq_len, hidden_size]
        
        # Image encoding
        image_features = self.image_encoder(images)  # [batch, hidden_size]
        
        # Co-attention
        text_attended, image_attended = self.co_attention(text_features, image_features)
        
        # Global pooling for text (attention-weighted)
        text_mask = attention_mask.unsqueeze(-1).float()
        text_pooled = (text_attended * text_mask).sum(dim=1) / text_mask.sum(dim=1)
        
        # CHANGE #3: Handle separate loss computation
        if USE_SEPARATE_LOSS:
            # Individual modality predictions
            text_logits = self.text_classifier(text_pooled)
            image_logits = self.image_classifier(image_attended)
            
            # Concatenate features for final prediction
            fused_features = torch.cat([text_pooled, image_attended], dim=-1)
            fused_features = self.fusion_layer(fused_features)
            final_logits = self.classifier(fused_features)
            
            return final_logits, text_logits, image_logits
        else:
            # Standard multimodal fusion
            fused_features = torch.cat([text_pooled, image_attended], dim=-1)
            fused_features = self.fusion_layer(fused_features)
            logits = self.classifier(fused_features)
            
            return logits

# Create datasets
train_dataset = MultimodalDataset(train_df, IMAGE_DIR, tokenizer, MAX_LENGTH, image_transforms)
val_dataset = MultimodalDataset(val_df, IMAGE_DIR, tokenizer, MAX_LENGTH, image_transforms)
test_dataset = MultimodalDataset(test_df, IMAGE_DIR, tokenizer, MAX_LENGTH, image_transforms)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Initialize model
model = MultimodalCoAttentionClassifier(TEXT_MODEL_NAME, NUM_CLASSES, DROPOUT_RATE)
model = model.to(DEVICE)

# Calculate model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS
warmup_steps = int(WARMUP_RATIO * total_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

scaler = GradScaler()

def evaluate_model(model, data_loader, device, name):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc=name):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            with autocast('cuda'):
                if USE_SEPARATE_LOSS:
                    final_outputs, text_outputs, image_outputs = model(input_ids, attention_mask, images)
                    # Compute losses with configured loss functions
                    if USE_COMBINED_LOSS or USE_FOCAL_LOSS or USE_LABEL_SMOOTHING:
                        final_loss = criterion(final_outputs, labels) if not USE_SEPARATE_LOSS else text_criterion(final_outputs, labels)
                    else:
                        final_loss = nn.CrossEntropyLoss()(final_outputs, labels)
                    
                    text_loss = text_criterion(text_outputs, labels)
                    image_loss = visual_criterion(image_outputs, labels)
                    loss = 0.5 * final_loss + 0.25 * text_loss + 0.25 * image_loss
                    outputs = final_outputs
                else:
                    outputs = model(input_ids, attention_mask, images)
                    loss = criterion(outputs, labels)
            
            total_loss += loss.item() * input_ids.size(0)
            
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(data_loader.dataset)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    
    try:
        auc = roc_auc_score(all_labels, np.eye(NUM_CLASSES)[all_preds], multi_class='ovr')
    except ValueError:
        auc = 0.0
    
    return avg_loss, accuracy, f1, auc, all_preds, all_labels

# Training loop
print("Starting training...")
best_f1 = 0.0
no_improve_epochs = 0
best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    print('-' * 50)
    
    # Training phase
    model.train()
    train_loss = 0.0
    train_preds = []
    train_labels = []
    
    progress_bar = tqdm(train_loader, desc="Training")
    
    for step, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        images = batch['image'].to(DEVICE)
        labels = batch['label'].to(DEVICE)
        
        with autocast('cuda'):
            if USE_SEPARATE_LOSS:
                final_outputs, text_outputs, image_outputs = model(input_ids, attention_mask, images)
                # Compute losses with configured loss functions
                if USE_COMBINED_LOSS or USE_FOCAL_LOSS or USE_LABEL_SMOOTHING:
                    final_loss = text_criterion(final_outputs, labels)  # Use text_criterion for consistency
                else:
                    final_loss = nn.CrossEntropyLoss()(final_outputs, labels)
                    
                text_loss = text_criterion(text_outputs, labels)
                image_loss = visual_criterion(image_outputs, labels)
                loss = (0.5 * final_loss + 0.25 * text_loss + 0.25 * image_loss) / GRADIENT_ACCUMULATION_STEPS
                outputs = final_outputs
            else:
                outputs = model(input_ids, attention_mask, images)
                loss = criterion(outputs, labels) / GRADIENT_ACCUMULATION_STEPS
        
        scaler.scale(loss).backward()
        
        # Gradient accumulation
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
        
        # Track metrics
        train_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS * input_ids.size(0)
        _, preds = torch.max(outputs, 1)
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())
        
        progress_bar.set_postfix({
            "batch_loss": loss.item() * GRADIENT_ACCUMULATION_STEPS,
            "lr": scheduler.get_last_lr()[0]
        })
    
    # Calculate training metrics
    train_loss = train_loss / len(train_loader.dataset)
    train_acc = accuracy_score(train_labels, train_preds)
    train_f1 = f1_score(train_labels, train_preds, average='macro')
    
    # Validation phase
    val_loss, val_acc, val_f1, val_auc, _, _ = evaluate_model(model, val_loader, DEVICE, 'Validation')
    
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f}')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val AUC: {val_auc:.4f}')
    
    # Early stopping
    if val_f1 > best_f1:
        print(f'Validation F1 improved from {best_f1:.4f} to {val_f1:.4f}')
        best_f1 = val_f1
        best_model_wts = copy.deepcopy(model.state_dict())
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1
        print(f'No improvement for {no_improve_epochs} epochs')
    
    if no_improve_epochs >= PATIENCE:
        print(f'Early stopping triggered after {epoch+1} epochs')
        break
        
print(f'\nBest Validation F1: {best_f1:.4f}')

# Load best model weights
model.load_state_dict(best_model_wts)

# Test evaluation
_, test_acc, test_f1, test_auc, test_preds, test_labels = evaluate_model(model, test_loader, DEVICE, 'Testing')

print('\n' + '='*60)
print('MULTIMODAL CO-ATTENTION FUSION RESULTS')
print('='*60)
print("\nTest Classification Report:")
print(classification_report(test_labels, test_preds, target_names=target_classes, digits=4))
print(f'\nTest Accuracy: {test_acc:.4f}')
print(f'Test F1-Score: {test_f1:.4f}')
print(f'Test ROC-AUC: {test_auc:.4f}')

# Visualizations
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Confusion Matrix
cm = confusion_matrix(test_labels, test_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=target_classes, yticklabels=target_classes, ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')
axes[0].set_title('Confusion Matrix')

# ROC Curves
for i in range(NUM_CLASSES):
    if len(np.unique(np.array(test_labels) == i)) > 1:  # Check if both classes exist
        fpr, tpr, _ = roc_curve(np.array(test_labels) == i, np.array(test_preds) == i)
        auc_score = roc_auc_score(np.array(test_labels) == i, np.array(test_preds) == i)
        axes[1].plot(fpr, tpr, label=f'{target_classes[i]} (AUC = {auc_score:.4f})')

axes[1].plot([0, 1], [0, 1], 'r--', alpha=0.5)
axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate')
axes[1].set_title('ROC Curves')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('multimodal_results.png', dpi=300, bbox_inches='tight')
plt.show()

# Save the model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': {
        'text_model_name': TEXT_MODEL_NAME,
        'num_classes': NUM_CLASSES,
        'dropout_rate': DROPOUT_RATE,
        'max_length': MAX_LENGTH,
        'image_size': IMAGE_SIZE
    },
    'best_f1': best_f1,
    'target_classes': target_classes
}, 'multimodal_coattention_classifier.pt')

print(f"\nModel saved as 'multimodal_coattention_classifier.pt'")
print(f"Best validation F1-score: {best_f1:.4f}")

# Memory cleanup
torch.cuda.empty_cache()
print("\nTraining completed successfully!")