<a href="https://colab.research.google.com/github/pbanavara/experimental_attention_free_diff/blob/main/AttnFreeDiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Diffusion Model

In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import numpy as np
import time
import datetime

In [55]:
BATCH_SIZE = 512

In [76]:
# ========== Hyperparameters ==========
EMBED_DIM = 256
NUM_ITERS = 4
ALPHA = 0.5
LR = 5e-5
EPOCHS = 10
MAX_LENGTH = 4096 # Maximum token length for padding/truncation
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PAD_TO_MULTIPLE_OF=8
GRADIENT_CLIPPING = 1.0

In [77]:
# Load AG News dataset
dataset = load_dataset('ag_news')

In [78]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased",
                                          padding="max_length",
                                          truncation=True,
                                          max_length=MAX_LENGTH,
                                          pad_to_multiple_of=PAD_TO_MULTIPLE_OF)


In [79]:
# Encode labels
label_encoder = LabelEncoder()
label_encoder.fit(dataset['train']['label'])

In [80]:
# Custom Dataset Class
class AGNewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(label, dtype=torch.long)
        }

In [81]:
# Prepare datasets
train_texts = dataset['train']['text']
train_labels = label_encoder.transform(dataset['train']['label'])
test_texts = dataset['test']['text']
test_labels = label_encoder.transform(dataset['test']['label'])

In [82]:
train_dataset = AGNewsDataset(train_texts, train_labels, tokenizer, MAX_LENGTH)
test_dataset = AGNewsDataset(test_texts, test_labels, tokenizer, MAX_LENGTH)

train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=12, pin_memory=True,
                          prefetch_factor=4,
                          persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)


In [83]:
# Get a validation set before training start
import random

# Select a small random subset from our test dataset
subset_size = 10  # Adjust as needed
subset_indices = random.sample(range(len(test_loader.dataset)), subset_size)

# Create a new DataLoader for this subset
from torch.utils.data import Subset

test_subset = Subset(test_loader.dataset, subset_indices)
test_subset_loader = DataLoader(test_subset, batch_size=16, shuffle=False)


In [84]:
# ========== Step 2: Define the Model ==========
class DiffusionAttentionFreeModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_iters=NUM_ITERS, alpha=ALPHA, num_classes=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.noise_std = 0.1  # Initial noise
        self.alpha = alpha  # Decay factor
        self.num_iters = num_iters  # Iterative updates
        self.update_mlp = nn.Linear(embed_dim, embed_dim)  # Local transformation
        self.output_mlp = nn.Linear(embed_dim, num_classes)  # Classifier

    def forward(self, input_ids, attention_mask):
        # Step 1: Embed + Add Noise
        h = self.embedding(input_ids) + self.noise_std * torch.randn_like(self.embedding(input_ids))

        # Step 2: Iterative Refinement (Diffusion Process)
        for _ in range(self.num_iters):
            # Multi-Neighbor Updates
            h_left = torch.roll(h, shifts=1, dims=1)
            h_right = torch.roll(h, shifts=-1, dims=1)
            h_update = self.update_mlp(h_left) + self.update_mlp(h_right)

            # Weighted update rule (diffusion-like)
            h = self.alpha * h + (1 - self.alpha) * h_update

        # Step 3: Pooling + Classification
        h = (h * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)  # Masked mean pooling
        logits = self.output_mlp(h)
        return logits

In [85]:
def evaluate(model, test_loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

    return total_loss / len(test_loader), correct / total

In [86]:
print("Learning rate", LR)
vocab_size = tokenizer.vocab_size
diff_model = DiffusionAttentionFreeModel(vocab_size, EMBED_DIM).to(DEVICE)
optimizer = optim.AdamW(diff_model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

Learning rate 5e-05


In [87]:
import time
for i, batch in enumerate(train_loader):
    start_time = time.time()
    batch_data = batch["input_ids"].to(DEVICE)  # Load batch to GPU
    print(f"Batch {i+1}: Load Time = {time.time() - start_time:.4f} sec")

    if i == 10:  # Stop after 10 batches
        break


Batch 1: Load Time = 0.0017 sec
Batch 2: Load Time = 0.0015 sec
Batch 3: Load Time = 0.0015 sec
Batch 4: Load Time = 0.0016 sec
Batch 5: Load Time = 0.0014 sec
Batch 6: Load Time = 0.0015 sec
Batch 7: Load Time = 0.0014 sec
Batch 8: Load Time = 0.0016 sec
Batch 9: Load Time = 0.0015 sec
Batch 10: Load Time = 0.0014 sec
Batch 11: Load Time = 0.0014 sec


In [92]:
import torch
import os
import datetime
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
import time

EPOCHS = 50
# Define hyperparameters
checkpoint_path = "drive/MyDrive/model_checkpoints"

# Ensure checkpoint directory exists
os.makedirs(checkpoint_path, exist_ok=True)

# Initialize Model & Optimizer
vocab_size = tokenizer.vocab_size
diff_model = DiffusionAttentionFreeModel(vocab_size, EMBED_DIM).to(DEVICE)
optimizer = optim.AdamW(diff_model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

# Learning Rate Scheduler (Reduce LR when validation loss stops improving)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3, verbose=True)
scaler = GradScaler()

# Load previous checkpoint if exists
latest_checkpoint = os.path.join(checkpoint_path, "latest_model.pth")
if os.path.exists(latest_checkpoint):
    print("Loading checkpoint...")
    checkpoint = torch.load(latest_checkpoint)
    diff_model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    initial_epoch = checkpoint["epoch"] + 1
    print(f"Resuming training from epoch {initial_epoch}")
else:
    initial_epoch = 1


torch.backends.cuda.matmul.allow_tf32 = True  # Enables Tensor Cores for faster FP16

# Training loop
with open(os.path.join(checkpoint_path,
    "training_log_" + str(datetime.datetime.now()) + ".txt"), "a") as f:  # Open once to avoid multiple file creations
    f.write(f"\n=== Training Start - {datetime.datetime.now()} ===\n")
    f.write(f"Batch Size: {BATCH_SIZE}\n")
    f.write(f"Max Sequence Length: {MAX_LENGTH}\n")
    #f.write(f"Gradient Clipping: {gradient_clipping if gradient_clipping else 'None'}\n")
    f.write(f"Number of Epochs: {EPOCHS}\n")
    f.write("=" * 50 + "\n")
    for epoch in range(initial_epoch, EPOCHS + 1):
        start_time = time.time()
        diff_model.train()
        total_loss, correct, total = 0, 0, 0

        for batch in train_loader:
            labels = batch["label"].to(DEVICE)
            texts = batch["input_ids"].to(DEVICE)
            masks = batch["attention_mask"].to(DEVICE)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():  # FP16 Mixed Precision
                output = diff_model(input_ids=texts, attention_mask=masks)
                loss = criterion(output, labels)

            scaler.scale(loss).backward()

            if GRADIENT_CLIPPING:
                torch.nn.utils.clip_grad_norm_(diff_model.parameters(), max_norm=GRADIENT_CLIPPING)

            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            correct += (output.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

        epoch_time = time.time() - start_time
        accuracy = correct / total

        # Save checkpoint after every epoch
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": diff_model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        torch.save(checkpoint, os.path.join(checkpoint_path, "latest_model.pth"))
        torch.save(checkpoint, os.path.join(checkpoint_path, f"model_epoch_{epoch}.pth"))  # Save per epoch

        # Adjust learning rate based on validation loss
        test_loss, test_acc = evaluate(diff_model, test_loader, criterion)
        scheduler.step(test_loss)  # Reduce LR if validation loss plateaus

        print(f"Epoch {epoch} - Loss: {total_loss:.4f}, Accuracy: {accuracy:.4f}, Time: {epoch_time:.2f} sec")
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

        # Write to the log file
        f.write(f"\n=== Epoch {epoch+1} - {datetime.datetime.now()} ===\n")
        f.write(f"Loss: {total_loss:.4f}, Accuracy: {accuracy:.4f}, Time: {epoch_time:.2f} sec\n")

        # Append GPU stats
        f.write(os.popen("nvidia-smi").read())  # More efficient than os.system
        f.flush()  # Ensure data is written immediately
        # Log learning rate
        current_lr = optimizer.param_groups[0]["lr"]
        f.write(f"Current Learning Rate: {current_lr:.8f}")
        print(f"Current Learning Rate: {current_lr:.8f}")



  scaler = GradScaler()
  checkpoint = torch.load(latest_checkpoint)


Loading checkpoint...
Resuming training from epoch 41


  with torch.cuda.amp.autocast():  # FP16 Mixed Precision


Epoch 41 - Loss: 158.4427, Accuracy: 0.7477, Time: 80.65 sec
Test Loss: 0.6858, Test Accuracy: 0.7428
Current Learning Rate: 0.00005000
Epoch 42 - Loss: 157.3184, Accuracy: 0.7496, Time: 80.86 sec
Test Loss: 0.6796, Test Accuracy: 0.7471
Current Learning Rate: 0.00005000
Epoch 43 - Loss: 156.3316, Accuracy: 0.7517, Time: 80.73 sec
Test Loss: 0.6784, Test Accuracy: 0.7442
Current Learning Rate: 0.00005000
Epoch 44 - Loss: 155.4265, Accuracy: 0.7537, Time: 80.99 sec
Test Loss: 0.6727, Test Accuracy: 0.7492
Current Learning Rate: 0.00005000
Epoch 45 - Loss: 154.3936, Accuracy: 0.7563, Time: 81.02 sec
Test Loss: 0.6691, Test Accuracy: 0.7471
Current Learning Rate: 0.00005000
Epoch 46 - Loss: 153.4276, Accuracy: 0.7573, Time: 80.89 sec
Test Loss: 0.6667, Test Accuracy: 0.7501
Current Learning Rate: 0.00005000
Epoch 47 - Loss: 152.6161, Accuracy: 0.7588, Time: 80.68 sec
Test Loss: 0.6610, Test Accuracy: 0.7541
Current Learning Rate: 0.00005000
Epoch 48 - Loss: 151.6926, Accuracy: 0.7609, Tim

In [94]:
# Evaluate on our own test subset
final_subset_loss, final_subset_acc = evaluate(diff_model, test_subset_loader, criterion)

print(f"🔥 Test Subset Loss: {final_subset_loss:.4f}, Test Accuracy: {final_subset_acc:.4f}")


🔥 Test Subset Loss: 0.2916, Test Accuracy: 0.9000
