In [11]:
%pip install datasets
%pip install transformers torch torchvision torchaudio

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import numpy as np
import time
import datetime

In [7]:
import os
print("CUDA_HOME:", os.environ.get('CUDA_HOME', 'Not set'))
print("PATH:", os.environ.get('PATH'))
print("LD_LIBRARY_PATH:", os.environ.get('LD_LIBRARY_PATH', 'Not set'))

# Check PyTorch installation
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch CUDA version: {torch.version.cuda}")
print(f"Is CUDA build: {torch.cuda.is_available()}")

CUDA_HOME: Not set
PATH: /usr/bin:/home/ubuntu/.vscode-server/cli/servers/Stable-c306e94f98122556ca081f527b466015e1bc37b0/server/bin/remote-cli:/home/ubuntu/.local/bin:/home/ubuntu/.local/bin:/usr/mpi/gcc/openmpi-4.1.7rc1/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/home/ubuntu/.vscode-server/cli/servers/Stable-c306e94f98122556ca081f527b466015e1bc37b0/server/bin/remote-cli:/home/ubuntu/.local/bin:/home/ubuntu/.local/bin:/usr/mpi/gcc/openmpi-4.1.7rc1/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin
LD_LIBRARY_PATH: /usr/mpi/gcc/openmpi-4.1.7rc1/lib:/usr/mpi/gcc/openmpi-4.1.7rc1/lib64
PyTorch version: 2.8.0+cpu
PyTorch CUDA version: None
Is CUDA build: False


In [8]:
import os
os.environ['CUDA_HOME'] = '/usr/local/cuda'
os.environ['PATH'] = '/usr/local/cuda/bin:' + os.environ.get('PATH', '')
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')

In [9]:
# To fix the CUDA environment issue, we can reinstall PyTorch with the correct CUDA version.
import subprocess
import sys

# Uninstall current PyTorch
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "torch", "torchvision", "torchaudio", "-y"])

# Install CUDA version
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cu121"])

Found existing installation: torch 2.8.0
Uninstalling torch-2.8.0:
  Successfully uninstalled torch-2.8.0
Found existing installation: torchvision 0.22.0
Not uninstalling torchvision at /usr/lib/python3/dist-packages, outside environment /usr
Can't uninstall 'torchvision'. No files were found to uninstall.
Found existing installation: torchaudio 2.8.0
Uninstalling torchaudio-2.8.0:
  Successfully uninstalled torchaudio-2.8.0
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/torchaudio-2.2.0-cp310-cp310-linux_aarch64.whl (1.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 156.0 MB/s eta 0:00:00
  Downloading https://download.pytorch.org/whl/torchaudio-2.1.2-cp310-cp310-linux_aarch64.whl (1.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 175.9 MB/s eta 0:00:00
  Downloading https://download.pytorch.org/whl

0

In [10]:
# Hyperparameters
BATCH_SIZE = 512
EMBED_DIM = 256
NUM_ITERS = 4
ALPHA = 0.5
LR = 5e-5
EPOCHS = 10
MAX_LENGTH = 4096 # Maximum token length for padding/truncation
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PAD_TO_MULTIPLE_OF=8
GRADIENT_CLIPPING = 1.0

# Test if CUDA is available
print(f"Using device: {DEVICE}")

Using device: cpu


In [4]:
# Load AG News dataset
dataset = load_dataset('ag_news')

In [5]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased",
                                          padding="max_length",
                                          truncation=True,
                                          max_length=MAX_LENGTH,
                                          pad_to_multiple_of=PAD_TO_MULTIPLE_OF)

In [6]:
# Encode labels
label_encoder = LabelEncoder()
label_encoder.fit(dataset['train']['label'])

LabelEncoder()

In [7]:
# Custom Dataset Class
class AGNewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(label, dtype=torch.long)
        }

In [8]:
# Prepare datasets
train_texts = dataset['train']['text']
train_labels = label_encoder.transform(dataset['train']['label'])
test_texts = dataset['test']['text']
test_labels = label_encoder.transform(dataset['test']['label'])

In [9]:
train_dataset = AGNewsDataset(train_texts, train_labels, tokenizer, MAX_LENGTH)
test_dataset = AGNewsDataset(test_texts, test_labels, tokenizer, MAX_LENGTH)

train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=12, pin_memory=True,
                          prefetch_factor=4,
                          persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [10]:
# Get a validation set before training start
import random

# Select a small random subset from our test dataset
subset_size = 10  # Adjust as needed
subset_indices = random.sample(range(len(test_loader.dataset)), subset_size)

# Create a new DataLoader for this subset
from torch.utils.data import Subset

test_subset = Subset(test_loader.dataset, subset_indices)
test_subset_loader = DataLoader(test_subset, batch_size=BATCH_SIZE, shuffle=False)

In [11]:
# ========== Step 2: Define the Model ==========
class DiffusionAttentionFreeModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_iters=NUM_ITERS, alpha=ALPHA, num_classes=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.noise_std = 0.1  # Initial noise
        self.alpha = alpha  # Decay factor
        self.num_iters = num_iters  # Iterative updates
        self.update_mlp = nn.Linear(embed_dim, embed_dim)  # Local transformation
        self.output_mlp = nn.Linear(embed_dim, num_classes)  # Classifier

    def forward(self, input_ids, attention_mask):
        # Step 1: Embed + Add Noise
        h = self.embedding(input_ids) + self.noise_std * torch.randn_like(self.embedding(input_ids))

        # Step 2: Iterative Refinement (Diffusion Process)
        for _ in range(self.num_iters):
            # Multi-Neighbor Updates
            h_left = torch.roll(h, shifts=1, dims=1)
            h_right = torch.roll(h, shifts=-1, dims=1)
            h_update = self.update_mlp(h_left) + self.update_mlp(h_right)

            # Weighted update rule (diffusion-like)
            h = self.alpha * h + (1 - self.alpha) * h_update

        # Step 3: Pooling + Classification
        h = (h * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)  # Masked mean pooling
        logits = self.output_mlp(h)
        return logits

In [12]:
def evaluate(model, test_loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

    return total_loss / len(test_loader), correct / total

In [13]:
print("Learning rate", LR)
vocab_size = tokenizer.vocab_size
diff_model = DiffusionAttentionFreeModel(vocab_size, EMBED_DIM).to(DEVICE)
optimizer = optim.AdamW(diff_model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

Learning rate 5e-05


In [14]:
import time
for i, batch in enumerate(train_loader):
    start_time = time.time()
    batch_data = batch["input_ids"].to(DEVICE)  # Load batch to GPU
    print(f"Batch {i+1}: Load Time = {time.time() - start_time:.4f} sec")

    if i == 10:  # Stop after 10 batches
        break




Batch 1: Load Time = 0.0000 sec
Batch 2: Load Time = 0.0001 sec
Batch 3: Load Time = 0.0001 sec
Batch 4: Load Time = 0.0001 sec
Batch 5: Load Time = 0.0001 sec
Batch 6: Load Time = 0.0001 sec
Batch 7: Load Time = 0.0001 sec
Batch 8: Load Time = 0.0001 sec
Batch 9: Load Time = 0.0001 sec
Batch 10: Load Time = 0.0001 sec
Batch 11: Load Time = 0.0001 sec


In [15]:
class ImprovedDiffusionAttentionFreeModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_iters=4, alpha=0.7, num_classes=4):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.num_iters = num_iters
        self.alpha = alpha
        self.noise_std = 0.05  # Reduced for FP16 stability
        
        # Token embedding with proper initialization
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Multi-head neighbor interaction (more sophisticated than single MLP)
        self.neighbor_proj = nn.ModuleList([
            nn.Linear(embed_dim, embed_dim, bias=False) for _ in range(3)
        ])  # left, right, self projections
        
        # Layer normalization for stability
        self.layer_norm = nn.LayerNorm(embed_dim)
        
        # Nonlinear transformation with residual connection
        self.update_mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim * 2, embed_dim)
        )
        
        # Classification head with dropout
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim, num_classes)
        )
        
        # Initialize weights for FP16 stability
        self._init_weights()
    
    def _init_weights(self):
        """Proper weight initialization for FP16 training"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                # Xavier initialization scaled for FP16
                nn.init.xavier_normal_(module.weight, gain=0.02)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                nn.init.constant_(module.bias, 0)
                nn.init.constant_(module.weight, 1.0)
    
    def forward(self, input_ids, attention_mask):
        batch_size, seq_len = input_ids.shape
        
        # Step 1: Embed + Add controlled noise
        h = self.embedding(input_ids)
        
        if self.training:
            # Add noise only during training, with proper scaling
            noise = torch.randn_like(h, dtype=h.dtype, device=h.device) * self.noise_std
            h = h + noise
        
        # Step 2: Iterative refinement with proper neighbor handling
        for iteration in range(self.num_iters):
            # Get neighbor representations
            h_left = torch.cat([h[:, -1:, :], h[:, :-1, :]], dim=1)  # Proper circular shift
            h_right = torch.cat([h[:, 1:, :], h[:, :1, :]], dim=1)   # Proper circular shift
            
            # Apply different projections to each neighbor type
            h_left_proj = self.neighbor_proj[0](h_left)
            h_right_proj = self.neighbor_proj[1](h_right)
            h_self_proj = self.neighbor_proj[2](h)
            
            # Combine neighbor information
            neighbor_sum = h_left_proj + h_right_proj + h_self_proj
            
            # Apply nonlinear transformation
            h_update = self.update_mlp(neighbor_sum)
            
            # Residual connection + weighted update
            h_new = self.alpha * h + (1 - self.alpha) * h_update
            
            # Apply layer normalization for stability
            h = self.layer_norm(h_new)
        
        # Step 3: Masked pooling (handle padding properly)
        if attention_mask is not None:
            # Expand attention mask for broadcasting
            mask_expanded = attention_mask.unsqueeze(-1).float()
            h_masked = h * mask_expanded
            
            # Avoid division by zero
            mask_sum = mask_expanded.sum(dim=1).clamp(min=1e-8)
            pooled = h_masked.sum(dim=1) / mask_sum
        else:
            pooled = h.mean(dim=1)
        
        # Step 4: Classification
        logits = self.classifier(pooled)
        
        return logits

In [16]:
def train_improved_model(model, train_loader, test_loader, device, 
                        epochs=50, lr=5e-5, checkpoint_path="./checkpoints"):
    """
    Training function with proper FP16 support and gradient scaling
    """
    
    # Ensure checkpoint directory exists
    os.makedirs(checkpoint_path, exist_ok=True)
    
    # Initialize optimizer with FP16-friendly settings
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=lr,
        weight_decay=0.01,
        eps=1e-8  # Increased epsilon for FP16 numerical stability
    )
    
    # Initialize gradient scaler for mixed precision
    scaler = GradScaler('cuda')
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=5 
    )
    
    # Load checkpoint if exists
    latest_checkpoint = os.path.join(checkpoint_path, "latest_model.pth")
    initial_epoch = 1
    
    if os.path.exists(latest_checkpoint):
        print("Loading checkpoint...")
        checkpoint = torch.load(latest_checkpoint, weights_only=False)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scaler.load_state_dict(checkpoint["scaler_state_dict"])
        initial_epoch = checkpoint["epoch"] + 1
        print(f"Resuming training from epoch {initial_epoch}")
    
    # Training loop
    log_file = os.path.join(checkpoint_path, f"training_log_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
    
    with open(log_file, "w") as f:
        f.write(f"=== Training Start - {datetime.datetime.now()} ===\n")
        f.write(f"Model: ImprovedDiffusionAttentionFreeModel\n")
        f.write(f"Learning Rate: {lr}\n")
        f.write(f"Epochs: {epochs}\n")
        f.write(f"Initial Scaler Scale: {scaler.get_scale()}\n")
        f.write("=" * 50 + "\n")
        
        for epoch in range(initial_epoch, epochs + 1):
            start_time = time.time()
            
            # Training phase
            model.train()
            total_loss = 0
            correct = 0
            total_samples = 0
            
            for batch_idx, batch in enumerate(train_loader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                optimizer.zero_grad()
                
                # Forward pass with autocast for mixed precision
                with autocast('cuda'):
                    outputs = model(input_ids, attention_mask)
                    loss = criterion(outputs, labels)
                
                # Check for loss explosion early
                if loss.item() > 100:
                    print(f"WARNING: Loss explosion detected: {loss.item():.2f}")
                    print(f"Scaler scale: {scaler.get_scale()}")
                    f.write(f"WARNING: Loss explosion at epoch {epoch}, batch {batch_idx}\n")
                
                # Backward pass with gradient scaling
                scaler.scale(loss).backward()
                
                # Gradient clipping (unscale first)
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                # Optimizer step
                scaler.step(optimizer)
                scaler.update()
                
                # Statistics
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total_samples += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                # Log progress every 50 batches
                if batch_idx % 50 == 0:
                    current_acc = 100. * correct / total_samples
                    print(f"Epoch {epoch}, Batch {batch_idx}: Loss = {loss.item():.4f}, Acc = {current_acc:.2f}%")
            
            epoch_time = time.time() - start_time
            train_accuracy = correct / total_samples
            avg_train_loss = total_loss / len(train_loader)
            
            # Evaluation phase
            model.eval()
            test_loss = 0
            test_correct = 0
            test_total = 0
            
            with torch.no_grad():
                for batch in test_loader:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['label'].to(device)
                    
                    with autocast('cuda'):
                        outputs = model(input_ids, attention_mask)
                        loss = criterion(outputs, labels)
                    
                    test_loss += loss.item()
                    _, predicted = outputs.max(1)
                    test_total += labels.size(0)
                    test_correct += predicted.eq(labels).sum().item()
            
            test_accuracy = test_correct / test_total
            avg_test_loss = test_loss / len(test_loader)
            
            # Update learning rate
            scheduler.step(avg_test_loss)
            current_lr = optimizer.param_groups[0]['lr']
            
            # Save checkpoint
            checkpoint = {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scaler_state_dict": scaler.state_dict(),
                "train_loss": avg_train_loss,
                "test_loss": avg_test_loss,
                "train_acc": train_accuracy,
                "test_acc": test_accuracy,
            }
            
            torch.save(checkpoint, latest_checkpoint)
            if epoch % 5 == 0:  # Save every 5 epochs
                torch.save(checkpoint, os.path.join(checkpoint_path, f"model_epoch_{epoch}.pth"))
            
            # Logging
            log_msg = f"Epoch {epoch}/{epochs}:\n"
            log_msg += f"  Train - Loss: {avg_train_loss:.4f}, Acc: {train_accuracy:.4f}\n"
            log_msg += f"  Test  - Loss: {avg_test_loss:.4f}, Acc: {test_accuracy:.4f}\n"
            log_msg += f"  Time: {epoch_time:.2f}s, LR: {current_lr:.2e}\n"
            log_msg += f"  Scaler Scale: {scaler.get_scale()}\n"
            log_msg += "-" * 50 + "\n"
            
            print(log_msg)
            f.write(log_msg)
            f.flush()
    
    return model


In [17]:
def setup_improved_training():
    """
    Setup function to replace your current training loop
    """
    
    # Your existing hyperparameters
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    EMBED_DIM = 256
    NUM_ITERS = 4
    ALPHA = 0.7  # Increased from 0.5 for better stability
    LR = 2e-5    # Reduced from 5e-5 for FP16 stability
    EPOCHS = 10
    
    # Initialize improved model
    vocab_size = 30522  # Your tokenizer vocab size
    model = ImprovedDiffusionAttentionFreeModel(
        vocab_size=vocab_size,
        embed_dim=EMBED_DIM,
        num_iters=NUM_ITERS,
        alpha=ALPHA,
        num_classes=4
    ).to(DEVICE)
    
    # Convert to half precision
    model = model.half()
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Model size (MB): {sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2:.2f}")
    
    return model

In [18]:
def train(train_loader, test_loader, tokenizer):
    """
    Drop-in replacement for your current training cell
    """
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create improved model
    model = ImprovedDiffusionAttentionFreeModel(
        vocab_size=tokenizer.vocab_size,
        embed_dim=256,  # Your EMBED_DIM
        num_iters=4,    # Your NUM_ITERS  
        alpha=0.7,      # Improved from your 0.5
        num_classes=4
    ).to(DEVICE)
    
    # Convert to FP16 - Removing this for the time being
    #model = model.half()
    
    # Start training with proper FP16 support
    trained_model = train_improved_model(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        device=DEVICE,
        epochs=5,
        lr=2e-5,  # Reduced for FP16 stability
        checkpoint_path="./checkpoints_improved"
    )
    
    return trained_model


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler  # Updated API
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import time
import datetime
import os
train(train_loader, test_loader, tokenizer)

