<a href="https://colab.research.google.com/github/yazidiyassine/-Clustering_Heart_Disease_Patient_Data/blob/main/dynafusion_debert_XLNI_AR_v1_2_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip uninstall -y torch
!pip install torch==2.3.0

In [None]:
!pip install torchtext==0.18

In [None]:
!pip install datasets

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import logging
from pathlib import Path
import random

@dataclass
class ModelConfig:
    """Configuration optimized for GPU training with Arabic model"""
    hidden_size: int = 1024
    num_classes: int = 3
    dropout_rate: float = 0.1
    attention_heads: int = 16  # Changed to be divisible into hidden_size
    use_dynafusion: bool = True
    max_length: int = 128
    model_name: str = 'microsoft/mdeberta-v3-base'  # Arabic BERT model
    batch_size: int = 16  # Reduced for GPU memory
    accumulation_steps: int = 2
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.1
    max_epochs: int = 3  # Increased epochs for better convergence
    early_stopping_patience: int = 4  # Increased patience
    early_stopping_threshold: float = 80.0  # Target 80% validation accuracy
    use_fp16: bool = True
    num_workers: int = 2
    weight_decay: float = 0.01
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    subset_size: int = 200000  # Reduced dataset size
    use_cached_model: bool = True

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        # Simplified transformer layer
        self.context_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_size,
            nhead=8,  # Adjusted for Arabic BERT hidden size (typically 768)
            dim_feedforward=2048,  # Adjusted down for Arabic BERT
            dropout=config.dropout_rate,
            batch_first=True
        )

        # Improved fusion gate
        self.fusion_gate = nn.Sequential(
            nn.LayerNorm(self.hidden_size * 2),
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, 2),
            nn.Softmax(dim=-1)
        )

        # Classifier with dropout
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # Handle case when token_type_ids might not be available
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids)
        )

        sequence_output = outputs.last_hidden_state

        # Process sequence through transformer
        context_output = self.context_layer(sequence_output)

        # If token_type_ids is None, create artificial token type IDs
        if token_type_ids is None:
            # Simple approach: first half is question (0), second half is context (1)
            batch_size, seq_len = input_ids.shape
            token_type_ids = torch.cat([
                torch.zeros((batch_size, seq_len // 2), device=input_ids.device).long(),
                torch.ones((batch_size, seq_len - seq_len // 2), device=input_ids.device).long()
            ], dim=1)

        # Separate question and context
        q_mask = (token_type_ids == 0).unsqueeze(-1).float()
        c_mask = (token_type_ids == 1).unsqueeze(-1).float()

        # Mean pooling
        q_seq = (context_output * q_mask).sum(1) / q_mask.sum(1).clamp(min=1e-9)
        c_seq = (context_output * c_mask).sum(1) / c_mask.sum(1).clamp(min=1e-9)

        # Fusion mechanism
        fusion_input = torch.cat([q_seq, c_seq], dim=-1)
        fusion_weights = self.fusion_gate(fusion_input)
        fused_output = (q_seq * fusion_weights[:, 0].unsqueeze(-1) +
                       c_seq * fusion_weights[:, 1].unsqueeze(-1))

        return self.classifier(fused_output)

class FastXNLIDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length: int):
        self.encodings = tokenizer(
            premises,
            hypotheses,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        item = {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx]
        }

        # Add token_type_ids if they exist
        if 'token_type_ids' in self.encodings:
            item['token_type_ids'] = self.encodings['token_type_ids'][idx]
        else:
            # Create synthetic token_type_ids if needed
            seq_len = len(item['input_ids'])
            item['token_type_ids'] = torch.cat([
                torch.zeros(seq_len // 2, dtype=torch.long),
                torch.ones(seq_len - seq_len // 2, dtype=torch.long)
            ])

        return item

class Trainer:
    def __init__(self, config: ModelConfig, model: nn.Module, device: torch.device):
        self.config = config
        self.model = model
        self.device = device
        self.scaler = GradScaler() if config.use_fp16 else None
        self.best_acc = 0
        self.patience_counter = 0
        self.save_path = Path("best_model_arabic.pth")

    def train_epoch(self, train_loader: DataLoader, optimizer: AdamW, scheduler) -> Tuple[float, float]:
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        optimizer.zero_grad()

        for i, batch in enumerate(tqdm(train_loader, desc="Training")):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)

            # Handle token_type_ids if present
            token_type_ids = batch.get('token_type_ids', None)
            if token_type_ids is not None:
                token_type_ids = token_type_ids.to(self.device)

            with autocast(enabled=self.config.use_fp16):
                outputs = self.model(input_ids, attention_mask, token_type_ids)
                loss = F.cross_entropy(outputs, labels)
                loss = loss / self.config.accumulation_steps

            if self.config.use_fp16:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()

            if (i + 1) % self.config.accumulation_steps == 0:
                if self.config.use_fp16:
                    self.scaler.unscale_(optimizer)

                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)

                if self.config.use_fp16:
                    self.scaler.step(optimizer)
                    self.scaler.update()
                else:
                    optimizer.step()

                optimizer.zero_grad()
                scheduler.step()

            total_loss += loss.item() * self.config.accumulation_steps
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return total_loss / len(train_loader), 100 * correct / total

    @torch.no_grad()
    def validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        correct = 0
        total = 0

        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)

            # Handle token_type_ids if present
            token_type_ids = batch.get('token_type_ids', None)
            if token_type_ids is not None:
                token_type_ids = token_type_ids.to(self.device)

            with autocast(enabled=self.config.use_fp16):
                outputs = self.model(input_ids, attention_mask, token_type_ids)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return 100 * correct / total

    def should_stop_early(self, val_acc: float) -> bool:
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.patience_counter = 0
            torch.save(self.model.state_dict(), self.save_path)
            return False

        self.patience_counter += 1
        if self.patience_counter >= self.config.early_stopping_patience:
            print(f"Early stopping triggered! Best validation accuracy: {self.best_acc:.2f}%")
            return True

        if val_acc >= self.config.early_stopping_threshold:
            print(f"Reached target accuracy of {self.config.early_stopping_threshold}%!")
            return True

        return False

def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True
    config = ModelConfig()

    # Set random seeds
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    print(f"Using device: {device}")

    # Load dataset
    print("Loading dataset...")
    dataset = load_dataset('xnli', 'ar')  # Changed to Arabic
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Subset data
    train_data = dataset['train'].select(range(min(config.subset_size, len(dataset['train']))))
    val_data = dataset['validation'].select(range(min(config.subset_size // 2, len(dataset['validation']))))

    print(f"Training samples: {len(train_data)}")
    print(f"Validation samples: {len(val_data)}")

    # Create datasets
    train_dataset = FastXNLIDataset(
        train_data['premise'],
        train_data['hypothesis'],
        train_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastXNLIDataset(
        val_data['premise'],
        val_data['hypothesis'],
        val_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    # Initialize model
    print("Initializing model...")
    model = FastBERTDynaFusion(config).to(device)

    # Optimizer setup
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if "bert" not in n],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate * 10,
        },
        {
            "params": [p for n, p in model.bert.named_parameters()],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate,
        }
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        betas=(config.adam_beta1, config.adam_beta2),
        eps=config.adam_epsilon
    )

    # Learning rate scheduler
    num_training_steps = len(train_loader) * config.max_epochs // config.accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_training_steps * config.warmup_ratio),
        num_training_steps=num_training_steps
    )

    # Initialize trainer
    trainer = Trainer(config, model, device)

    # Training loop
    print("Starting training...")
    for epoch in range(config.max_epochs):
        train_loss, train_acc = trainer.train_epoch(train_loader, optimizer, scheduler)
        val_acc = trainer.validate(val_loader)

        print(f"Epoch {epoch + 1}/{config.max_epochs}:")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Train Accuracy: {train_acc:.2f}%")
        print(f"Validation Accuracy: {val_acc:.2f}%")
        print("-" * 50)

        if trainer.should_stop_early(val_acc):
            break

    # Load best model
    if trainer.save_path.exists():
        model.load_state_dict(torch.load(trainer.save_path))
        print(f"Loaded best model with validation accuracy: {trainer.best_acc:.2f}%")

if __name__ == '__main__':
    main()

Using device: cuda
Loading dataset...


Exception ignored in: <function _xla_gc_callback at 0x78270d8404a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lib/__init__.py", line 96, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 


Training samples: 200000
Validation samples: 2490


KeyboardInterrupt: 

In [9]:
def print_val_acc():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config = ModelConfig()

    # Load dataset and create tokenizer
    dataset = load_dataset('xnli', 'ar')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Prepare validation data
    val_data = dataset['validation'].select(range(min(config.subset_size // 2, len(dataset['validation']))))

    # Create validation dataset
    val_dataset = FastXNLIDataset(
        val_data['premise'],
        val_data['hypothesis'],
        val_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    # Create validation dataloader
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    # Initialize model
    model = FastBERTDynaFusion(config).to(device)

    # Load saved weights
    model.load_state_dict(torch.load('/content/best_model_arabic.pth'))

    # Evaluate
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)

            with autocast(enabled=config.use_fp16):
                outputs = model(input_ids, attention_mask, token_type_ids)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Validation Accuracy: {accuracy:.2f}%")

# Call the function to print validation accuracy
print_val_acc()

  model.load_state_dict(torch.load('/content/best_model_arabic.pth'))


Validating:   0%|          | 0/156 [00:00<?, ?it/s]

  with autocast(enabled=config.use_fp16):


Validation Accuracy: 78.07%


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import logging
from pathlib import Path
import random

@dataclass
class ModelConfig:
    """Configuration optimized for GPU training"""
    hidden_size: int = 1024
    num_classes: int = 3
    dropout_rate: float = 0.1
    attention_heads: int = 16  # Changed to be divisible into hidden_size
    use_dynafusion: bool = True
    max_length: int = 128
    model_name: str = 'microsoft/deberta-v3-large'
    batch_size: int = 16  # Reduced for GPU memory
    accumulation_steps: int = 2
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.1
    max_epochs: int = 10
    early_stopping_patience: int = 3
    early_stopping_threshold: float = 86.0
    use_fp16: bool = True
    num_workers: int = 2
    weight_decay: float = 0.01
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    subset_size: int = 50000  # Reduced dataset size
    use_cached_model: bool = True

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        # Simplified transformer layer
        self.context_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_size,
            nhead=16,  # Must divide hidden_size evenly
            dim_feedforward=4096,
            dropout=config.dropout_rate,
            batch_first=True
        )

        # Improved fusion gate
        self.fusion_gate = nn.Sequential(
            nn.LayerNorm(self.hidden_size * 2),
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, 2),
            nn.Softmax(dim=-1)
        )

        # Classifier with dropout
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        sequence_output = outputs.last_hidden_state

        # Process sequence through transformer
        context_output = self.context_layer(sequence_output)

        # Separate question and context
        q_mask = (token_type_ids == 0).unsqueeze(-1).float()
        c_mask = (token_type_ids == 1).unsqueeze(-1).float()

        # Mean pooling
        q_seq = (context_output * q_mask).sum(1) / q_mask.sum(1).clamp(min=1e-9)
        c_seq = (context_output * c_mask).sum(1) / c_mask.sum(1).clamp(min=1e-9)

        # Fusion mechanism
        fusion_input = torch.cat([q_seq, c_seq], dim=-1)
        fusion_weights = self.fusion_gate(fusion_input)
        fused_output = (q_seq * fusion_weights[:, 0].unsqueeze(-1) +
                       c_seq * fusion_weights[:, 1].unsqueeze(-1))

        return self.classifier(fused_output)

class FastXNLIDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length: int):
        self.encodings = tokenizer(
            premises,
            hypotheses,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'token_type_ids': self.encodings['token_type_ids'][idx],
            'labels': self.labels[idx]
        }

class Trainer:
    def __init__(self, config: ModelConfig, model: nn.Module, device: torch.device):
        self.config = config
        self.model = model
        self.device = device
        self.scaler = GradScaler() if config.use_fp16 else None
        self.best_acc = 0
        self.patience_counter = 0
        self.save_path = Path("best_model.pth")

    def train_epoch(self, train_loader: DataLoader, optimizer: AdamW, scheduler) -> Tuple[float, float]:
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        optimizer.zero_grad()

        for i, batch in enumerate(tqdm(train_loader, desc="Training")):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            with autocast(enabled=self.config.use_fp16):
                outputs = self.model(input_ids, attention_mask, token_type_ids)
                loss = F.cross_entropy(outputs, labels)
                loss = loss / self.config.accumulation_steps

            self.scaler.scale(loss).backward()

            if (i + 1) % self.config.accumulation_steps == 0:
                self.scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                self.scaler.step(optimizer)
                self.scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            total_loss += loss.item() * self.config.accumulation_steps
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return total_loss / len(train_loader), 100 * correct / total

    @torch.no_grad()
    def validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        correct = 0
        total = 0

        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            with autocast(enabled=self.config.use_fp16):
                outputs = self.model(input_ids, attention_mask, token_type_ids)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return 100 * correct / total

    def should_stop_early(self, val_acc: float) -> bool:
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.patience_counter = 0
            torch.save(self.model.state_dict(), self.save_path)
            return False

        self.patience_counter += 1
        if self.patience_counter >= self.config.early_stopping_patience:
            print(f"Early stopping triggered! Best validation accuracy: {self.best_acc:.2f}%")
            return True

        if val_acc >= self.config.early_stopping_threshold:
            print(f"Reached target accuracy of {self.config.early_stopping_threshold}%!")
            return True

        return False

def main():
    # Setup
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True
    config = ModelConfig()

    # Set random seeds
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    torch.cuda.manual_seed_all(42)

    print(f"Using device: {device}")

    # Load Arabic XNLI dataset
    print("Loading dataset...")
    dataset = load_dataset('xnli', 'ar')  # Changed to Arabic
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Subset data
    train_data = dataset['train'].select(range(min(config.subset_size, len(dataset['train']))))
    val_data = dataset['validation'].select(range(min(config.subset_size // 2, len(dataset['validation']))))

    print(f"Training samples: {len(train_data)}")
    print(f"Validation samples: {len(val_data)}")

    # Create datasets
    train_dataset = FastXNLIDataset(
        train_data['premise'],
        train_data['hypothesis'],
        train_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastXNLIDataset(
        val_data['premise'],
        val_data['hypothesis'],
        val_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    # Initialize model
    print("Initializing model...")
    model = FastBERTDynaFusion(config).to(device)

    # Optimizer setup
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if "bert" not in n],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate * 10,
        },
        {
            "params": [p for n, p in model.bert.named_parameters()],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate,
        }
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        betas=(config.adam_beta1, config.adam_beta2),
        eps=config.adam_epsilon
    )

    # Calculate number of training steps
    num_training_steps = len(train_loader) * config.max_epochs // config.accumulation_steps
    num_warmup_steps = int(num_training_steps * config.warmup_ratio)

    # Initialize scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    # Initialize trainer
    trainer = Trainer(config, model, device)

    # Training loop
    print("Starting training...")
    for epoch in range(config.max_epochs):
        train_loss, train_acc = trainer.train_epoch(train_loader, optimizer, scheduler)
        val_acc = trainer.validate(val_loader)

        print(f"Epoch {epoch + 1}/{config.max_epochs}:")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Train Accuracy: {train_acc:.2f}%")
        print(f"Validation Accuracy: {val_acc:.2f}%")
        print("-" * 50)

        if trainer.should_stop_early(val_acc):
            break

    # Load best model
    if trainer.save_path.exists():
        model.load_state_dict(torch.load(trainer.save_path))
        print(f"Loaded best model with validation accuracy: {trainer.best_acc:.2f}%")

if __name__ == '__main__':
    main()import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import logging
from pathlib import Path

@dataclass
class ModelConfig:
    """Configuration for model architecture and training"""
    hidden_size: int = 768
    num_classes: int = 3  # XNLI has 3 classes: entailment, neutral, contradiction
    dropout_rate: float = 0.1
    attention_heads: int = 8
    use_dynafusion: bool = True
    max_length: int = 128
    model_name: str = 'microsoft/deberta-v3-small'  # Changed to smaller model
    batch_size: int = 32  # Increased batch size
    accumulation_steps: int = 2  # Reduced accumulation steps
    learning_rate: float = 3e-5  # Slightly increased learning rate
    warmup_ratio: float = 0.1
    max_epochs: int = 5
    early_stopping_patience: int = 2
    early_stopping_threshold: float = 76.0  # Based on mBERT baseline
    use_fp16: bool = True  # Added explicit FP16 control
    num_workers: int = 4  # Explicit control of DataLoader workers

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        # Simplified context processing
        self.context_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_size,
            nhead=config.attention_heads,
            dim_feedforward=1024,  # Reduced dimension
            dropout=config.dropout_rate,
            batch_first=True,
            norm_first=True
        )

        if config.use_dynafusion:
            # Fusion mechanism for premise and hypothesis representations
            self.fusion_gate = nn.Sequential(
                nn.LayerNorm(self.hidden_size * 2),
                nn.Linear(self.hidden_size * 2, 2),
                nn.Softmax(dim=-1)
            )

        # Simplified classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor) -> torch.Tensor:
        # Get BERT outputs with gradient checkpointing
        self.bert.gradient_checkpointing_enable()
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        sequence_output = outputs.last_hidden_state

        if self.config.use_dynafusion:
            # Process sequence through context layer
            context_output = self.context_layer(sequence_output, src_key_padding_mask=~attention_mask.bool())

            # Separate pooling for premise and hypothesis segments
            q_mask = (token_type_ids == 0).unsqueeze(-1).float()
            c_mask = (token_type_ids == 1).unsqueeze(-1).float()

            # Pool premise and hypothesis representations separately
            q_seq = (sequence_output * q_mask * attention_mask.unsqueeze(-1)).sum(1) / (q_mask * attention_mask.unsqueeze(-1)).sum(1).clamp(min=1e-9)
            c_seq = (sequence_output * c_mask * attention_mask.unsqueeze(-1)).sum(1) / (c_mask * attention_mask.unsqueeze(-1)).sum(1).clamp(min=1e-9)

            # Apply fusion mechanism
            gate_input = torch.cat([q_seq, c_seq], dim=-1)
            gate_weights = self.fusion_gate(gate_input)
            final_output = (q_seq * gate_weights[:, 0].unsqueeze(-1) +
                          c_seq * gate_weights[:, 1].unsqueeze(-1))
        else:
            # Simple masked pooling if fusion is disabled
            mask_expanded = attention_mask.unsqueeze(-1).float()
            final_output = (sequence_output * mask_expanded).sum(1) / mask_expanded.sum(1)

        return self.classifier(final_output)

class FastXNLIDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length: int):
        self.encodings = tokenizer(
            premises,
            hypotheses,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'token_type_ids': self.encodings['token_type_ids'][idx],
            'labels': self.labels[idx]
        }

class Trainer:
    def __init__(self, config: ModelConfig, model: nn.Module, device: torch.device):
        self.config = config
        self.model = model
        self.device = device
        self.scaler = GradScaler() if config.use_fp16 else None
        self.best_acc = 0
        self.patience_counter = 0

    def train_epoch(self, train_loader: DataLoader, optimizer: AdamW, scheduler) -> Tuple[float, float]:
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        optimizer.zero_grad()

        for i, batch in enumerate(tqdm(train_loader)):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            if self.config.use_fp16:
                with autocast():
                    outputs = self.model(input_ids, attention_mask, token_type_ids)
                    loss = F.cross_entropy(outputs, labels)
                    loss = loss / self.config.accumulation_steps

                self.scaler.scale(loss).backward()
            else:
                outputs = self.model(input_ids, attention_mask, token_type_ids)
                loss = F.cross_entropy(outputs, labels)
                loss = loss / self.config.accumulation_steps
                loss.backward()

            if (i + 1) % self.config.accumulation_steps == 0:
                if self.config.use_fp16:
                    self.scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    self.scaler.step(optimizer)
                    self.scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    optimizer.step()

                optimizer.zero_grad()
                scheduler.step()

            total_loss += loss.item() * self.config.accumulation_steps
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return total_loss / len(train_loader), 100 * correct / total

    def validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in tqdm(val_loader):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                token_type_ids = batch['token_type_ids'].to(self.device)
                labels = batch['labels'].to(self.device)

                with autocast() if self.config.use_fp16 else torch.no_grad():
                    outputs = self.model(input_ids, attention_mask, token_type_ids)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        return 100 * correct / total

    def should_stop_early(self, val_acc: float) -> bool:
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.patience_counter = 0
            torch.save(self.model.state_dict(), "best_model.pth")
            return False

        self.patience_counter += 1
        if self.patience_counter >= self.config.early_stopping_patience:
            return True

        if val_acc >= self.config.early_stopping_threshold:
            return True

        return False

def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    config = ModelConfig()

    # Load and prepare data
    dataset = load_dataset('xnli', 'fr')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Prepare datasets
    train_dataset = FastXNLIDataset(
        dataset['train']['premise'],
        dataset['train']['hypothesis'],
        dataset['train']['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastXNLIDataset(
        dataset['validation']['premise'],
        dataset['validation']['hypothesis'],
        dataset['validation']['label'],
        tokenizer,
        max_length=config.max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    # Initialize model and training components
    model = FastBERTDynaFusion(config).to(device)

    # Optimizer with simple configuration
    optimizer = AdamW([
        {'params': model.bert.parameters(), 'lr': config.learning_rate},
        {'params': (p for n, p in model.named_parameters() if not n.startswith('bert')),
         'lr': config.learning_rate * 3}
    ], weight_decay=0.01)

    num_training_steps = len(train_loader) * config.max_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_training_steps * config.warmup_ratio),
        num_training_steps=num_training_steps
    )

    # Initialize trainer
    trainer = Trainer(config, model, device)

    # Training loop
    for epoch in range(config.max_epochs):
        train_loss, train_acc = trainer.train_epoch(train_loader, optimizer, scheduler)
        val_acc = trainer.validate(val_loader)

        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, "
              f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

        if trainer.should_stop_early(val_acc):
            print("Early stopping triggered!")
            break

    # Load best model and evaluate on test set
    model.load_state_dict(torch.load("best_model.pth"))

    test_dataset = FastXNLIDataset(
        dataset['test']['premise'],
        dataset['test']['hypothesis'],
        dataset['test']['label'],
        tokenizer,
        max_length=config.max_length
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    test_acc = trainer.validate(test_loader)
    print(f"Test Accuracy: {test_acc:.2f}%")

if __name__ == '__main__':
    main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


pytorch_model.bin:   0%|          | 0.00/286M [00:00<?, ?B/s]

  self.scaler = GradScaler() if config.use_fp16 else None


model.safetensors:   0%|          | 0.00/286M [00:00<?, ?B/s]

  0%|          | 0/12272 [00:00<?, ?it/s]

  with autocast():


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import logging
from pathlib import Path

@dataclass
class ModelConfig:
    """Configuration for model architecture and training"""
    hidden_size: int = 768
    num_classes: int = 3  # XNLI has 3 classes: entailment, neutral, contradiction
    dropout_rate: float = 0.1
    attention_heads: int = 8
    use_dynafusion: bool = True
    max_length: int = 128
    model_name: str = 'microsoft/deberta-v3-small'  # Small but powerful model
    batch_size: int = 64  # Doubled batch size
    accumulation_steps: int = 1  # No accumulation for faster steps
    learning_rate: float = 5e-5  # Increased learning rate
    warmup_ratio: float = 0.1
    max_epochs: int = 3  # Reduced max epochs
    early_stopping_patience: int = 1  # More aggressive early stopping
    early_stopping_threshold: float = 86.0  # Updated threshold
    use_fp16: bool = True
    num_workers: int = 8  # More workers for data loading
    weight_decay: float = 0.01
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    use_efficient_pooling: bool = True  # Use efficient pooling instead of pooler_output

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        if not config.use_efficient_pooling:
            # Lighter context processing for sequence output
            self.context_layer = nn.TransformerEncoderLayer(
                d_model=self.hidden_size,
                nhead=config.attention_heads,
                dim_feedforward=1024,
                dropout=config.dropout_rate,
                batch_first=True,
                norm_first=True
            )

        if config.use_dynafusion and not config.use_efficient_pooling:
            # Fusion mechanism for premise and hypothesis
            self.fusion_gate = nn.Sequential(
                nn.LayerNorm(self.hidden_size * 2),
                nn.Linear(self.hidden_size * 2, 2),
                nn.Softmax(dim=-1)
            )

        # Efficient classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor) -> torch.Tensor:
        # Get outputs with gradient checkpointing for memory efficiency
        self.bert.gradient_checkpointing_enable()
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        sequence_output = outputs.last_hidden_state

        # Use efficient pooling if configured
        if self.config.use_efficient_pooling:
            # Simple CLS token extraction (first token)
            return self.classifier(sequence_output[:, 0])

        if self.config.use_dynafusion:
            # Process sequence through context layer
            context_output = self.context_layer(sequence_output, src_key_padding_mask=~attention_mask.bool())

            # Separate pooling for premise and hypothesis segments
            q_mask = (token_type_ids == 0).unsqueeze(-1).float()
            c_mask = (token_type_ids == 1).unsqueeze(-1).float()

            # Pool premise and hypothesis representations separately
            q_seq = (sequence_output * q_mask * attention_mask.unsqueeze(-1)).sum(1) / (q_mask * attention_mask.unsqueeze(-1)).sum(1).clamp(min=1e-9)
            c_seq = (sequence_output * c_mask * attention_mask.unsqueeze(-1)).sum(1) / (c_mask * attention_mask.unsqueeze(-1)).sum(1).clamp(min=1e-9)

            # Apply fusion mechanism
            gate_input = torch.cat([q_seq, c_seq], dim=-1)
            gate_weights = self.fusion_gate(gate_input)
            final_output = (q_seq * gate_weights[:, 0].unsqueeze(-1) +
                          c_seq * gate_weights[:, 1].unsqueeze(-1))
        else:
            # Simple masked pooling if fusion is disabled
            mask_expanded = attention_mask.unsqueeze(-1).float()
            final_output = (sequence_output * mask_expanded).sum(1) / mask_expanded.sum(1)

        return self.classifier(final_output)

class FastXNLIDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length: int):
        self.encodings = tokenizer(
            premises,
            hypotheses,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'token_type_ids': self.encodings['token_type_ids'][idx],
            'labels': self.labels[idx]
        }

class Trainer:
    def __init__(self, config: ModelConfig, model: nn.Module, device: torch.device):
        self.config = config
        self.model = model
        self.device = device
        self.scaler = GradScaler() if config.use_fp16 else None
        self.best_acc = 0
        self.patience_counter = 0
        self.save_path = Path("best_model.pth")

    def train_epoch(self, train_loader: DataLoader, optimizer: AdamW, scheduler) -> Tuple[float, float]:
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch in tqdm(train_loader, desc="Training"):
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            if self.config.use_fp16:
                with autocast():
                    outputs = self.model(input_ids, attention_mask, token_type_ids)
                    loss = F.cross_entropy(outputs, labels)

                self.scaler.scale(loss).backward()
                self.scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                self.scaler.step(optimizer)
                self.scaler.update()
            else:
                outputs = self.model(input_ids, attention_mask, token_type_ids)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                optimizer.step()

            scheduler.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return total_loss / len(train_loader), 100 * correct / total

    @torch.no_grad()
    def validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        correct = 0
        total = 0

        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            with autocast() if self.config.use_fp16 else torch.no_grad():
                outputs = self.model(input_ids, attention_mask, token_type_ids)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return 100 * correct / total

    def should_stop_early(self, val_acc: float) -> bool:
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.patience_counter = 0
            torch.save(self.model.state_dict(), self.save_path)
            return False

        self.patience_counter += 1
        if self.patience_counter >= self.config.early_stopping_patience:
            return True

        if val_acc >= self.config.early_stopping_threshold:
            return True

        return False

def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    config = ModelConfig()

    # Set up logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler("training.log")
        ]
    )
    logger = logging.getLogger(__name__)
    logger.info(f"Using device: {device}")
    logger.info(f"Config: {config}")

    # Load and prepare data
    dataset = load_dataset('xnli', 'fr')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Prepare datasets
    train_dataset = FastXNLIDataset(
        dataset['train']['premise'],
        dataset['train']['hypothesis'],
        dataset['train']['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastXNLIDataset(
        dataset['validation']['premise'],
        dataset['validation']['hypothesis'],
        dataset['validation']['label'],
        tokenizer,
        max_length=config.max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    # Initialize model and training components
    model = FastBERTDynaFusion(config).to(device)

    # Optimizer setup
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": config.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=config.learning_rate,
        betas=(config.adam_beta1, config.adam_beta2),
        eps=config.adam_epsilon
    )

    num_training_steps = len(train_loader) * config.max_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_training_steps * config.warmup_ratio),
        num_training_steps=num_training_steps
    )

    # Initialize trainer
    trainer = Trainer(config, model, device)

    # Training loop
    logger.info("Starting training...")
    for epoch in range(config.max_epochs):
        train_loss, train_acc = trainer.train_epoch(train_loader, optimizer, scheduler)
        val_acc = trainer.validate(val_loader)

        logger.info(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, "
              f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

        if trainer.should_stop_early(val_acc):
            logger.info(f"Early stopping triggered! Best validation accuracy: {trainer.best_acc:.2f}%")
            break

    # Load best model and evaluate on test set
    if trainer.save_path.exists():
        model.load_state_dict(torch.load(trainer.save_path))
        logger.info(f"Loaded best model from {trainer.save_path}")

    test_dataset = FastXNLIDataset(
        dataset['test']['premise'],
        dataset['test']['hypothesis'],
        dataset['test']['label'],
        tokenizer,
        max_length=config.max_length
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    test_acc = trainer.validate(test_loader)
    logger.info(f"Test Accuracy: {test_acc:.2f}%")

if __name__ == '__main__':
    main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  self.scaler = GradScaler() if config.use_fp16 else None


Training:   0%|          | 0/6136 [00:01<?, ?it/s]

  with autocast():


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import logging
from pathlib import Path
import random

@dataclass
class ModelConfig:
    """Configuration for model architecture and training"""
    hidden_size: int = 768
    num_classes: int = 3  # XNLI has 3 classes: entailment, neutral, contradiction
    dropout_rate: float = 0.1
    attention_heads: int = 8
    use_dynafusion: bool = True
    max_length: int = 128
    model_name: str = 'microsoft/deberta-v3-small'  # Small but powerful model
    batch_size: int = 128  # Increased batch size for faster training
    accumulation_steps: int = 1  # No accumulation for faster steps
    learning_rate: float = 8e-5  # Increased learning rate for faster convergence
    warmup_ratio: float = 0.05
    max_epochs: int = 10  # More epochs but we'll stop early
    early_stopping_patience: int = 3  # More patience to find the best model
    early_stopping_threshold: float = 86.0  # Updated threshold
    use_fp16: bool = True
    num_workers: int = 4  # Adjusted workers for data loading
    weight_decay: float = 0.01
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    use_efficient_pooling: bool = True  # Use efficient pooling instead of pooler_output
    subset_size: int = 1000  # Use only this many examples from each split
    use_cached_model: bool = True  # Try to use cached model if available

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        if not config.use_efficient_pooling:
            # Lighter context processing for sequence output
            self.context_layer = nn.TransformerEncoderLayer(
                d_model=self.hidden_size,
                nhead=config.attention_heads,
                dim_feedforward=1024,
                dropout=config.dropout_rate,
                batch_first=True,
                norm_first=True
            )

        if config.use_dynafusion and not config.use_efficient_pooling:
            # Fusion mechanism for premise and hypothesis
            self.fusion_gate = nn.Sequential(
                nn.LayerNorm(self.hidden_size * 2),
                nn.Linear(self.hidden_size * 2, 2),
                nn.Softmax(dim=-1)
            )

        # Efficient classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor) -> torch.Tensor:
        # Get outputs with gradient checkpointing for memory efficiency
        self.bert.gradient_checkpointing_enable()
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        sequence_output = outputs.last_hidden_state

        # Use efficient pooling if configured
        if self.config.use_efficient_pooling:
            # Simple CLS token extraction (first token)
            return self.classifier(sequence_output[:, 0])

        if self.config.use_dynafusion:
            # Process sequence through context layer
            context_output = self.context_layer(sequence_output, src_key_padding_mask=~attention_mask.bool())

            # Separate pooling for premise and hypothesis segments
            q_mask = (token_type_ids == 0).unsqueeze(-1).float()
            c_mask = (token_type_ids == 1).unsqueeze(-1).float()

            # Pool premise and hypothesis representations separately
            q_seq = (sequence_output * q_mask * attention_mask.unsqueeze(-1)).sum(1) / (q_mask * attention_mask.unsqueeze(-1)).sum(1).clamp(min=1e-9)
            c_seq = (sequence_output * c_mask * attention_mask.unsqueeze(-1)).sum(1) / (c_mask * attention_mask.unsqueeze(-1)).sum(1).clamp(min=1e-9)

            # Apply fusion mechanism
            gate_input = torch.cat([q_seq, c_seq], dim=-1)
            gate_weights = self.fusion_gate(gate_input)
            final_output = (q_seq * gate_weights[:, 0].unsqueeze(-1) +
                          c_seq * gate_weights[:, 1].unsqueeze(-1))
        else:
            # Simple masked pooling if fusion is disabled
            mask_expanded = attention_mask.unsqueeze(-1).float()
            final_output = (sequence_output * mask_expanded).sum(1) / mask_expanded.sum(1)

        return self.classifier(final_output)

class FastXNLIDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length: int):
        self.encodings = tokenizer(
            premises,
            hypotheses,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'token_type_ids': self.encodings['token_type_ids'][idx],
            'labels': self.labels[idx]
        }

class Trainer:
    def __init__(self, config: ModelConfig, model: nn.Module, device: torch.device):
        self.config = config
        self.model = model
        self.device = device
        self.scaler = GradScaler() if config.use_fp16 else None
        self.best_acc = 0
        self.patience_counter = 0
        self.save_path = Path("best_model.pth")

    def train_epoch(self, train_loader: DataLoader, optimizer: AdamW, scheduler) -> Tuple[float, float]:
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch in tqdm(train_loader, desc="Training"):
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            if self.config.use_fp16:
                with autocast():
                    outputs = self.model(input_ids, attention_mask, token_type_ids)
                    loss = F.cross_entropy(outputs, labels)

                self.scaler.scale(loss).backward()
                self.scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                self.scaler.step(optimizer)
                self.scaler.update()
            else:
                outputs = self.model(input_ids, attention_mask, token_type_ids)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                optimizer.step()

            scheduler.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return total_loss / len(train_loader), 100 * correct / total

    @torch.no_grad()
    def validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        correct = 0
        total = 0

        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            with autocast() if self.config.use_fp16 else torch.no_grad():
                outputs = self.model(input_ids, attention_mask, token_type_ids)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return 100 * correct / total

    def should_stop_early(self, val_acc: float) -> bool:
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.patience_counter = 0
            torch.save(self.model.state_dict(), self.save_path)
            return False

        self.patience_counter += 1
        if self.patience_counter >= self.config.early_stopping_patience:
            return True

        if val_acc >= self.config.early_stopping_threshold:
            logger.info(f"Reached target accuracy of {self.config.early_stopping_threshold}%!")
            return True

        return False

def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    config = ModelConfig()

    # Set up logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler("training.log")
        ]
    )
    global logger
    logger = logging.getLogger(__name__)
    logger.info(f"Using device: {device}")
    logger.info(f"Config: {config}")

    # Load and prepare data - SIMPLIFIED APPROACH
    dataset = load_dataset('xnli', 'fr')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Simple subsetting by taking the first N examples
    train_data = dataset['train'].select(range(min(config.subset_size, len(dataset['train']))))
    val_data = dataset['validation'].select(range(min(config.subset_size // 2, len(dataset['validation']))))
    test_data = dataset['test'].select(range(min(config.subset_size // 2, len(dataset['test']))))

    logger.info(f"Using {len(train_data)} training examples")
    logger.info(f"Using {len(val_data)} validation examples")
    logger.info(f"Using {len(test_data)} test examples")

    # Prepare datasets
    train_dataset = FastXNLIDataset(
        train_data['premise'],
        train_data['hypothesis'],
        train_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastXNLIDataset(
        val_data['premise'],
        val_data['hypothesis'],
        val_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    # Initialize model and training components
    model = FastBERTDynaFusion(config).to(device)

    # Try to load cached model if it exists and option is enabled
    model_path = Path("best_model.pth")
    if config.use_cached_model and model_path.exists():
        try:
            model.load_state_dict(torch.load(model_path, map_location=device))
            logger.info(f"Loaded pre-trained model from {model_path}")

            # Quick validation to see if the loaded model already meets our criteria
            trainer = Trainer(config, model, device)
            val_acc = trainer.validate(val_loader)
            logger.info(f"Pre-trained model validation accuracy: {val_acc:.2f}%")

            if val_acc >= config.early_stopping_threshold:
                logger.info(f"Pre-trained model already meets accuracy threshold of {config.early_stopping_threshold}%")

                # Create test dataset and evaluate
                test_dataset = FastXNLIDataset(
                    test_data['premise'],
                    test_data['hypothesis'],
                    test_data['label'],
                    tokenizer,
                    max_length=config.max_length
                )

                test_loader = DataLoader(
                    test_dataset,
                    batch_size=config.batch_size * 2,
                    shuffle=False,
                    num_workers=config.num_workers,
                    pin_memory=True,
                    persistent_workers=True if config.num_workers > 0 else False
                )

                test_acc = trainer.validate(test_loader)
                logger.info(f"Test Accuracy: {test_acc:.2f}%")
                return
        except Exception as e:
            logger.warning(f"Failed to load pre-trained model: {e}")
            logger.info("Training a new model...")

    # Optimizer setup with differential learning rates
    no_decay = ["bias", "LayerNorm.weight"]

    # Higher learning rate for classifier, lower for pre-trained model
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.classifier.named_parameters()],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate * 2,  # Higher learning rate for classifier
        },
        {
            "params": [p for n, p in model.bert.named_parameters()
                      if not any(nd in n for nd in no_decay)],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate,
        },
        {
            "params": [p for n, p in model.bert.named_parameters()
                      if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
            "lr": config.learning_rate,
        },
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        betas=(config.adam_beta1, config.adam_beta2),
        eps=config.adam_epsilon
    )

    num_training_steps = len(train_loader) * config.max_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_training_steps * config.warmup_ratio),
        num_training_steps=num_training_steps
    )

    # Initialize trainer
    trainer = Trainer(config, model, device)

    # Training loop
    logger.info("Starting training...")
    for epoch in range(config.max_epochs):
        train_loss, train_acc = trainer.train_epoch(train_loader, optimizer, scheduler)
        val_acc = trainer.validate(val_loader)

        logger.info(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, "
              f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

        if trainer.should_stop_early(val_acc):
            logger.info(f"Early stopping triggered! Best validation accuracy: {trainer.best_acc:.2f}%")
            break

    # Load best model and evaluate on test set
    if trainer.save_path.exists():
        model.load_state_dict(torch.load(trainer.save_path))
        logger.info(f"Loaded best model from {trainer.save_path}")

    test_dataset = FastXNLIDataset(
        test_data['premise'],
        test_data['hypothesis'],
        test_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    test_acc = trainer.validate(test_loader)
    logger.info(f"Test Accuracy: {test_acc:.2f}%")

if __name__ == '__main__':
    # Set seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    torch.cuda.manual_seed_all(42)

    main()

  self.scaler = GradScaler() if config.use_fp16 else None


Training:   0%|          | 0/8 [00:00<?, ?it/s]

  with autocast():


Validating:   0%|          | 0/2 [00:00<?, ?it/s]

  with autocast() if self.config.use_fp16 else torch.no_grad():


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Validating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/8 [00:00<?, ?it/s]

Validating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/8 [00:00<?, ?it/s]

Validating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/8 [00:00<?, ?it/s]

Validating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/8 [00:00<?, ?it/s]

Validating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/8 [00:00<?, ?it/s]

Validating:   0%|          | 0/2 [00:00<?, ?it/s]

Training:   0%|          | 0/8 [00:00<?, ?it/s]

Validating:   0%|          | 0/2 [00:00<?, ?it/s]

  model.load_state_dict(torch.load(trainer.save_path))


Validating:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import logging
from pathlib import Path
import random

@dataclass
class ModelConfig:
    """Configuration for model architecture and training"""
    hidden_size: int = 768
    num_classes: int = 3  # XNLI has 3 classes: entailment, neutral, contradiction
    dropout_rate: float = 0.1
    attention_heads: int = 8
    use_dynafusion: bool = True
    max_length: int = 128
    model_name: str = 'microsoft/deberta-v3-base'  # Small but powerful model
    batch_size: int = 128  # Increased batch size for faster training
    accumulation_steps: int = 1  # No accumulation for faster steps
    learning_rate: float = 8e-5  # Increased learning rate for faster convergence
    warmup_ratio: float = 0.05
    max_epochs: int = 10  # More epochs but we'll stop early
    early_stopping_patience: int = 3  # More patience to find the best model
    early_stopping_threshold: float = 86.0  # Updated threshold
    use_fp16: bool = True
    num_workers: int = 4  # Adjusted workers for data loading
    weight_decay: float = 0.01
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    use_efficient_pooling: bool = True  # Use efficient pooling instead of pooler_output
    subset_size: int = 100000  # Use only this many examples from each split
    use_cached_model: bool = True  # Try to use cached model if available

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        if not config.use_efficient_pooling:
            # Lighter context processing for sequence output
            self.context_layer = nn.TransformerEncoderLayer(
                d_model=self.hidden_size,
                nhead=config.attention_heads,
                dim_feedforward=1024,
                dropout=config.dropout_rate,
                batch_first=True,
                norm_first=True
            )

        if config.use_dynafusion and not config.use_efficient_pooling:
            # Fusion mechanism for premise and hypothesis
            self.fusion_gate = nn.Sequential(
                nn.LayerNorm(self.hidden_size * 2),
                nn.Linear(self.hidden_size * 2, 2),
                nn.Softmax(dim=-1)
            )

        # Efficient classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor) -> torch.Tensor:
        # Get outputs with gradient checkpointing for memory efficiency
        self.bert.gradient_checkpointing_enable()
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        sequence_output = outputs.last_hidden_state

        # Use efficient pooling if configured
        if self.config.use_efficient_pooling:
            # Simple CLS token extraction (first token)
            return self.classifier(sequence_output[:, 0])

        if self.config.use_dynafusion:
            # Process sequence through context layer
            context_output = self.context_layer(sequence_output, src_key_padding_mask=~attention_mask.bool())

            # Separate pooling for premise and hypothesis segments
            q_mask = (token_type_ids == 0).unsqueeze(-1).float()
            c_mask = (token_type_ids == 1).unsqueeze(-1).float()

            # Pool premise and hypothesis representations separately
            q_seq = (sequence_output * q_mask * attention_mask.unsqueeze(-1)).sum(1) / (q_mask * attention_mask.unsqueeze(-1)).sum(1).clamp(min=1e-9)
            c_seq = (sequence_output * c_mask * attention_mask.unsqueeze(-1)).sum(1) / (c_mask * attention_mask.unsqueeze(-1)).sum(1).clamp(min=1e-9)

            # Apply fusion mechanism
            gate_input = torch.cat([q_seq, c_seq], dim=-1)
            gate_weights = self.fusion_gate(gate_input)
            final_output = (q_seq * gate_weights[:, 0].unsqueeze(-1) +
                          c_seq * gate_weights[:, 1].unsqueeze(-1))
        else:
            # Simple masked pooling if fusion is disabled
            mask_expanded = attention_mask.unsqueeze(-1).float()
            final_output = (sequence_output * mask_expanded).sum(1) / mask_expanded.sum(1)

        return self.classifier(final_output)

class FastXNLIDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length: int):
        self.encodings = tokenizer(
            premises,
            hypotheses,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'token_type_ids': self.encodings['token_type_ids'][idx],
            'labels': self.labels[idx]
        }

class Trainer:
    def __init__(self, config: ModelConfig, model: nn.Module, device: torch.device):
        self.config = config
        self.model = model
        self.device = device
        self.scaler = GradScaler() if config.use_fp16 else None
        self.best_acc = 0
        self.patience_counter = 0
        self.save_path = Path("best_model.pth")
        self.val_accuracies = []  # Track validation accuracies

    def train_epoch(self, train_loader: DataLoader, optimizer: AdamW, scheduler) -> Tuple[float, float]:
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch in tqdm(train_loader, desc="Training"):
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            if self.config.use_fp16:
                with autocast():
                    outputs = self.model(input_ids, attention_mask, token_type_ids)
                    loss = F.cross_entropy(outputs, labels)

                self.scaler.scale(loss).backward()
                self.scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                self.scaler.step(optimizer)
                self.scaler.update()
            else:
                outputs = self.model(input_ids, attention_mask, token_type_ids)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                optimizer.step()

            scheduler.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return total_loss / len(train_loader), 100 * correct / total

    @torch.no_grad()
    def validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        correct = 0
        total = 0

        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            with autocast() if self.config.use_fp16 else torch.no_grad():
                outputs = self.model(input_ids, attention_mask, token_type_ids)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        self.val_accuracies.append(acc)  # Store the validation accuracy
        return acc

    def should_stop_early(self, val_acc: float) -> bool:
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.patience_counter = 0
            torch.save(self.model.state_dict(), self.save_path)
            return False

        self.patience_counter += 1
        if self.patience_counter >= self.config.early_stopping_patience:
            print(f"Early stopping triggered! Best validation accuracy: {self.best_acc:.2f}%")
            return True

        if val_acc >= self.config.early_stopping_threshold:
            print(f"Reached target accuracy of {self.config.early_stopping_threshold}%!")
            return True

        return False

def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    config = ModelConfig()

    # Set up logging for file only, not console
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler("training.log")
        ]
    )
    global logger
    logger = logging.getLogger(__name__)

    print(f"Using device: {device}")

    # Load and prepare data
    print("Loading dataset...")
    dataset = load_dataset('xnli', 'fr')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Simple subsetting by taking the first N examples
    train_data = dataset['train'].select(range(min(config.subset_size, len(dataset['train']))))
    val_data = dataset['validation'].select(range(min(config.subset_size // 2, len(dataset['validation']))))
    test_data = dataset['test'].select(range(min(config.subset_size // 2, len(dataset['test']))))

    print(f"Using {len(train_data)} training examples")
    print(f"Using {len(val_data)} validation examples")
    print(f"Using {len(test_data)} test examples")

    # Prepare datasets
    train_dataset = FastXNLIDataset(
        train_data['premise'],
        train_data['hypothesis'],
        train_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastXNLIDataset(
        val_data['premise'],
        val_data['hypothesis'],
        val_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    # Initialize model and training components
    print("Initializing model...")
    model = FastBERTDynaFusion(config).to(device)

    # Try to load cached model if it exists and option is enabled
    model_path = Path("best_model.pth")
    if config.use_cached_model and model_path.exists():
        try:
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"Loaded pre-trained model from {model_path}")

            # Quick validation to see if the loaded model already meets our criteria
            trainer = Trainer(config, model, device)
            val_acc = trainer.validate(val_loader)
            print(f"Pre-trained model validation accuracy: {val_acc:.2f}%")

            if val_acc >= config.early_stopping_threshold:
                print(f"Pre-trained model already meets accuracy threshold of {config.early_stopping_threshold}%")

                # Create test dataset and evaluate
                test_dataset = FastXNLIDataset(
                    test_data['premise'],
                    test_data['hypothesis'],
                    test_data['label'],
                    tokenizer,
                    max_length=config.max_length
                )

                test_loader = DataLoader(
                    test_dataset,
                    batch_size=config.batch_size * 2,
                    shuffle=False,
                    num_workers=config.num_workers,
                    pin_memory=True,
                    persistent_workers=True if config.num_workers > 0 else False
                )

                test_acc = trainer.validate(test_loader)
                print(f"Test Accuracy: {test_acc:.2f}%")
                return
        except Exception as e:
            print(f"Failed to load pre-trained model: {e}")
            print("Training a new model...")

    # Optimizer setup with differential learning rates
    no_decay = ["bias", "LayerNorm.weight"]

    # Higher learning rate for classifier, lower for pre-trained model
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.classifier.named_parameters()],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate * 2,  # Higher learning rate for classifier
        },
        {
            "params": [p for n, p in model.bert.named_parameters()
                      if not any(nd in n for nd in no_decay)],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate,
        },
        {
            "params": [p for n, p in model.bert.named_parameters()
                      if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
            "lr": config.learning_rate,
        },
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        betas=(config.adam_beta1, config.adam_beta2),
        eps=config.adam_epsilon
    )

    num_training_steps = len(train_loader) * config.max_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_training_steps * config.warmup_ratio),
        num_training_steps=num_training_steps
    )

    # Initialize trainer
    trainer = Trainer(config, model, device)

    # Training loop
    print("Starting training...")
    for epoch in range(config.max_epochs):
        train_loss, train_acc = trainer.train_epoch(train_loader, optimizer, scheduler)
        val_acc = trainer.validate(val_loader)

        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, "
              f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

        if trainer.should_stop_early(val_acc):
            break

    # Print all validation accuracies
    print("\nValidation accuracy for each epoch:")
    for epoch, acc in enumerate(trainer.val_accuracies):
        print(f"Epoch {epoch + 1}: Val Acc: {acc:.2f}%")

    # Load best model and evaluate on test set
    if trainer.save_path.exists():
        model.load_state_dict(torch.load(trainer.save_path))
        print(f"Loaded best model from {trainer.save_path}")

    test_dataset = FastXNLIDataset(
        test_data['premise'],
        test_data['hypothesis'],
        test_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )

    test_acc = trainer.validate(test_loader)
    print(f"Final Test Accuracy: {test_acc:.2f}%")

if __name__ == '__main__':
    # Set seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    torch.cuda.manual_seed_all(42)

    main()

Using device: cuda
Loading dataset...
Using 100000 training examples
Using 2490 validation examples
Using 5010 test examples
Initializing model...


  model.load_state_dict(torch.load(model_path, map_location=device))


Failed to load pre-trained model: Error(s) in loading state_dict for FastBERTDynaFusion:
	Missing key(s) in state_dict: "bert.encoder.layer.6.attention.self.query_proj.weight", "bert.encoder.layer.6.attention.self.query_proj.bias", "bert.encoder.layer.6.attention.self.key_proj.weight", "bert.encoder.layer.6.attention.self.key_proj.bias", "bert.encoder.layer.6.attention.self.value_proj.weight", "bert.encoder.layer.6.attention.self.value_proj.bias", "bert.encoder.layer.6.attention.output.dense.weight", "bert.encoder.layer.6.attention.output.dense.bias", "bert.encoder.layer.6.attention.output.LayerNorm.weight", "bert.encoder.layer.6.attention.output.LayerNorm.bias", "bert.encoder.layer.6.intermediate.dense.weight", "bert.encoder.layer.6.intermediate.dense.bias", "bert.encoder.layer.6.output.dense.weight", "bert.encoder.layer.6.output.dense.bias", "bert.encoder.layer.6.output.LayerNorm.weight", "bert.encoder.layer.6.output.LayerNorm.bias", "bert.encoder.layer.7.attention.self.query_proj.we

  self.scaler = GradScaler() if config.use_fp16 else None


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():


Validating:   0%|          | 0/10 [00:00<?, ?it/s]

  with autocast() if self.config.use_fp16 else torch.no_grad():


Epoch 1: Train Loss: 0.7460, Train Acc: 68.42%, Val Acc: 68.92%


Training:   0%|          | 0/782 [00:00<?, ?it/s]

Validating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 2: Train Loss: 0.6073, Train Acc: 75.53%, Val Acc: 72.17%


Training:   0%|          | 0/782 [00:00<?, ?it/s]

Validating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 3: Train Loss: 0.4899, Train Acc: 81.20%, Val Acc: 74.94%


Training:   0%|          | 0/782 [00:00<?, ?it/s]

Validating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 4: Train Loss: 0.3858, Train Acc: 85.72%, Val Acc: 73.05%


Training:   0%|          | 0/782 [00:00<?, ?it/s]

KeyboardInterrupt: 

# v2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import logging
from pathlib import Path
import random

@dataclass
class ModelConfig:
    """Configuration optimized for GPU training"""
    hidden_size: int = 1024
    num_classes: int = 3
    dropout_rate: float = 0.1
    attention_heads: int = 16  # Changed to be divisible into hidden_size
    use_dynafusion: bool = True
    max_length: int = 128
    model_name: str = 'microsoft/deberta-v3-large'
    batch_size: int = 16  # Reduced for GPU memory
    accumulation_steps: int = 2
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.1
    max_epochs: int = 10
    early_stopping_patience: int = 3
    early_stopping_threshold: float = 86.0
    use_fp16: bool = True
    num_workers: int = 2
    weight_decay: float = 0.01
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    subset_size: int = 50000  # Reduced dataset size
    use_cached_model: bool = True

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        # Simplified transformer layer
        self.context_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_size,
            nhead=16,  # Must divide hidden_size evenly
            dim_feedforward=4096,
            dropout=config.dropout_rate,
            batch_first=True
        )

        # Improved fusion gate
        self.fusion_gate = nn.Sequential(
            nn.LayerNorm(self.hidden_size * 2),
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, 2),
            nn.Softmax(dim=-1)
        )

        # Classifier with dropout
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )

        sequence_output = outputs.last_hidden_state

        # Process sequence through transformer
        context_output = self.context_layer(sequence_output)

        # Separate question and context
        q_mask = (token_type_ids == 0).unsqueeze(-1).float()
        c_mask = (token_type_ids == 1).unsqueeze(-1).float()

        # Mean pooling
        q_seq = (context_output * q_mask).sum(1) / q_mask.sum(1).clamp(min=1e-9)
        c_seq = (context_output * c_mask).sum(1) / c_mask.sum(1).clamp(min=1e-9)

        # Fusion mechanism
        fusion_input = torch.cat([q_seq, c_seq], dim=-1)
        fusion_weights = self.fusion_gate(fusion_input)
        fused_output = (q_seq * fusion_weights[:, 0].unsqueeze(-1) +
                       c_seq * fusion_weights[:, 1].unsqueeze(-1))

        return self.classifier(fused_output)

class FastXNLIDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length: int):
        self.encodings = tokenizer(
            premises,
            hypotheses,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'token_type_ids': self.encodings['token_type_ids'][idx],
            'labels': self.labels[idx]
        }

class Trainer:
    def __init__(self, config: ModelConfig, model: nn.Module, device: torch.device):
        self.config = config
        self.model = model
        self.device = device
        self.scaler = GradScaler() if config.use_fp16 else None
        self.best_acc = 0
        self.patience_counter = 0
        self.save_path = Path("best_model.pth")

    def train_epoch(self, train_loader: DataLoader, optimizer: AdamW, scheduler) -> Tuple[float, float]:
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        optimizer.zero_grad()

        for i, batch in enumerate(tqdm(train_loader, desc="Training")):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            with autocast(enabled=self.config.use_fp16):
                outputs = self.model(input_ids, attention_mask, token_type_ids)
                loss = F.cross_entropy(outputs, labels)
                loss = loss / self.config.accumulation_steps

            self.scaler.scale(loss).backward()

            if (i + 1) % self.config.accumulation_steps == 0:
                self.scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                self.scaler.step(optimizer)
                self.scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            total_loss += loss.item() * self.config.accumulation_steps
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return total_loss / len(train_loader), 100 * correct / total

    @torch.no_grad()
    def validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        correct = 0
        total = 0

        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            token_type_ids = batch['token_type_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            with autocast(enabled=self.config.use_fp16):
                outputs = self.model(input_ids, attention_mask, token_type_ids)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return 100 * correct / total

    def should_stop_early(self, val_acc: float) -> bool:
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.patience_counter = 0
            torch.save(self.model.state_dict(), self.save_path)
            return False

        self.patience_counter += 1
        if self.patience_counter >= self.config.early_stopping_patience:
            print(f"Early stopping triggered! Best validation accuracy: {self.best_acc:.2f}%")
            return True

        if val_acc >= self.config.early_stopping_threshold:
            print(f"Reached target accuracy of {self.config.early_stopping_threshold}%!")
            return True

        return False

def main():
    # Setup
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True
    config = ModelConfig()

    # Set random seeds
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    torch.cuda.manual_seed_all(42)

    print(f"Using device: {device}")

    # Load dataset
    print("Loading dataset...")
    dataset = load_dataset('xnli', 'fr')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Subset data
    train_data = dataset['train'].select(range(min(config.subset_size, len(dataset['train']))))
    val_data = dataset['validation'].select(range(min(config.subset_size // 2, len(dataset['validation']))))

    print(f"Training samples: {len(train_data)}")
    print(f"Validation samples: {len(val_data)}")

    # Create datasets
    train_dataset = FastXNLIDataset(
        train_data['premise'],
        train_data['hypothesis'],
        train_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastXNLIDataset(
        val_data['premise'],
        val_data['hypothesis'],
        val_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    # Initialize model
    print("Initializing model...")
    model = FastBERTDynaFusion(config).to(device)

    # Optimizer setup
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if "bert" not in n],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate * 10,
        },
        {
            "params": [p for n, p in model.bert.named_parameters()],
            "weight_decay": config.weight_decay,
            "lr": config.learning_rate,
        }
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        betas=(config.adam_beta1, config.adam_beta2),
        eps=config.adam_epsilon
    )

    # Learning rate scheduler
    num_training_steps = len(train_loader) * config.max_epochs // config.accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_training_steps * config.warmup_ratio),
        num_training_steps=num_training_steps
    )

    # Initialize trainer
    trainer = Trainer(config, model, device)

    # Training loop
    print("Starting training...")
    for epoch in range(config.max_epochs):
        train_loss, train_acc = trainer.train_epoch(train_loader, optimizer, scheduler)
        val_acc = trainer.validate(val_loader)

        print(f"Epoch {epoch + 1}/{config.max_epochs}:")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Train Accuracy: {train_acc:.2f}%")
        print(f"Validation Accuracy: {val_acc:.2f}%")
        print("-" * 50)

        if trainer.should_stop_early(val_acc):
            break

    # Load best model
    if trainer.save_path.exists():
        model.load_state_dict(torch.load(trainer.save_path))
        print(f"Loaded best model with validation accuracy: {trainer.best_acc:.2f}%")

if __name__ == '__main__':
    main()

Using device: cuda
Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/55.4M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/360k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/183k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5010 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2490 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



Training samples: 50000
Validation samples: 2490
Initializing model...


pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/874M [00:00<?, ?B/s]

Starting training...


  self.scaler = GradScaler() if config.use_fp16 else None


Training:   0%|          | 0/3125 [00:00<?, ?it/s]

  with autocast(enabled=self.config.use_fp16):


Validating:   0%|          | 0/78 [00:00<?, ?it/s]

  with autocast(enabled=self.config.use_fp16):


Epoch 1/10:
Train Loss: 0.5772
Train Accuracy: 76.47%
Validation Accuracy: 83.33%
--------------------------------------------------


Training:   0%|          | 0/3125 [00:00<?, ?it/s]

Validating:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 2/10:
Train Loss: 0.3869
Train Accuracy: 86.06%
Validation Accuracy: 84.26%
--------------------------------------------------


Training:   0%|          | 0/3125 [00:00<?, ?it/s]

Validating:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 3/10:
Train Loss: 0.2349
Train Accuracy: 92.16%
Validation Accuracy: 84.86%
--------------------------------------------------


Training:   0%|          | 0/3125 [00:00<?, ?it/s]

Validating:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 4/10:
Train Loss: 0.1573
Train Accuracy: 95.35%
Validation Accuracy: 84.22%
--------------------------------------------------


Training:   0%|          | 0/3125 [00:00<?, ?it/s]

Validating:   0%|          | 0/78 [00:00<?, ?it/s]

Epoch 5/10:
Train Loss: 0.1054
Train Accuracy: 97.18%
Validation Accuracy: 83.49%
--------------------------------------------------


Training:   0%|          | 0/3125 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.cuda.amp import autocast

def evaluate_saved_model():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config = ModelConfig()

    # Load dataset and create tokenizer
    dataset = load_dataset('xnli', 'fr')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Prepare validation data
    val_data = dataset['validation'].select(range(min(config.subset_size // 2, len(dataset['validation']))))

    # Create validation dataset
    val_dataset = FastXNLIDataset(
        val_data['premise'],
        val_data['hypothesis'],
        val_data['label'],
        tokenizer,
        max_length=config.max_length
    )

    # Create validation dataloader
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    # Initialize model
    model = FastBERTDynaFusion(config).to(device)

    # Load saved weights
    model.load_state_dict(torch.load('/content/best_model.pth'))

    # Evaluate
    model.eval()
    correct = 0
    total = 0

    print("Evaluating model...")
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)

            with autocast(enabled=config.use_fp16):
                outputs = model(input_ids, attention_mask, token_type_ids)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"\nValidation Accuracy of saved model: {accuracy:.2f}%")

if __name__ == '__main__':
    evaluate_saved_model()

  model.load_state_dict(torch.load('/content/best_model.pth'))


Evaluating model...


Validating:   0%|          | 0/156 [00:00<?, ?it/s]

  with autocast(enabled=config.use_fp16):



Validation Accuracy of saved model: 84.86%


# SST2

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.0-py3-none-any.whl (484 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m484.9/484.9 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading x

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple
import logging
from pathlib import Path
import re
from bs4 import BeautifulSoup
import nltk
from nltk.tokenize import word_tokenize
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

@dataclass
class ModelConfig:
    """Optimized configuration for faster convergence"""
    hidden_size: int = 768
    num_classes: int = 2
    dropout_rate: float = 0.1  # Reduced dropout for faster convergence
    use_dynafusion: bool = True
    max_length: int = 256  # Reduced sequence length
    model_name: str = 'microsoft/deberta-v3-small'  # Smaller, faster model

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        # Enhanced context processing
        self.context_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=self.hidden_size,
                nhead=config.attention_heads,
                dim_feedforward=3072,  # Increased capacity
                dropout=config.dropout_rate,
                batch_first=True
            ) for _ in range(config.num_transformer_layers)
        ])

        # Enhanced fusion mechanism
        if config.use_dynafusion:
            self.fusion_gate = nn.Sequential(
                nn.Linear(self.hidden_size * 2, self.hidden_size),
                nn.GELU(),
                nn.Dropout(config.dropout_rate),
                nn.Linear(self.hidden_size, 2),
                nn.Softmax(dim=-1)
            )

        # Enhanced classifier with skip connections
        self.pre_classifier = nn.Linear(self.hidden_size, self.hidden_size)
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size // 2, config.num_classes)
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        if self.config.use_dynafusion:
            # Enhanced context processing
            context_output = sequence_output
            for layer in self.context_layers:
                context_output = layer(context_output, src_key_padding_mask=~attention_mask.bool())

            # Weighted pooling
            attention_weights = torch.softmax(
                torch.matmul(sequence_output, context_output.transpose(-2, -1)) / np.sqrt(self.hidden_size),
                dim=-1
            )
            pooled_sequence = torch.matmul(attention_weights, sequence_output)
            pooled_context = torch.matmul(attention_weights, context_output)

            # Enhanced fusion
            gate_input = torch.cat([
                torch.mean(pooled_sequence, dim=1),
                torch.mean(pooled_context, dim=1)
            ], dim=-1)
            gate_weights = self.fusion_gate(gate_input)

            final_output = (torch.mean(pooled_sequence, dim=1) * gate_weights[:, 0].unsqueeze(-1) +
                          torch.mean(pooled_context, dim=1) * gate_weights[:, 1].unsqueeze(-1))
        else:
            final_output = torch.mean(sequence_output * attention_mask.unsqueeze(-1), dim=1)

        # Skip connection in classifier
        pre_class = self.pre_classifier(final_output)
        return self.classifier(pre_class + final_output)
class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        # Simplified context processing
        self.context_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_size,
            nhead=8,  # Reduced heads
            dim_feedforward=2048,
            dropout=config.dropout_rate,
            batch_first=True
        )

        if config.use_dynafusion:
            # Optimized fusion mechanism
            self.fusion_gate = nn.Sequential(
                nn.Linear(self.hidden_size * 2, 2),
                nn.Softmax(dim=-1)
            )

        # Simplified classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        if self.config.use_dynafusion:
            # Simplified fusion
            context_output = self.context_layer(sequence_output)
            pooled_sequence = torch.mean(sequence_output * attention_mask.unsqueeze(-1), dim=1)
            pooled_context = torch.mean(context_output * attention_mask.unsqueeze(-1), dim=1)

            gate_input = torch.cat([pooled_sequence, pooled_context], dim=-1)
            gate_weights = self.fusion_gate(gate_input)

            final_output = (pooled_sequence * gate_weights[:, 0].unsqueeze(-1) +
                          pooled_context * gate_weights[:, 1].unsqueeze(-1))
        else:
            final_output = torch.mean(sequence_output * attention_mask.unsqueeze(-1), dim=1)

        return self.classifier(final_output)

class FastIMDBDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx]
        }

def train_epoch(model, train_loader, optimizer, scheduler, scaler, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with autocast():
            outputs = model(input_ids, attention_mask)
            loss = F.cross_entropy(outputs, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return total_loss / len(train_loader), 100 * correct / total

def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True  # Enable cudnn autotuner

    config = ModelConfig()

    # Load and prepare data
    dataset = load_dataset('imdb')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Create datasets with optimized preprocessing
    train_dataset = FastIMDBDataset(
        dataset['train']['text'],
        dataset['train']['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastIMDBDataset(
        dataset['test']['text'],
        dataset['test']['label'],
        tokenizer,
        max_length=config.max_length
    )

    # Optimized data loading
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,  # Increased batch size
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=64,  # Larger validation batch size
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Initialize model and training components
    model = FastBERTDynaFusion(config).to(device)

    # Optimized training setup
    optimizer = AdamW([
        {'params': model.bert.parameters(), 'lr': 3e-5},
        {'params': (p for n, p in model.named_parameters() if not n.startswith('bert')), 'lr': 1e-4}
    ], weight_decay=0.01)

    num_training_steps = len(train_loader) * 4  # 4 epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_training_steps // 10,
        num_training_steps=num_training_steps
    )

    scaler = GradScaler()
    best_acc = 0

    # Training loop
    for epoch in range(4):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, scaler, device)

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch in tqdm(val_loader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(input_ids, attention_mask)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_acc = 100 * val_correct / val_total
        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")

        if val_acc >= 98.5:  # Early stopping if target achieved
            print(f"Target accuracy achieved! Stopping training.")
            break

if __name__ == '__main__':
    main()

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/578 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/286M [00:00<?, ?B/s]

  scaler = GradScaler()
  with autocast():


model.safetensors:   0%|          | 0.00/286M [00:00<?, ?B/s]

100%|██████████| 782/782 [05:08<00:00,  2.53it/s]
100%|██████████| 391/391 [04:20<00:00,  1.50it/s]


Epoch 1: Train Loss: 0.2729, Train Acc: 88.90%, Val Acc: 92.36%


100%|██████████| 782/782 [05:07<00:00,  2.55it/s]
100%|██████████| 391/391 [04:19<00:00,  1.51it/s]


Epoch 2: Train Loss: 0.1436, Train Acc: 95.06%, Val Acc: 94.28%


100%|██████████| 782/782 [05:07<00:00,  2.55it/s]
100%|██████████| 391/391 [04:19<00:00,  1.51it/s]


Epoch 3: Train Loss: 0.0788, Train Acc: 97.74%, Val Acc: 94.12%


100%|██████████| 782/782 [05:07<00:00,  2.55it/s]
100%|██████████| 391/391 [04:19<00:00,  1.51it/s]


Epoch 4: Train Loss: 0.0417, Train Acc: 98.94%, Val Acc: 94.14%


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
import logging
from pathlib import Path

@dataclass
class ModelConfig:
    """Configuration for model architecture and training"""
    hidden_size: int = 768
    num_classes: int = 2
    dropout_rate: float = 0.1
    attention_heads: int = 8
    use_dynafusion: bool = True
    max_length: int = 256
    model_name: str = 'microsoft/deberta-v3-small'
    batch_size: int = 32
    accumulation_steps: int = 2
    learning_rate: float = 3e-5
    warmup_ratio: float = 0.1
    max_epochs: int = 4
    early_stopping_patience: int = 2
    early_stopping_threshold: float = 98.5

class FastBERTDynaFusion(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.hidden_size = self.bert.config.hidden_size

        # Efficient context processing with single transformer layer
        self.context_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_size,
            nhead=config.attention_heads,
            dim_feedforward=2048,
            dropout=config.dropout_rate,
            batch_first=True,
            norm_first=True  # Pre-normalization for better training stability
        )

        if config.use_dynafusion:
            # Simplified fusion mechanism with layer normalization
            self.fusion_gate = nn.Sequential(
                nn.LayerNorm(self.hidden_size * 2),
                nn.Linear(self.hidden_size * 2, 2),
                nn.Softmax(dim=-1)
            )

        # Efficient classifier with residual connection
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(self.hidden_size, config.num_classes)
        )

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        # Use gradient checkpointing for memory efficiency
        self.bert.gradient_checkpointing_enable()
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        if self.config.use_dynafusion:
            # Efficient attention-based fusion
            context_output = self.context_layer(sequence_output, src_key_padding_mask=~attention_mask.bool())

            # Efficient pooling with attention mask
            mask_expanded = attention_mask.unsqueeze(-1).float()
            pooled_sequence = (sequence_output * mask_expanded).sum(1) / mask_expanded.sum(1)
            pooled_context = (context_output * mask_expanded).sum(1) / mask_expanded.sum(1)

            # Gated fusion
            gate_input = torch.cat([pooled_sequence, pooled_context], dim=-1)
            gate_weights = self.fusion_gate(gate_input)

            final_output = (pooled_sequence * gate_weights[:, 0].unsqueeze(-1) +
                          pooled_context * gate_weights[:, 1].unsqueeze(-1))
        else:
            # Simple masked pooling if fusion is disabled
            mask_expanded = attention_mask.unsqueeze(-1).float()
            final_output = (sequence_output * mask_expanded).sum(1) / mask_expanded.sum(1)

        return self.classifier(final_output)

class FastIMDBDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length: int):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx]
        }

class Trainer:
    def __init__(self, config: ModelConfig, model: nn.Module, device: torch.device):
        self.config = config
        self.model = model
        self.device = device
        self.scaler = GradScaler()
        self.best_acc = 0
        self.patience_counter = 0

    def train_epoch(self, train_loader: DataLoader, optimizer: AdamW, scheduler) -> Tuple[float, float]:
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        optimizer.zero_grad()

        for i, batch in enumerate(tqdm(train_loader)):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)

            with autocast():
                outputs = self.model(input_ids, attention_mask)
                loss = F.cross_entropy(outputs, labels)
                loss = loss / self.config.accumulation_steps

            self.scaler.scale(loss).backward()

            if (i + 1) % self.config.accumulation_steps == 0:
                self.scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.scaler.step(optimizer)
                self.scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            total_loss += loss.item() * self.config.accumulation_steps
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        return total_loss / len(train_loader), 100 * correct / total

    def validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in tqdm(val_loader):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                outputs = self.model(input_ids, attention_mask)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        return 100 * correct / total

    def should_stop_early(self, val_acc: float) -> bool:
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.patience_counter = 0
            torch.save(self.model.state_dict(), "best_model.pth")
            return False

        self.patience_counter += 1
        if self.patience_counter >= self.config.early_stopping_patience:
            return True

        if val_acc >= self.config.early_stopping_threshold:
            return True

        return False

def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    config = ModelConfig()

    # Load and prepare data
    dataset = load_dataset('imdb')
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    train_dataset = FastIMDBDataset(
        dataset['train']['text'],
        dataset['train']['label'],
        tokenizer,
        max_length=config.max_length
    )

    val_dataset = FastIMDBDataset(
        dataset['test']['text'],
        dataset['test']['label'],
        tokenizer,
        max_length=config.max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Initialize model and training components
    model = FastBERTDynaFusion(config).to(device)

    # Separate learning rates for BERT and custom layers
    optimizer = AdamW([
        {'params': model.bert.parameters(), 'lr': config.learning_rate},
        {'params': (p for n, p in model.named_parameters() if not n.startswith('bert')),
         'lr': config.learning_rate * 10}
    ], weight_decay=0.01)

    num_training_steps = len(train_loader) * config.max_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(num_training_steps * config.warmup_ratio),
        num_training_steps=num_training_steps
    )

    # Initialize trainer
    trainer = Trainer(config, model, device)

    # Training loop
    for epoch in range(config.max_epochs):
        train_loss, train_acc = trainer.train_epoch(train_loader, optimizer, scheduler)
        val_acc = trainer.validate(val_loader)

        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, "
              f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

        if trainer.should_stop_early(val_acc):
            print("Early stopping triggered!")
            break

if __name__ == '__main__':
    main()

  self.scaler = GradScaler()
  with autocast():
100%|██████████| 782/782 [06:36<00:00,  1.97it/s]
100%|██████████| 391/391 [04:48<00:00,  1.36it/s]


Epoch 1: Train Loss: 0.2884, Train Acc: 87.40%, Val Acc: 93.83%


100%|██████████| 782/782 [06:35<00:00,  1.98it/s]
100%|██████████| 391/391 [04:48<00:00,  1.35it/s]


Epoch 2: Train Loss: 0.1665, Train Acc: 94.12%, Val Acc: 94.11%


100%|██████████| 782/782 [06:34<00:00,  1.98it/s]
100%|██████████| 391/391 [04:49<00:00,  1.35it/s]


Epoch 3: Train Loss: 0.0984, Train Acc: 96.82%, Val Acc: 93.49%


100%|██████████| 782/782 [06:32<00:00,  1.99it/s]
100%|██████████| 391/391 [04:48<00:00,  1.35it/s]


Epoch 4: Train Loss: 0.0558, Train Acc: 98.39%, Val Acc: 93.93%
Early stopping triggered!


In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [None]:
I'll create a comprehensive version with detailed comments and instrumentation for ablation studies.



```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass
from typing import Optional, Tuple, Dict
import wandb  # For experiment tracking

@dataclass
class ModelConfig:
    """Configuration for ablation studies and architecture variations"""
    use_local_context: bool = True
    use_global_context: bool = True
    use_uncertainty: bool = True
    use_adaptive_fusion: bool = True
    num_monte_carlo_samples: int = 10
    temperature: float = 0.1
    hidden_size: int = 768
    num_classes: int = 4
    dropout_rate: float = 0.1

class EnhancedBERTDynaFusion(nn.Module):
    """
    Enhanced BERT-DynaFusion model with support for ablation studies

    Architecture Components (can be individually disabled for ablation):
    1. Local Context Processing (sentence-level)
    2. Global Context Processing (document-level)
    3. Uncertainty Estimation
    4. Adaptive Feature Fusion
    """
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Component initialization (controlled by config)
        if config.use_local_context:
            self.local_context = self._create_transformer_layer()
        if config.use_global_context:
            self.global_context = self._create_transformer_layer()

        if config.use_adaptive_fusion:
            num_features = sum([
                config.use_local_context,
                config.use_global_context,
                1  # BERT features always present
            ])
            self.feature_gate = nn.Sequential(
                nn.Linear(config.hidden_size * num_features, num_features),
                nn.Softmax(dim=-1)
            )

        # Output dimensionality varies based on uncertainty setting
        output_dim = config.num_classes * (2 if config.use_uncertainty else 1)
        self.classifier = self._create_classifier(output_dim)

        # Initialize feature statistics for analysis
        self.feature_usage_stats = {
            'local_weight_avg': 0.0,
            'global_weight_avg': 0.0,
            'bert_weight_avg': 0.0,
            'uncertainty_avg': 0.0
        }

    def _create_transformer_layer(self) -> nn.Module:
        """Creates a transformer layer with standard configuration"""
        return nn.TransformerEncoderLayer(
            d_model=self.config.hidden_size,
            nhead=8,
            dim_feedforward=2048,
            dropout=self.config.dropout_rate,
            batch_first=True
        )

    def _create_classifier(self, output_dim: int) -> nn.Module:
        """Creates the classification head"""
        return nn.Sequential(
            nn.Linear(self.config.hidden_size * 2, self.config.hidden_size),
            nn.LayerNorm(self.config.hidden_size),
            nn.GELU(),
            nn.Dropout(self.config.dropout_rate),
            nn.Linear(self.config.hidden_size, output_dim)
        )

    def _masked_pool(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Advanced pooling strategy combining mean and max pooling

        Args:
            features: Input features of shape [batch_size, seq_len, hidden_size]
            mask: Attention mask of shape [batch_size, seq_len, 1]

        Returns:
            Pooled features of shape [batch_size, hidden_size * 2]
        """
        mean_pool = (features * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        max_pool = torch.max(features * mask - 1e-9 * (1 - mask), dim=1)[0]
        return torch.cat([mean_pool, max_pool], dim=-1)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Get BERT embeddings
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = bert_outputs.last_hidden_state
        mask_expanded = attention_mask.unsqueeze(-1).float()

        # Initialize feature list with BERT features
        pooled_features = [self._masked_pool(sequence_output, mask_expanded)]
        feature_names = ['bert']

        # Process through additional context layers if enabled
        if self.config.use_local_context:
            local_features = self.local_context(
                sequence_output,
                src_key_padding_mask=~attention_mask.bool()
            )
            pooled_features.append(self._masked_pool(local_features, mask_expanded))
            feature_names.append('local')

        if self.config.use_global_context:
            global_features = self.global_context(
                sequence_output,
                src_key_padding_mask=~attention_mask.bool()
            )
            pooled_features.append(self._masked_pool(global_features, mask_expanded))
            feature_names.append('global')

        # Combine features
        if self.config.use_adaptive_fusion and len(pooled_features) > 1:
            concat_features = torch.cat(pooled_features, dim=-1)
            mixing_weights = self.feature_gate(concat_features)

            # Update feature usage statistics
            with torch.no_grad():
                for idx, name in enumerate(feature_names):
                    self.feature_usage_stats[f'{name}_weight_avg'] = \
                        0.9 * self.feature_usage_stats[f'{name}_weight_avg'] + \
                        0.1 * mixing_weights[:, idx].mean().item()

            # Mix features with learned weights
            mixed_features = sum(w.unsqueeze(1) * f for w, f in zip(
                mixing_weights.chunk(len(pooled_features), dim=1),
                pooled_features
            ))
        else:
            mixed_features = torch.cat(pooled_features, dim=-1)

        # Classification with or without uncertainty
        logits = self.classifier(mixed_features)
        if self.config.use_uncertainty:
            batch_size = logits.shape[0]
            logits = logits.view(batch_size, -1, 2)
            mean_logits = logits[..., 0]
            var_logits = F.softplus(logits[..., 1]) + 1e-6

            # Update uncertainty statistics
            with torch.no_grad():
                self.feature_usage_stats['uncertainty_avg'] = \
                    0.9 * self.feature_usage_stats['uncertainty_avg'] + \
                    0.1 * var_logits.mean().item()

            return mean_logits, var_logits
        else:
            return logits, None

def train_epoch(
    model: EnhancedBERTDynaFusion,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LRScheduler,
    scaler: GradScaler,
    device: torch.device,
    epoch: int
) -> Dict[str, float]:
    """
    Training loop with comprehensive metrics collection
    """
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
    for batch in progress_bar:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with autocast():
            mean_logits, var_logits = model(input_ids, attention_mask)

            if var_logits is not None:
                # Uncertainty-aware loss
                nll_loss = gaussian_nll_loss(mean_logits, labels, var_logits)
                kl_loss = kl_divergence_loss(mean_logits, var_logits)
                loss = nll_loss + 0.1 * kl_loss
            else:
                # Standard cross-entropy loss
                loss = F.cross_entropy(mean_logits, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        _, predicted = torch.max(mean_logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })

    # Collect metrics
    metrics = {
        'train_loss': total_loss / len(train_loader),
        'train_acc': 100 * correct / total,
        **model.feature_usage_stats
    }

    return metrics

def gaussian_nll_loss(mean: torch.Tensor, targets: torch.Tensor, var: torch.Tensor) -> torch.Tensor:
    """Gaussian Negative Log-Likelihood loss for uncertainty estimation"""
    return 0.5 * (torch.log(var) + (mean - targets.unsqueeze(-1))**2 / var).mean()

def kl_divergence_loss(mean: torch.Tensor, var: torch.Tensor) -> torch.Tensor:
    """KL divergence regularization for uncertainty estimation"""
    return -0.5 * (1 + torch.log(var) - mean**2 - var).mean()

def main():
    # Initialize wandb for experiment tracking
    wandb.init(project="bert-dynafusion-ablation")

    # Set up configuration for current run
    config = ModelConfig(
        use_local_context=True,
        use_global_context=True,
        use_uncertainty=True,
        use_adaptive_fusion=True
    )

    # Log configuration
    wandb.config.update(asdict(config))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EnhancedBERTDynaFusion(config).to(device)

    # Training setup and execution as in original code...
    # (Previous training loop implementation)

if __name__ == '__main__':
    main()

```

This enhanced version includes several key improvements for research purposes:

1. Ablation Study Support
- Modular architecture with configurable components
- Detailed tracking of feature usage and importance
- Integration with wandb for experiment tracking
- Easy toggling of different architectural components

2. Theoretical Analysis Features
- Comprehensive metrics collection for each component
- Uncertainty estimation and calibration
- Feature importance tracking through adaptive fusion weights
- Component-wise performance analysis

3. Research-Friendly Additions
- Detailed documentation of architectural choices
- Clear separation of components for analysis
- Metrics for understanding model behavior
- Support for various experimental configurations

To conduct ablation studies, you can create different configurations:

```python
# Baseline configuration
baseline_config = ModelConfig(
    use_local_context=False,
    use_global_context=False,
    use_uncertainty=False,
    use_adaptive_fusion=False
)

# Full model configuration
full_config = ModelConfig(
    use_local_context=True,
    use_global_context=True,
    use_uncertainty=True,
    use_adaptive_fusion=True
)

# Component-specific configurations
local_only_config = ModelConfig(
    use_local_context=True,
    use_global_context=False,
    use_uncertainty=False,
    use_adaptive_fusion=False
)
```

The code tracks key metrics for each component that can be analyzed:
1. Feature usage weights through the adaptive fusion mechanism
2. Uncertainty estimates and their correlation with errors
3. Performance impact of each architectural component

Would you like me to:
1. Add additional metrics for theoretical analysis?
2. Implement more sophisticated ablation configurations?
3. Add visualization tools for analyzing model behavior?
4. Include statistical significance testing for ablation results?