In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# Multimodal Early Fusion Implementation
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, roc_auc_score, roc_curve
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torchvision import transforms, models
from torchvision.models import swin_v2_b, Swin_V2_B_Weights
from PIL import Image
import copy
import time
import math
from collections import Counter
from tqdm import tqdm
from torch.amp import autocast, GradScaler

# Load datasets
train_df = pd.read_csv('')
val_df = pd.read_csv('')
test_df = pd.read_csv('')

# Hyperparameters from your best models
# Text model hyperparameters
DROPOUT_RATE = 0.5
WEIGHT_DECAY = 0.01
seed = 42
MAX_LENGTH = 256
TEXT_MODEL_NAME = "microsoft/mdeberta-v3-base"
TEXT_LEARNING_RATE = 1e-5

# Visual model hyperparameters
D_O = 0.5  # Visual dropout
VISUAL_LEARNING_RATE = 0.0001
IMAGE_SIZE = 224

# Common hyperparameters
BATCH_SIZE = 8
NUM_EPOCHS = 500
NUM_CLASSES = 3
PATIENCE = 3
WARMUP_STEPS = 0
IMAGE_DIR = ""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f'Drop Out Rate {DROPOUT_RATE}')
print(f'Weight Decay {WEIGHT_DECAY}')
print(f'Seed = {seed}')
print(f'Max Length {MAX_LENGTH}')
print(f"Visual Dropout = {D_O}")
print(f"Using device: {DEVICE}")

# Ensure reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

target_classes = ['x', 'y', 'z']

# Loss functions (different for text and visual as in your original code)
# Text: Normal CrossEntropyLoss
text_criterion = nn.CrossEntropyLoss()
print('Text: Normal Loss Function')

# Visual: Weighted CrossEntropyLoss
class_counts = Counter(train_df['class_idx'])
total_samples = sum(class_counts.values())
num_classes = len(class_counts)
class_weights = torch.tensor([
    math.log(total_samples / class_counts[i]) for i in range(num_classes)
], dtype=torch.float32).to(DEVICE)
visual_criterion = nn.CrossEntropyLoss(weight=class_weights)
print('Visual: Weighted Loss Function')

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, use_fast=False)

# Data transformations for images
weights = Swin_V2_B_Weights.IMAGENET1K_V1
train_transforms = weights.transforms()
val_transforms = weights.transforms()

# Custom multimodal dataset
class MultimodalDataset(Dataset):
    def __init__(self, dataframe, image_dir, tokenizer, max_length, transform=None):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transform
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        # Text processing
        text = str(self.dataframe.iloc[idx]['text'])
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        # Image processing
        img_name = self.dataframe.iloc[idx]['image']
        img_path = os.path.join(self.image_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            image = Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color='white')
            
        if self.transform:
            image = self.transform(image)
        
        label = self.dataframe.iloc[idx]['class_idx']
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'image': image,
            'label': torch.tensor(label, dtype=torch.long)
        }

# Create datasets
train_dataset = MultimodalDataset(train_df, IMAGE_DIR, tokenizer, MAX_LENGTH, train_transforms)
val_dataset = MultimodalDataset(val_df, IMAGE_DIR, tokenizer, MAX_LENGTH, val_transforms)
test_dataset = MultimodalDataset(test_df, IMAGE_DIR, tokenizer, MAX_LENGTH, val_transforms)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Early Fusion Multimodal Model
class EarlyFusionMultimodalModel(nn.Module):
    def __init__(self, text_model_name, num_classes, text_dropout_rate=0.3, visual_dropout_rate=0.5):
        super(EarlyFusionMultimodalModel, self).__init__()
        print(f'Text Dropout rate {text_dropout_rate}')
        print(f'Visual Dropout rate {visual_dropout_rate}')
        
        # Text encoder (DeBERTa-v3)
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.text_dropout = nn.Dropout(text_dropout_rate)
        
        # Visual encoder (Swin Transformer V2)
        weights = Swin_V2_B_Weights.IMAGENET1K_V1
        self.visual_encoder = swin_v2_b(weights=weights)
        # Remove the final classification head
        self.visual_encoder.head = nn.Identity()
        self.visual_dropout = nn.Dropout(visual_dropout_rate)
        
        # Fusion layer - concatenate text and visual features
        text_hidden_size = self.text_encoder.config.hidden_size  # 768 for DeBERTa-v3-base
        visual_hidden_size = 1024  # Swin-V2-B output size
        fused_size = text_hidden_size + visual_hidden_size
        
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(fused_size, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, input_ids, attention_mask, images):
        # Text encoding
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state[:, 0, :]  # CLS token
        text_features = self.text_dropout(text_features)
        
        # Visual encoding
        visual_features = self.visual_encoder(images)  # Already goes through all layers except head
        visual_features = self.visual_dropout(visual_features)
        
        # Early fusion - concatenate features
        fused_features = torch.cat([text_features, visual_features], dim=1)
        
        # Classification
        logits = self.classifier(fused_features)
        return logits, text_features, visual_features

# Create model
model = EarlyFusionMultimodalModel(TEXT_MODEL_NAME, NUM_CLASSES, DROPOUT_RATE, D_O)
model = model.to(DEVICE)

# Optimizer with different learning rates for different components
text_params = list(model.text_encoder.parameters()) + list(model.text_dropout.parameters())
visual_params = list(model.visual_encoder.parameters()) + list(model.visual_dropout.parameters())
fusion_params = list(model.classifier.parameters())

optimizer = AdamW([
    {'params': text_params, 'lr': TEXT_LEARNING_RATE, 'weight_decay': WEIGHT_DECAY},
    {'params': visual_params, 'lr': VISUAL_LEARNING_RATE, 'weight_decay': 0},
    {'params': fusion_params, 'lr': (TEXT_LEARNING_RATE + VISUAL_LEARNING_RATE) / 2, 'weight_decay': WEIGHT_DECAY/2}
])

# Scheduler
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# Combined loss function
def combined_loss(logits, text_features, visual_features, labels, alpha=0.6, beta=0.4):
    """
    Combined loss with different weights for text and visual components
    alpha: weight for text loss, beta: weight for visual loss
    """
    # Main multimodal loss
    main_loss = nn.CrossEntropyLoss()(logits, labels)
    
    # Auxiliary text loss (using normal CrossEntropyLoss as in original)
    text_logits = nn.Linear(text_features.size(1), NUM_CLASSES).to(DEVICE)(text_features)
    text_loss = text_criterion(text_logits, labels)
    
    # Auxiliary visual loss (using weighted CrossEntropyLoss as in original)
    visual_logits = nn.Linear(visual_features.size(1), NUM_CLASSES).to(DEVICE)(visual_features)
    visual_loss = visual_criterion(visual_logits, labels)
    
    # Combined loss
    total_loss = main_loss + alpha * text_loss + beta * visual_loss
    return total_loss

def evaluate_model(model, data_loader, device, name):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc=name):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            logits, text_features, visual_features = model(input_ids, attention_mask, images)
            loss = combined_loss(logits, text_features, visual_features, labels)
            
            total_loss += loss.item() * input_ids.size(0)
            
            _, preds = torch.max(logits, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(data_loader.dataset)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    
    try:
        auc = roc_auc_score(all_labels, np.eye(NUM_CLASSES)[all_preds], multi_class='ovr')
    except ValueError:
        auc = 0.0
    
    return avg_loss, accuracy, f1, auc, all_preds, all_labels

# Training loop
best_f1 = 0.0
no_improve_epochs = 0
best_model_wts = copy.deepcopy(model.state_dict())
scaler = GradScaler()

print("Starting training...")
for epoch in range(NUM_EPOCHS):
    print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
    print('-' * 30)
    
    # Training phase
    model.train()
    train_loss = 0.0
    train_preds = []
    train_labels = []
    
    progress_bar = tqdm(train_loader, desc="Training")
    for batch in progress_bar:
        # Get batch data
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        images = batch['image'].to(DEVICE)
        labels = batch['label'].to(DEVICE)
        
        # Forward pass
        optimizer.zero_grad()
        with autocast('cuda'):
            logits, text_features, visual_features = model(input_ids, attention_mask, images)
            loss = combined_loss(logits, text_features, visual_features, labels)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        # Track loss and predictions
        train_loss += loss.item() * input_ids.size(0)
        _, preds = torch.max(logits, 1)
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())
        
        # Update progress bar
        progress_bar.set_postfix({"batch_loss": loss.item()})
    
    # Calculate training metrics
    train_loss = train_loss / len(train_loader.dataset)
    train_acc = accuracy_score(train_labels, train_preds)
    train_f1 = f1_score(train_labels, train_preds, average='macro')
    
    # Validation phase
    val_loss, val_acc, val_f1, val_auc, _, _ = evaluate_model(model, val_loader, DEVICE, 'Validating')
    
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f}')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val AUC: {val_auc:.4f}')
    
    # Early stopping based on validation macro F1 score
    if val_f1 > best_f1:
        print(f'Validation F1 improved from {best_f1:.4f} to {val_f1:.4f}')
        best_f1 = val_f1
        best_model_wts = copy.deepcopy(model.state_dict())
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1
        print(f'No improvement for {no_improve_epochs} epochs')
    
    if no_improve_epochs >= PATIENCE:
        print(f'Early stopping triggered after {epoch+1} epochs')
        break
        
print(f'Best Validation F1: {best_f1:.4f}')

# Load best model weights
model.load_state_dict(best_model_wts)

# Evaluate model on test set
_, test_acc, test_f1, test_auc, test_preds, test_labels = evaluate_model(model, test_loader, DEVICE, 'Testing')

print('Early Fusion Multimodal Model')
print("\nTest Classification Report:")
print(classification_report(test_labels, test_preds, target_names=target_classes, digits=4))
print(f'Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f} | Test AUC: {test_auc:.4f}')

# Confusion Matrix
plt.figure(figsize=(10, 8))
cm = confusion_matrix(test_labels, test_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_classes, yticklabels=target_classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Early Fusion - Confusion Matrix')
plt.tight_layout()
plt.savefig('early_fusion_confusion_matrix.png')
plt.show()

# ROC Curve Visualization
plt.figure(figsize=(10, 8))
for i in range(NUM_CLASSES):
    fpr, tpr, _ = roc_curve(np.array(test_labels) == i, np.array(test_preds) == i)
    plt.plot(fpr, tpr, label=f'Class {target_classes[i]} (AUC = {roc_auc_score(np.array(test_labels) == i, np.array(test_preds) == i):.4f})')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Early Fusion - ROC Curve')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.plot([0, 1], [0, 1], 'r--')
plt.savefig('early_fusion_roc_curve.png')
plt.show()

# Save the model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': {
        'max_length': MAX_LENGTH,
        'num_classes': NUM_CLASSES,
        'text_model_name': TEXT_MODEL_NAME,
        'text_dropout_rate': DROPOUT_RATE,
        'visual_dropout_rate': D_O,
        'image_size': IMAGE_SIZE
    }
}, 'early_fusion_multimodal_classifier.pt')

print("Early fusion model training completed and saved!")

# Inference function
'''def predict_multimodal(text, image_path, model, tokenizer, config, device=torch.device('cpu')):
    model.eval()
    
    # Process text
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=config['max_length'],
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt"
    )
    
    # Process image
    weights = Swin_V2_B_Weights.IMAGENET1K_V1
    transform = weights.transforms()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    
    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    image = image.to(device)
    
    with torch.no_grad():
        logits, _, _ = model(input_ids, attention_mask, image)
        _, preds = torch.max(logits, 1)
        predicted_class = target_classes[preds.item()]
        
        # Get probabilities
        probs = torch.nn.functional.softmax(logits, dim=1)
    
    return predicted_class, probs.cpu().numpy()[0]'''