# Efficient Adaptation and Analysis of Vision Transformers using LoRA

In [None]:
import torch
import sys

print("=== PyTorch Environment Test ===")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

    # Test GPU tensor
    x = torch.randn(3, 3).cuda()
    print(f"\nGPU tensor created: {x.device}")
    print(f"Tensor shape: {x.shape}")
else:
    print("CUDA not available - using CPU")
    x = torch.randn(3, 3)
    print(f"CPU tensor created: {x.device}")

print("\n✅ PyTorch test completed!")

=== PyTorch Environment Test ===
Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch version: 2.8.0+cu126
CUDA available: True
CUDA version: 12.6
Number of GPUs: 1
GPU name: NVIDIA A100-SXM4-80GB

GPU tensor created: cuda:0
Tensor shape: torch.Size([3, 3])

✅ PyTorch test completed!


In [None]:
!pip install wandb
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msiddpath[0m ([33msiddpath-university-of-maryland[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


#### Downloading the data

In [None]:
import torchvision
import torchvision.transforms as transforms

# Define a transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the CIFAR-100 training dataset
trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform)

# Load the CIFAR-100 test dataset
testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)

print("CIFAR-100 dataset imported successfully.")
print(f"Training set size: {len(trainset)}")
print(f"Test set size: {len(testset)}")

100%|██████████| 169M/169M [00:05<00:00, 30.0MB/s]


CIFAR-100 dataset imported successfully.
Training set size: 50000
Test set size: 10000


#### Resizing the data for ViT model


In [None]:
import torchvision.transforms as transforms

# Define transforms for training and validation/testing
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Upsample to ViT resolution
    transforms.RandomHorizontalFlip(), # Example data augmentation
    transforms.RandomCrop(224, padding=4), # Example data augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)), # Upsample to ViT resolution
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Apply the transforms to the datasets
trainset.transform = train_transform
testset.transform = test_transform

print("Data preparation complete. Transforms applied to datasets.")

Data preparation complete. Transforms applied to datasets.


#### Loading the VIT and freezing the parameters

In [None]:
from transformers import ViTForImageClassification

# Load a pre-trained Vision Transformer model
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=100, ignore_mismatched_sizes=True)

# # Freeze all parameters
# for param in model.parameters():
#     param.requires_grad = False # This freezes the parameters

# print("Pre-trained ViT model loaded and parameters frozen.")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# %pip install peft transformers datasets

# Here we decide if we want to train full model or some

In [None]:
# Going for full

# from peft import LoraConfig, get_peft_model

# # Define LoRA configuration
# config = LoraConfig(
#     r=16, # Rank of the update matrices.
#     lora_alpha=16, # Scaling factor for the LoRA update.
#     target_modules=["query", "value"], # Modules to apply LoRA to.
#     lora_dropout=0.1, # Dropout probability for LoRA layers.
#     bias="none", # Bias type.
# )

# # Get the LoRA-infused model
# model = get_peft_model(model, config)

# # Print trainable parameters
# model.print_trainable_parameters()

# print("LoRA adapters integrated into the model.")

In [None]:
# Print the number of trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

Number of trainable parameters: 85875556


In [None]:
# model

In [None]:
# You need to reinstall DeepSpeed and force it to compile this special CPU Adam extension.

# # # Uninstall the old version first
# !pip uninstall deepspeed -y

# # Re-install with the build flag for CPUAdam
# !DS_BUILD_CPU_ADAM=1 pip install deepspeed


Next, we need to create a DeepSpeed configuration file. This is typically a JSON file that specifies the various optimization settings for DeepSpeed. Here's an example configuration for mixed precision training and ZeRO Stage 2 optimization, which is often used for memory efficiency.

You can save this configuration to a file named `deepspeed_config.json`.

In [None]:
%%writefile deepspeed_config.json
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        }
    },
    "zero_force_ds_cpu_optimizer": false,
    "train_batch_size": 16,
    "train_micro_batch_size_per_gpu": 16,
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 1.0,
    "steps_per_print": 200
}

Writing deepspeed_config.json


## Step 1: Create the DataLoaders

In [None]:
from torch.utils.data import DataLoader

# Create DataLoaders
train_loader = DataLoader(trainset, batch_size=16, shuffle=True)
test_loader = DataLoader(testset, batch_size=16, shuffle=False)

print("DataLoaders created.")

DataLoaders created.


## Step 2: Enable Gradient Checkpointing

In [None]:
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
print("Gradient checkpointing enabled.")

Gradient checkpointing enabled.


## Step 3: Initialize DeepSpeed

In [None]:
%pip install mpi4py

Collecting mpi4py
  Downloading mpi4py-4.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (16 kB)
Downloading mpi4py-4.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_5_x86_64.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mpi4py
Successfully installed mpi4py-4.1.1


In [None]:
# %pip install deepspeed

In [None]:
import torch.optim as optim
import deepspeed

# 1. Manually create the standard PyTorch optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

# 2. Initialize DeepSpeed, passing the optimizer you just created
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    optimizer=optimizer,  # Pass the optimizer here
    config_params='deepspeed_config.json'
)

print("DeepSpeed engine initialized with PyTorch AdamW (forced).")

DeepSpeed engine initialized with PyTorch AdamW (forced).


## Step 1: Write the Training Script LoRA-DeepSpeed-T4-Baseline

In [None]:
%%writefile train.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
from peft import LoraConfig, get_peft_model
import deepspeed
from sklearn.metrics import confusion_matrix
import warnings
import wandb  # <-- 1. IMPORT WANDB
import os     # <-- Import OS to get rank

# Suppress warnings
warnings.filterwarnings("ignore")

print("--- Initializing Training Script ---")

# --- W&B Setup ---
# Set the project name
WANDB_PROJECT = "optimized-vit-periodic-labs"
# Define a name for this specific run
WANDB_RUN_NAME = "LoRA-DeepSpeed-T4-Baseline"

# --- 1. Data Prep ---
print("Setting up data transformations...")
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

print("Loading CIFAR-100 dataset...")
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=16, shuffle=False, num_workers=2)
print("DataLoaders created.")

# --- 2. Model Setup ---
print("Loading pre-trained ViT model...")
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=100, ignore_mismatched_sizes=True)

# Freeze all parameters first
for param in model.parameters():
    param.requires_grad = False

# --- 3. LoRA Setup ---
print("Applying LoRA adapters...")
config = LoraConfig(
    r=16, lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1, bias="none",
)
model = get_peft_model(model, config)
print("LoRA adapters applied.")

# --- !! CORRECTED ORDER: Unfreeze classifier AFTER LoRA !! ---
print("Unfreezing classification head...")
for param in model.classifier.parameters():
    param.requires_grad = True
# --- End of Fix ---

print("New trainable parameters:")
model.print_trainable_parameters()

# --- 4. Gradient Checkpointing (Task 8) ---
print("Enabling gradient checkpointing...")
model.gradient_checkpointing_enable()

# --- 5. DeepSpeed Initialization (Task 8) ---
print("Initializing DeepSpeed...")
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    optimizer=optimizer,
    config_params='deepspeed_config.json'
)
print("DeepSpeed engine initialized successfully.")

# --- 6. WANDB INITIALIZATION ---
# DeepSpeed provides the rank env var
rank = int(os.environ.get('RANK', 0))
if rank == 0: # Only initialize W&B on the main process
    wandb.init(
        project=WANDB_PROJECT,
        name=WANDB_RUN_NAME,
        config={
            "learning_rate": 5e-5,
            "epochs": 3,
            "batch_size": 16,
            "lora_r": 16,
            "model": "vit-base",
            "optimization": "DeepSpeed ZeRO-Offload + LoRA"
        }
    )

# --- 7. Training Loop (Task 9) ---
device = model_engine.device
num_epochs = 3 # Start with 3-5 epochs to test

print(f"--- Starting training for {num_epochs} epochs ---")
for epoch in range(num_epochs):
    model_engine.train()
    total_loss = 0
    for i, batch in enumerate(train_loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model_engine(inputs, labels=labels)
        loss = outputs.loss

        model_engine.backward(loss)
        model_engine.step()

        total_loss += loss.item()

        # --- 8. LOG TO WANDB (Inside loop) ---
        if rank == 0: # Only log from the main process
            wandb.log({"step_loss": loss.item()})

        if i % 100 == 0:
            print(f"  Epoch {epoch+1}, Step {i}: Loss = {loss.item():.4f}")

    avg_train_loss = total_loss / len(train_loader)
    print(f"**Epoch {epoch+1}/{num_epochs} - Avg. Training Loss: {total_loss / len(train_loader):.4f}**")

    # --- 9. LOG TO WANDB (End of epoch) ---
    if rank == 0:
        wandb.log({"epoch": epoch+1, "avg_train_loss": avg_train_loss})

print("--- Training complete ---")

# --- 10. Evaluation (Task 11) ---
print("--- Starting evaluation ---")
model_engine.eval()
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model_engine(inputs)

        _, predicted = torch.max(outputs.logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = 100 * correct / total
print(f"**Final Test Accuracy: {accuracy:.2f}%**")

# --- 11. Confusion Matrix (Task 11) ---
print("Generating confusion matrix...")
cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix (first 10x10):")
print(cm[:10, :10])

# --- 12. LOG FINAL METRICS TO WANDB ---
if rank == 0:
    wandb.log({"final_test_accuracy": accuracy})

    # Optional: Log the confusion matrix as a W&B Table
    # You can visualize this in the W&B dashboard
    class_names = trainset.classes # Get class names from dataset
    wandb_cm = wandb.plot.confusion_matrix(
        preds=all_preds,
        y_true=all_labels,
        class_names=class_names
    )
    wandb.log({"confusion_matrix": wandb_cm})

    wandb.finish() # Finish the run

print("--- Run complete ---")

Writing train.py


In [None]:
# Install libraries for the script, just in case
# !pip install deepspeed scikit-learn

In [None]:
# Launch the training script with DeepSpeed
# !deepspeed --num_gpus=1 train.py
# We're just adding --master_port 29501 to pick a new, free port
!deepspeed --num_gpus=1 --master_port 29501 train.py

2025-10-22 23:30:26.124413: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761175826.143434    2809 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761175826.149236    2809 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761175826.163884    2809 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761175826.163909    2809 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761175826.163912    2809 computation_placer.cc:177] computation placer alr

## "LoRA-Only" baseline training


In [None]:
%%writefile train_lora_only.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
from peft import LoraConfig, get_peft_model
# No deepspeed import needed
from sklearn.metrics import confusion_matrix
import warnings
import wandb  # <-- Import W&B
import os
import time   # <-- Import time for benchmarking

# Suppress warnings
warnings.filterwarnings("ignore")

print("--- Initializing LoRA-Only Training Script ---")

# --- W&B Setup ---
WANDB_PROJECT = "optimized-vit-periodic-labs"
# Give this run a distinct name for comparison
WANDB_RUN_NAME = "LoRA-Only-T4"

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. Data Prep ---
print("Setting up data transformations...")
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

print("Loading CIFAR-100 dataset...")
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=16, shuffle=False, num_workers=2)
print("DataLoaders created.")

# --- 2. Model Setup ---
print("Loading pre-trained ViT model...")
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=100, ignore_mismatched_sizes=True)
for param in model.parameters():
    param.requires_grad = False

# --- 3. LoRA Setup ---
print("Applying LoRA adapters...")
config = LoraConfig(
    r=16, lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1, bias="none",
)
model = get_peft_model(model, config)
print("LoRA adapters applied.")

# --- Unfreeze classifier AFTER LoRA ---
print("Unfreezing classification head...")
for param in model.classifier.parameters():
    param.requires_grad = True
print("New trainable parameters:")
model.print_trainable_parameters()

# --- 4. Gradient Checkpointing ---
print("Enabling gradient checkpointing...")
model.gradient_checkpointing_enable()

# --- Move model to GPU ---
model.to(device)
print(f"Model moved to {device}.")

# --- 5. Standard PyTorch Optimizer (NO DEEPSPEED) ---
print("Initializing standard PyTorch AdamW optimizer...")
# Only parameters requiring gradients will be optimized
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)

# --- 6. WANDB INITIALIZATION ---
wandb.init(
    project=WANDB_PROJECT,
    name=WANDB_RUN_NAME,
    config={
        "learning_rate": 5e-5,
        "epochs": 3,
        "batch_size": 16,
        "lora_r": 16,
        "model": "vit-base",
        "optimization": "LoRA + Gradient Checkpointing (No DeepSpeed)"
    }
)

# --- 7. Standard Training Loop (NO DEEPSPEED) ---
num_epochs = 3
print(f"--- Starting training for {num_epochs} epochs ---")
start_time = time.time() # Start timer

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    epoch_start_time = time.time() # Timer for epoch
    for i, batch in enumerate(train_loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Standard forward pass
        outputs = model(inputs, labels=labels)
        loss = outputs.loss

        # Standard backward pass
        loss.backward()

        # Standard optimizer step
        optimizer.step()
        optimizer.zero_grad() # Clear gradients for next step

        total_loss += loss.item()

        # Log step loss to W&B
        wandb.log({"step_loss": loss.item()})

        if i % 100 == 0:
            print(f"  Epoch {epoch+1}, Step {i}: Loss = {loss.item():.4f}")

    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time
    avg_train_loss = total_loss / len(train_loader)
    print(f"**Epoch {epoch+1}/{num_epochs} - Avg. Training Loss: {avg_train_loss:.4f} (Duration: {epoch_duration:.2f}s)**")

    # Log epoch metrics to W&B
    wandb.log({"epoch": epoch+1, "avg_train_loss": avg_train_loss, "epoch_duration_sec": epoch_duration})

end_time = time.time()
total_training_time = end_time - start_time
print(f"--- Training complete in {total_training_time:.2f} seconds ---")

# --- 8. Evaluation ---
print("--- Starting evaluation ---")
model.eval()
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Use the standard model for inference
        outputs = model(inputs)

        _, predicted = torch.max(outputs.logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = 100 * correct / total
print(f"**Final Test Accuracy: {accuracy:.2f}%**")

# --- 9. Confusion Matrix ---
print("Generating confusion matrix...")
cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix (first 10x10):")
print(cm[:10, :10])

# --- 10. LOG FINAL METRICS TO WANDB ---
wandb.log({
    "final_test_accuracy": accuracy,
    "total_training_time_sec": total_training_time
})

# Log confusion matrix plot
class_names = trainset.classes
wandb_cm = wandb.plot.confusion_matrix(
    preds=all_preds, y_true=all_labels, class_names=class_names
)
wandb.log({"confusion_matrix": wandb_cm})

wandb.finish() # Finish the run

print("--- Run complete ---")

Writing train_lora_only.py


In [None]:
!python train_lora_only.py

2025-10-23 00:09:20.234732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761178160.256259   12648 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761178160.262536   12648 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761178160.279651   12648 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761178160.279680   12648 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761178160.279686   12648 computation_placer.cc:177] computation placer alr

# These code snippets are for A100 specially

In [None]:
%%writefile deepspeed_config_A100_bs64.json
{
    "bf16": {
        "enabled": true  // Use BFloat16 for A100 performance
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu", // Still offload optimizer to save VRAM, even with 80GB
            "pin_memory": true
        }
    },
    "zero_force_ds_cpu_optimizer": false, // Use fallback AdamW if DeepSpeedCPUAdam isn't built
    "gradient_accumulation_steps": 1,     // Accumulate gradients once
    "gradient_clipping": 1.0,
    "train_batch_size": 64,               // Target effective batch size
    "train_micro_batch_size_per_gpu": 64, // Process 64 samples at once per GPU
    "steps_per_print": 100                // Log more frequently
}

Writing deepspeed_config_A100_bs64.json


## Script for Standard Full Fine-Tune (No DeepSpeed - A100, BS=64)

In [None]:
%%writefile train_standard_full_A100_bs64.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
import warnings
import wandb
import os
import time

# Suppress warnings
warnings.filterwarnings("ignore")
print("--- Initializing STANDARD FULL-TUNE Script (A100 BS=64) ---")

# --- W&B Setup ---
WANDB_PROJECT = "optimized-vit-periodic-labs"
WANDB_RUN_NAME = "Standard-Full-Tune-bs64-GC-fp32-A100" # Larger batch, GC, FP32

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 1. Data Prep ---
print("Setting up data transformations...")
train_transform = transforms.Compose([
    transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4), transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)), transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
print("Loading CIFAR-100 dataset...")
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)

# --- Use LARGER BATCH SIZE ---
MICRO_BATCH_SIZE = 64
train_loader = DataLoader(trainset, batch_size=MICRO_BATCH_SIZE, shuffle=True, num_workers=4) # Increased workers
test_loader = DataLoader(testset, batch_size=MICRO_BATCH_SIZE, shuffle=False, num_workers=4)
print(f"DataLoaders created with micro_batch_size={MICRO_BATCH_SIZE}.")

# --- 2. Model Setup (NO FREEZING) ---
print("Loading pre-trained ViT model for full fine-tuning...")
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=100, ignore_mismatched_sizes=True)

# --- Enable Gradient Checkpointing (Safety for BS=64) ---
print("Enabling gradient checkpointing...")
model.gradient_checkpointing_enable()

model.to(device)
print(f"Model moved to {device}.")

# --- 3. Standard PyTorch Optimizer ---
print("Initializing standard PyTorch AdamW optimizer...")
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
print(f"Optimizer created. Training {sum(p.numel() for p in model.parameters() if p.requires_grad):,} parameters.")

# --- 4. WANDB INITIALIZATION ---
wandb.init(
    project=WANDB_PROJECT,
    name=WANDB_RUN_NAME,
    config={ "learning_rate": 5e-5, "epochs": 1, "batch_size": MICRO_BATCH_SIZE, "model": "vit-base", "optimization": "Standard Full-Tune (bs=64 + GC + FP32)"}
)

# --- 5. Standard Training Loop ---
num_epochs = 1
print(f"--- Starting training for {num_epochs} epoch ---")
start_time = time.time()

model.train()
total_loss = 0
epoch_start_time = time.time()
# No GradScaler needed for FP32

for i, batch in enumerate(train_loader):
    inputs, labels = batch
    inputs = inputs.to(device)
    labels = labels.to(device)

    outputs = model(inputs, labels=labels)
    loss = outputs.loss

    loss.backward() # Standard backward, handles GC recomputation
    optimizer.step()
    optimizer.zero_grad()

    total_loss += loss.item()
    wandb.log({"step_loss": loss.item()})

    if i % 50 == 0: # Print more often with larger batch size
        print(f"  Epoch 1, Step {i}: Loss = {loss.item():.4f}")

epoch_end_time = time.time()
epoch_duration = epoch_end_time - epoch_start_time
avg_train_loss = total_loss / len(train_loader)
print(f"**Epoch 1/{num_epochs} - Avg. Training Loss: {avg_train_loss:.4f} (Duration: {epoch_duration:.2f}s)**")
wandb.log({"epoch": 1, "avg_train_loss": avg_train_loss, "epoch_duration_sec": epoch_duration})

end_time = time.time()
total_training_time = end_time - start_time
print(f"--- Training complete in {total_training_time:.2f} seconds ---")

wandb.log({"total_training_time_sec": total_training_time})
wandb.finish()
print("--- Run complete ---")

Writing train_standard_full_A100_bs64.py


## Script for DeepSpeed Full Fine-Tune (A100 Optimized, BS=64)

In [None]:
%%writefile train_deepspeed_full_A100_bs64.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
# No peft import needed
import deepspeed
import warnings
import wandb
import os
import time

# Suppress warnings
warnings.filterwarnings("ignore")
print("--- Initializing DEEPSPEED FULL-TUNE Script (A100 BS=64 BF16) ---")

# --- W&B Setup ---
WANDB_PROJECT = "optimized-vit-periodic-labs"
WANDB_RUN_NAME = "DeepSpeed-Full-Tune-bs64-bf16-A100" # Updated name

# --- 1. Data Prep ---
print("Setting up data transformations...")
train_transform = transforms.Compose([
    transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4), transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)), transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
print("Loading CIFAR-100 dataset...")
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)

# --- Use LARGER Batch Size ---
MICRO_BATCH_SIZE = 64 # Match config file
train_loader = DataLoader(trainset, batch_size=MICRO_BATCH_SIZE, shuffle=True, num_workers=4) # Increased workers
test_loader = DataLoader(testset, batch_size=MICRO_BATCH_SIZE, shuffle=False, num_workers=4)
print(f"DataLoaders created with micro_batch_size={MICRO_BATCH_SIZE}.")

# --- 2. Model Setup (NO FREEZING) ---
print("Loading pre-trained ViT model for full fine-tuning...")
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=100, ignore_mismatched_sizes=True)
# --- NO Gradient Checkpointing needed with BF16/DeepSpeed on 80GB ---
print("Gradient Checkpointing NOT enabled for this run.")

# --- 3. DeepSpeed Initialization ---
print("Initializing DeepSpeed...")
optimizer = optim.AdamW(model.parameters(), lr=5e-5) # Using fallback
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    optimizer=optimizer,
    config_params='deepspeed_config_A100_bs64.json' # Use the NEW config file
)
print(f"DeepSpeed engine initialized. Training {sum(p.numel() for p in model_engine.module.parameters() if p.requires_grad):,} parameters.")


# --- 4. WANDB INITIALIZATION ---
rank = int(os.environ.get('RANK', 0))
if rank == 0:
    wandb.init(
        project=WANDB_PROJECT,
        name=WANDB_RUN_NAME,
        config={ "learning_rate": 5e-5, "epochs": 1, "batch_size": MICRO_BATCH_SIZE, "model": "vit-base", "precision": "bf16", "optimization": "DeepSpeed Full-Tune (bs=64 + bf16 + ZeRO2-Offload)"}
    )

# --- 5. DeepSpeed Training Loop ---
device = model_engine.device
num_epochs = 1
print(f"--- Starting training for {num_epochs} epoch ---")
start_time = time.time()

model_engine.train()
total_loss = 0
epoch_start_time = time.time()
for i, batch in enumerate(train_loader):
    inputs, labels = batch
    inputs = inputs.to(device)
    labels = labels.to(device)

    outputs = model_engine(inputs, labels=labels)
    loss = outputs.loss

    model_engine.backward(loss)
    model_engine.step()

    total_loss += loss.item()
    if rank == 0:
        wandb.log({"step_loss": loss.item()})

    if i % 50 == 0: # Print more often
        print(f"  Epoch 1, Step {i}: Loss = {loss.item():.4f}")

epoch_end_time = time.time()
epoch_duration = epoch_end_time - epoch_start_time
avg_train_loss = total_loss / len(train_loader)
print(f"**Epoch 1/{num_epochs} - Avg. Training Loss: {avg_train_loss:.4f} (Duration: {epoch_duration:.2f}s)**")

if rank == 0:
    wandb.log({"epoch": 1, "avg_train_loss": avg_train_loss, "epoch_duration_sec": epoch_duration})

end_time = time.time()
total_training_time = end_time - start_time
print(f"--- Training complete in {total_training_time:.2f} seconds ---")

if rank == 0:
    wandb.log({"total_training_time_sec": total_training_time})
    wandb.finish()
print("--- Run complete ---")

Writing train_deepspeed_full_A100_bs64.py


## Runnning the standard script

In [None]:
!python train_standard_full_A100_bs64.py

2025-10-24 20:27:25.449833: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761337645.471274   11394 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761337645.477765   11394 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761337645.494238   11394 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761337645.494269   11394 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761337645.494272   11394 computation_placer.cc:177] computation placer alr

## Run the DeepSpeed script

In [None]:
# Launch DeepSpeed using a different port
!deepspeed --num_gpus=1 --master_port 29501 train_deepspeed_full_A100_bs64.py

2025-10-24 20:38:03.573806: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761338283.594997   14481 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761338283.601502   14481 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761338283.617949   14481 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761338283.617978   14481 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761338283.617981   14481 computation_placer.cc:177] computation placer alr