# Efficient Adaptation and Analysis of Vision Transformers


In [1]:
import torch
import sys

print("=== PyTorch Environment Test ===")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

    # Test GPU tensor
    x = torch.randn(3, 3).cuda()
    print(f"\nGPU tensor created: {x.device}")
    print(f"Tensor shape: {x.shape}")
else:
    print("CUDA not available - using CPU")
    x = torch.randn(3, 3)
    print(f"CPU tensor created: {x.device}")

print("\n‚úÖ PyTorch test completed!")

=== PyTorch Environment Test ===
Python version: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
Number of GPUs: 2
GPU name: Tesla T4

GPU tensor created: cuda:0
Tensor shape: torch.Size([3, 3])

‚úÖ PyTorch test completed!


In [2]:
!nvidia-smi

Sun Nov 30 18:39:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   44C    P0             26W /   70W |     105MiB /  15360MiB |      3%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla T4                       Off |   00

In [3]:
!pip install wandb
# !wandb login

import wandb
from kaggle_secrets import UserSecretsClient

# Get the secret value
user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")

# Log in using the key
wandb.login(key=wandb_api_key)

print("‚úÖ Successfully logged into W&B!")



[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msiddpath[0m ([33msiddpath-university-of-maryland[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


‚úÖ Successfully logged into W&B!


#### Downloading the data

In [4]:
import torchvision
import torchvision.transforms as transforms

# Define a transform to normalize the data
# You'll likely want to resize to 224x224 and use ImageNet stats for ViT
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Load the Food-101 training dataset
# Food-101 uses 'split="train"' for training
trainset = torchvision.datasets.Food101(root='./data', split="train",
                                        download=True, transform=transform)

# Load the Food-101 test (validation) dataset
# Food-101 uses 'split="test"' for testing
testset = torchvision.datasets.Food101(root='./data', split="test",
                                       download=True, transform=transform)

print("Food-101 dataset imported successfully.")
print(f"Training set size: {len(trainset)}")
print(f"Test set size: {len(testset)}")

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5.00G/5.00G [00:22<00:00, 223MB/s] 


Food-101 dataset imported successfully.
Training set size: 75750
Test set size: 25250


#### Resizing the data for ViT model


In [5]:
import torchvision.transforms as transforms

# Define transforms for training and validation/testing
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Upsample to ViT resolution
    transforms.RandomHorizontalFlip(), # Example data augmentation
    transforms.RandomCrop(224, padding=4), # Example data augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)), # Upsample to ViT resolution
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Apply the transforms to the datasets
trainset.transform = train_transform
testset.transform = test_transform

print("Data preparation complete. Transforms applied to datasets.")

Data preparation complete. Transforms applied to datasets.


#### Loading the VIT and freezing the parameters

In [6]:
from transformers import ViTForImageClassification

# Load a pre-trained ViT-Huge model
model = ViTForImageClassification.from_pretrained(
    'google/vit-huge-patch14-224-in21k',  # Model name for ViT-Huge
    num_labels=101,                       # Updated for 101 classes in Food-101
    ignore_mismatched_sizes=True          # Allows replacing the classifier head
)

print("Pre-trained ViT-Huge model loaded.")

2025-11-30 18:41:27.998628: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764528088.134481      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764528088.176033      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

config.json:   0%|          | 0.00/503 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.53G [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-huge-patch14-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Pre-trained ViT-Huge model loaded.


In [7]:
# %pip install peft transformers datasets

# Here we decide if we want to train full model or some

In [8]:
# Print the number of trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

Number of trainable parameters: 630894181


In [9]:
# model

## Step 1: Create the DataLoaders

In [10]:
from torch.utils.data import DataLoader

# Create DataLoaders
train_loader = DataLoader(trainset, batch_size=16, shuffle=True)
test_loader = DataLoader(testset, batch_size=16, shuffle=False)

print("DataLoaders created.")

DataLoaders created.


## Step 2: Enable Gradient Checkpointing

In [11]:
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
print("Gradient checkpointing enabled.")

Gradient checkpointing enabled.


## Step 3: Initialize DeepSpeed

In [12]:
%pip install mpi4py

Collecting mpi4py
  Downloading mpi4py-4.1.1-cp311-cp311-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (16 kB)
Downloading mpi4py-4.1.1-cp311-cp311-manylinux1_x86_64.manylinux_2_5_x86_64.whl (1.4 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.4/1.4 MB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hInstalling collected packages: mpi4py
Successfully installed mpi4py-4.1.1
Note: you may need to restart the kernel to use updated packages.


In [13]:
%pip install deepspeed

Collecting deepspeed
  Downloading deepspeed-0.18.2.tar.gz (1.6 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.6/1.6 MB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting hjson (from deepspeed)
  Downloading hjson-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->deepspeed)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->deepspeed)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->deepspeed)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->deepspee

In [14]:
!apt-get install -y libaio-dev

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  libaio-dev
0 upgraded, 1 newly installed, 0 to remove and 165 not upgraded.
Need to get 21.2 kB of archives.
After this operation, 71.7 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libaio-dev amd64 0.3.112-13build1 [21.2 kB]
Fetched 21.2 kB in 0s (238 kB/s)      
Selecting previously unselected package libaio-dev:amd64.
(Reading database ... 128639 files and directories currently installed.)
Preparing to unpack .../libaio-dev_0.3.112-13build1_amd64.deb ...
Unpacking libaio-dev:amd64 (0.3.112-13build1) ...
Setting up libaio-dev:amd64 (0.3.112-13build1) ...
Processing triggers for man-db (2.10.2-1) ...


## Script for Standard Full Fine-Tune Using DDP, running on 2 T4 gpus batch size = 16 per gpu

In [15]:
%%writefile train_standard_ddp_vit_huge_bs16.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTImageProcessor
import warnings
import wandb
import os
import time
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import GradScaler, autocast
import torch.profiler 

# --- 1. Setup & Helper Functions ---
warnings.filterwarnings("ignore")
print("--- Initializing Standard DDP (PRO RUN) ---")

rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])

dist.init_process_group(backend='nccl')
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")

def inspect_model_sharding(model, name_tag):
    """Checks if parameters are sharded (Expected: NO for DDP)"""
    if rank == 0:
        print(f"\n[Rank {rank}] --- INSPECTING SHARDING ({name_tag}) ---")
        # Unwrap DDP to get to the actual model parameters
        actual_model = model.module if hasattr(model, "module") else model
        
        count = 0
        for name, param in actual_model.named_parameters():
            print(f"Param: {name}")
            # In DDP, Physical Shape should equal Logical Shape (Full Layer)
            print(f"  Physical Shape: {param.shape} (Elements: {param.numel():,})")
            print(f"  Status:         REPLICATED (Full Copy)")
            print("-" * 40)
            count += 1
            if count >= 2: break
        print("----------------------------------\n")

# --- 2. Data & Model ---
if rank == 0: print("Setting up Data & Model...")
processor = ViTImageProcessor.from_pretrained('google/vit-huge-patch14-224-in21k')
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

if rank == 0:
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=True, transform=train_transform)
dist.barrier()
if rank != 0:
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=False, transform=train_transform)

train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
train_loader = DataLoader(trainset, batch_size=16, sampler=train_sampler, num_workers=2, pin_memory=True)

model = ViTForImageClassification.from_pretrained('google/vit-huge-patch14-224-in21k', num_labels=101, ignore_mismatched_sizes=True)
model.gradient_checkpointing_enable()
model.to(device)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

optimizer = optim.AdamW(model.parameters(), lr=5e-5)
scaler = GradScaler(enabled=True)

# --- 3. Logging & Inspection ---
if rank == 0:
    wandb.init(
        project="Distributed ViT training systems-Latest_run",
        name="Standard-DDP-ViT-Huge-Food101-2xT4-Batch-size-16-per-gpu",
        config={"model": "vit-huge", "mode": "DDP"}
    )

# PRO CHECK: This output will prove that DDP duplicates the full model on every GPU
inspect_model_sharding(model, "Standard DDP")

# --- 4. Training Loop with Labels ---
print(f"[Rank {rank}] Starting training...")

for i, batch in enumerate(train_loader):
    # Stop early to save time (we just need traces and logs)
    if i >= 6: break
    
    # LABEL 0: Top Level Step
    with torch.profiler.record_function(f"## Training Step {i} ##"):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        # LABEL 1: Forward
        with torch.profiler.record_function("## Forward Pass ##"):
            with autocast(enabled=True):
                outputs = model(inputs, labels=labels)
                loss = outputs.loss

        # LABEL 2: Backward
        with torch.profiler.record_function("## Backward Pass ##"):
            # This is where it usually crashes (OOM)
            scaler.scale(loss).backward()

        # LABEL 3: Optimizer
        with torch.profiler.record_function("## Optimizer Step ##"):
            scaler.step(optimizer)
            scaler.update()

        # --- PRO LOGGING ---
        if rank == 0:
            mem = torch.cuda.memory_allocated() / 1e9
            max_mem = torch.cuda.max_memory_allocated() / 1e9
            wandb.log({
                "loss": loss.item(),
                "System/Memory_Allocated_GB": mem,
                "System/Max_Memory_GB": max_mem
            })
            print(f"Step {i}: Loss={loss.item():.4f} | Mem={mem:.2f}GB")

dist.destroy_process_group()

Writing train_standard_ddp_vit_huge_bs16.py


In [16]:
# 1. Install/Login
!pip install -q transformers datasets peft deepspeed wandb scikit-learn seaborn tensorboard
!wandb login

# 2. Launch the DDP script
# !torchrun --nproc_per_node=2 train_standard_ddp.py
!torchrun --nproc_per_node=2 --master_port 29501 train_standard_ddp_vit_huge_bs16.py

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m47.7/47.7 MB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
pylibcudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
cudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
bigframes 2.12.0 requires rich<14,>=12.4.4, but you have rich 14.2.0 which is incompatible.
cudf-polars-cu12 25.6.0 requires pylibcudf-cu12==25.6.*, but you have pylibcudf-cu12 25.2.2 which is incompatible.[0m[31m
[34m[1mwa

## Script for DeepSpeed Full Fine-Tune, running on 2 T4 gpus stage 2

In [17]:
%%writefile deepspeed_config_stage2_vit_huge_bs16.json
{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 5e-5,
            "betas": [0.9, 0.999],
            "eps": 1e-8,
            "weight_decay": 3e-7
        }
    },
    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 5e8,
        "reduce_scatter": true,
        "reduce_bucket_size": 5e8,
        "overlap_comm": true,
        "contiguous_gradients": true
    },
    "train_batch_size": 32,
    "train_micro_batch_size_per_gpu": 16,
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 1.0,
    "steps_per_print": 50,
    "wall_clock_breakdown": true,
    "flops_profiler": {
        "enabled": true,
        "profile_step": 5,
        "module_depth": -1,
        "top_modules": 1,
        "detailed": true,
        "output_file": null
    }
}

Writing deepspeed_config_stage2_vit_huge_bs16.json


In [18]:
%%writefile train_deepspeed_stage2_vit_huge_bs16.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTImageProcessor
import deepspeed
import warnings
import wandb
import os
import time
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from sklearn.metrics import confusion_matrix
import torch.profiler # <--- PRO ADDITION

# Suppress warnings
warnings.filterwarnings("ignore")
print("--- Initializing DEEPSPEED ZeRO Stage 2 (PRO RUN) ---")

# --- 1. Setup ---
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
torch.cuda.set_device(local_rank)
print(f"[Rank {rank}] Initializing process...")

# --- W&B Setup ---
if rank == 0:
    WANDB_PROJECT = "Distributed ViT training systems-Latest_run"
    WANDB_RUN_NAME = "DeepSpeed-Stage2-ViT-Huge-Food101-2xT4-MBS16-GAS1-Batch-size-16-per-gpu"

# --- HELPER: PROVE NO SHARDING ---
def inspect_model_sharding(model):
    """Stage 2 should show NO parameter sharding (Physical == Logical)"""
    if rank == 0:
        print(f"\n[Rank {rank}] --- INSPECTING SHARDING (Stage 2) ---")
        count = 0
        for name, param in model.named_parameters():
            # In Stage 2, parameters are NOT partitioned by default
            print(f"Param: {name}")
            print(f"  Physical Shape: {param.shape}")
            print(f"  Status:         Full Layer Present (High Memory)")
            print("-" * 40)
            count += 1
            if count >= 2: break
        print("----------------------------------\n")

# --- 2. Model Setup ---
if rank == 0:
    print("Loading pre-trained ViT-Huge model...")
model = ViTForImageClassification.from_pretrained('google/vit-huge-patch14-224-in21k', num_labels=101, ignore_mismatched_sizes=True)
model.gradient_checkpointing_enable()

# --- 3. DeepSpeed Init ---
if rank == 0: print("Initializing DeepSpeed...")
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model, model_parameters=model.parameters(), 
    config_params='deepspeed_config_stage2_vit_huge_bs16.json'
)

# --- 4. Data Prep ---
if rank == 0: print("Setting up data...")
processor = ViTImageProcessor.from_pretrained('google/vit-huge-patch14-224-in21k')
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

if rank == 0:
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=True, transform=train_transform)
    testset = torchvision.datasets.Food101(root='./data', split="test", download=True, transform=test_transform)
dist.barrier() 
if rank != 0:
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=False, transform=train_transform)
    testset = torchvision.datasets.Food101(root='./data', split="test", download=False, transform=test_transform)

MICRO_BATCH_SIZE = 16
train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
test_sampler = DistributedSampler(testset, num_replicas=world_size, rank=rank, shuffle=False)
train_loader = DataLoader(trainset, batch_size=MICRO_BATCH_SIZE, sampler=train_sampler, num_workers=2, pin_memory=True)
test_loader = DataLoader(testset, batch_size=MICRO_BATCH_SIZE, sampler=test_sampler, num_workers=2, pin_memory=True)

# --- 6. WANDB INIT ---
if rank == 0:
    wandb.init(
        project=WANDB_PROJECT, name=WANDB_RUN_NAME,
        config={"model": "vit-huge", "optimization": "Stage 2"}
    )

# Run Check
inspect_model_sharding(model_engine.module)

# --- 7. Training Loop ---
device = model_engine.device
num_epochs = 1
start_time = time.time()

for epoch in range(num_epochs):
    model_engine.train()
    train_sampler.set_epoch(epoch)
    total_loss_rank = 0.0
    
    for i, batch in enumerate(train_loader):
        
        # PRO LABEL: Top Level
        with torch.profiler.record_function(f"## Training Step {i} ##"):
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            # PRO LABEL: Forward (Clean)
            with torch.profiler.record_function("## Forward Pass ##"):
                outputs = model_engine(inputs, labels=labels)
                loss = outputs.loss

            # PRO LABEL: Backward (Blocking)
            with torch.profiler.record_function("## Backward Pass ##"):
                model_engine.backward(loss)

            # PRO LABEL: Optimizer
            with torch.profiler.record_function("## Optimizer Step ##"):
                model_engine.step()

            total_loss_rank += loss.item()
            
            # --- PRO LOGGING ---
            if rank == 0:
                mem = torch.cuda.memory_allocated() / 1e9
                if hasattr(model_engine, 'tput_timer'):
                    tput = model_engine.tput_timer.avg_samples_per_sec
                else: tput = 0
                
                wandb.log({"step_loss": loss.item(), "System/Memory_Allocated_GB": mem, "Performance/Throughput": tput})
                if i % 100 == 0: print(f"  Step {i}: Loss={loss.item():.4f} | Mem={mem:.2f}GB")

    # Sync Loss
    total_loss_all_ranks = torch.tensor(total_loss_rank, dtype=torch.float, device=device)
    dist.all_reduce(total_loss_all_ranks, op=dist.ReduceOp.SUM)
    
    if rank == 0:
        avg_train_loss = total_loss_all_ranks.item() / (len(train_loader) * world_size)
        duration = time.time() - start_time
        wandb.log({"epoch": 1, "avg_train_loss": avg_train_loss, "epoch_duration_sec": duration})

if rank == 0:
    total_time = time.time() - start_time
    wandb.log({"total_training_time_sec": total_time})

# --- 8. Evaluation ---
if rank == 0: print("--- Starting evaluation ---")
model_engine.eval()
rank_preds = []
rank_labels = []

with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels_cpu = labels.cpu() 
        outputs = model_engine(inputs)
        _, predicted = torch.max(outputs.logits.data, 1)
        rank_preds.extend(predicted.cpu().numpy())
        rank_labels.extend(labels_cpu.numpy())

all_preds = [None] * world_size
all_labels = [None] * world_size
dist.all_gather_object(all_preds, rank_preds)
dist.all_gather_object(all_labels, rank_labels)

if rank == 0:
    print("Computing metrics...")
    flat_preds = [item for sublist in all_preds for item in sublist]
    flat_labels = [item for sublist in all_labels for item in sublist]
    
    correct = (torch.tensor(flat_preds) == torch.tensor(flat_labels)).sum().item()
    accuracy = 100 * correct / len(flat_preds)
    print(f"**Final Test Accuracy: {accuracy:.2f}%**")

    wandb.log({"final_test_accuracy": accuracy})
    wandb_cm = wandb.plot.confusion_matrix(preds=flat_preds, y_true=flat_labels, class_names=trainset.classes)
    wandb.log({"confusion_matrix": wandb_cm})
    wandb.finish()

dist.destroy_process_group()
print(f"--- [Rank {rank}] Run complete ---")

Writing train_deepspeed_stage2_vit_huge_bs16.py


In [19]:
!deepspeed --num_gpus=2 train_deepspeed_stage2_vit_huge_bs16.py

2025-11-24 00:55:34.322063: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763945734.344443     407 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763945734.352258     407 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
[2025-11-24 00:55:39,427] [INFO] [runner.py:630:main] cmd = /usr/bin/python3 

## Script for DeepSpeed Full Fine-Tune, running on 2 T4 gpus stage 3

In [20]:
%%writefile deepspeed_config_stage3_vit_huge_bs16.json
{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 5e-5,
            "betas": [0.9, 0.999],
            "eps": 1e-8,
            "weight_decay": 3e-7
        }
    },
    "zero_optimization": {
        "stage": 3,
        "allgather_partitions": true,
        "allgather_bucket_size": 5e8,
        "reduce_scatter": true,
        "reduce_bucket_size": 5e8,
        "overlap_comm": true,
        "contiguous_gradients": true
    },
    "train_batch_size": 32,
    "train_micro_batch_size_per_gpu": 16,
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 1.0,
    "steps_per_print": 50,
    "wall_clock_breakdown": true,
    "flops_profiler": {
        "enabled": true,
        "profile_step": 5,
        "module_depth": -1,
        "top_modules": 1,
        "detailed": true,
        "output_file": null
    }
}

Writing deepspeed_config_stage3_vit_huge_bs16.json


In [21]:
%%writefile train_deepspeed_stage3_vit_huge_bs16.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTImageProcessor
import deepspeed
import warnings
import wandb
import os
import time
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from sklearn.metrics import confusion_matrix
import torch.profiler # <--- PRO ADDITION

# Suppress warnings
warnings.filterwarnings("ignore")
print("--- Initializing DEEPSPEED ZeRO Stage 3 (PRO RUN - SYSTEMS CHECK) ---")

# --- 1. DDP/DeepSpeed Setup ---
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])

torch.cuda.set_device(local_rank)
print(f"[Rank {rank}] Initializing process...")

# --- HELPER: PROVE ACTIVE SHARDING ---
def inspect_model_sharding(model):
    """Stage 3 should show ACTIVE parameter sharding"""
    if rank == 0:
        print(f"\n[Rank {rank}] --- INSPECTING SHARDING (Stage 3) ---")
        count = 0
        for name, param in model.named_parameters():
            # Check for DeepSpeed Attributes indicating sharding
            if hasattr(param, "ds_numel"):
                logical = param.ds_numel # Full size
                physical = param.numel() # Actual size on this GPU
                ratio = physical / logical
                print(f"Param: {name}")
                print(f"  Logical (Full):   {param.ds_shape} ({logical:,})")
                print(f"  Physical (Shard): {param.shape} ({physical:,})")
                print(f"  Shard Ratio:      {ratio:.2%} (Target: ~{100/world_size:.1f}%)")
                print("-" * 40)
                count += 1
            if count >= 2: break
        print("----------------------------------\n")

# --- 2. Model Setup (ViT-Huge, NO FREEZING) ---
if rank == 0:
    print("Loading pre-trained ViT-Huge model for full fine-tuning...")
model = ViTForImageClassification.from_pretrained('google/vit-huge-patch14-224-in21k', num_labels=101, ignore_mismatched_sizes=True)
print(f"[Rank {rank}] Model loaded.")

# --- !! KEEPING GRADIENT CHECKPOINTING !! ---
print("Enabling gradient checkpointing...")
model.gradient_checkpointing_enable()

# --- 3. DeepSpeed Initialization ---
if rank == 0:
    print("Initializing DeepSpeed...")
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(), 
    config_params='deepspeed_config_stage3_vit_huge_bs16.json' # <-- USE PRO CONFIG
)
if rank == 0:
    print(f"DeepSpeed engine initialized. Training {sum(p.numel() for p in model_engine.module.parameters() if p.requires_grad):,} parameters.")

# --- 4. Data Prep (Food-101) ---
if rank == 0:
    print("Setting up data transformations...")
processor = ViTImageProcessor.from_pretrained('google/vit-huge-patch14-224-in21k')
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

if rank == 0:
    print("Loading Food-101 dataset...")
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=True, transform=train_transform)
    testset = torchvision.datasets.Food101(root='./data', split="test", download=True, transform=test_transform)
    
dist.barrier() 

if rank != 0:
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=False, transform=train_transform)
    testset = torchvision.datasets.Food101(root='./data', split="test", download=False, transform=test_transform)

# --- 5. Data Samplers ---
MICRO_BATCH_SIZE = 16 
train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
test_sampler = DistributedSampler(testset, num_replicas=world_size, rank=rank, shuffle=False)
train_loader = DataLoader(trainset, batch_size=MICRO_BATCH_SIZE, sampler=train_sampler, num_workers=2, pin_memory=True)
test_loader = DataLoader(testset, batch_size=MICRO_BATCH_SIZE, sampler=test_sampler, num_workers=2, pin_memory=True)

# --- 6. WANDB INITIALIZATION ---
if rank == 0:
    wandb.init(
        project="Distributed ViT training systems-Latest_run",
        name="DeepSpeed-Stage3-ViT-Huge-Food101-2xT4-BS16-GC-Batch-size-16-per-gpu",
        config={ 
            "model": "vit-huge", 
            "optimization": "DeepSpeed ZeRO Stage 3 (GC)",
            "batch_size": MICRO_BATCH_SIZE * world_size
        }
    )

# PRO CHECK: Verify Sharding
inspect_model_sharding(model_engine.module)

# --- 7. DeepSpeed Training Loop ---
device = model_engine.device
num_epochs = 1
if rank == 0:
    print(f"--- Starting training for {num_epochs} epoch ---")
start_time = time.time()

for epoch in range(num_epochs):
    model_engine.train()
    train_sampler.set_epoch(epoch)
    total_loss_rank = 0.0
    
    for i, batch in enumerate(train_loader):
        # Optional: Break early for profiling if you don't want full epoch
        # if i >= 10: break 
        
        # LABEL 0: Top Level
        with torch.profiler.record_function(f"## Training Step {i} ##"):
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            # LABEL 1: Forward Pass
            # EXPECTATION: Pink AllGather bars INSIDE this block (Overlap)
            with torch.profiler.record_function("## Forward Pass ##"):
                outputs = model_engine(inputs, labels=labels)
                loss = outputs.loss

            # LABEL 2: Backward Pass
            # EXPECTATION: Pink AllGather bars INSIDE this block (Overlap)
            with torch.profiler.record_function("## Backward Pass ##"):
                model_engine.backward(loss)

            # LABEL 3: Optimizer
            with torch.profiler.record_function("## Optimizer Step ##"):
                model_engine.step()

            total_loss_rank += loss.item()
            
            # --- PRO LOGGING ---
            if rank == 0:
                # 1. Log Loss
                wandb.log({"step_loss": loss.item()})
                
                # 2. Log System Stats (Every step or every few steps)
                if i % 10 == 0:
                    mem = torch.cuda.memory_allocated() / 1e9
                    # DeepSpeed tracks throughput internally via tput_timer
                    if hasattr(model_engine, 'tput_timer'):
                        tput = model_engine.tput_timer.avg_samples_per_sec
                    else:
                        tput = 0
                    
                    wandb.log({
                        "System/Memory_Allocated_GB": mem,
                        "Performance/Throughput": tput
                    })
                    print(f"  Epoch 1, Step {i}: Loss = {loss.item():.4f} | Mem: {mem:.2f}GB")

    # Sync Loss for epoch average
    total_loss_all_ranks = torch.tensor(total_loss_rank, dtype=torch.float, device=device)
    dist.all_reduce(total_loss_all_ranks, op=dist.ReduceOp.SUM)
    
    if rank == 0:
        avg_train_loss = total_loss_all_ranks.item() / (len(train_loader) * world_size)
        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - start_time
        print(f"**Epoch 1/{num_epochs} - Avg. Training Loss: {avg_train_loss:.4f} (Duration: {epoch_duration:.2f}s)**")
        wandb.log({"epoch": 1, "avg_train_loss": avg_train_loss, "epoch_duration_sec": epoch_duration})

if rank == 0:
    total_time = time.time() - start_time
    wandb.log({"total_training_time_sec": total_time})

# --- 8. Evaluation ---
if rank == 0:
    print("--- Starting evaluation ---")
model_engine.eval()
rank_preds = []
rank_labels = []

with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels_cpu = labels.cpu() 
        outputs = model_engine(inputs)
        _, predicted = torch.max(outputs.logits.data, 1)
        rank_preds.extend(predicted.cpu().numpy())
        rank_labels.extend(labels_cpu.numpy())

all_preds = [None] * world_size
all_labels = [None] * world_size
dist.all_gather_object(all_preds, rank_preds)
dist.all_gather_object(all_labels, rank_labels)

if rank == 0:
    print("[Rank 0] All results gathered. Computing final accuracy...")
    flat_preds = [item for sublist in all_preds for item in sublist]
    flat_labels = [item for sublist in all_labels for item in sublist]
    
    correct = (torch.tensor(flat_preds) == torch.tensor(flat_labels)).sum().item()
    accuracy = 100 * correct / len(flat_preds)
    print(f"**Final Test Accuracy: {accuracy:.2f}%**")

    wandb.log({"final_test_accuracy": accuracy})
    class_names = trainset.classes
    wandb_cm = wandb.plot.confusion_matrix(
        preds=flat_preds, y_true=flat_labels, class_names=class_names
    )
    wandb.log({"confusion_matrix": wandb_cm})
    wandb.finish()

print(f"--- [Rank {rank}] Run complete ---")

Writing train_deepspeed_stage3_vit_huge_bs16.py


In [None]:
!deepspeed --num_gpus=2 train_deepspeed_stage3_vit_huge_bs16.py

2025-11-24 02:31:05.280645: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763951465.304501     773 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763951465.312573     773 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
[2025-11-24 02:31:10,817] [INFO] [runner.py:630:main] cmd = /usr/bin/python3 

## Profiling of the processes

### Script for DeepSpeed Full Fine-Tune, running on 2 T4 gpus stage 2

In [None]:
%%writefile train_deepspeed_stage2_profiler.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTImageProcessor
import deepspeed
import warnings
import os
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import torch.profiler # Import Profiler

# Suppress warnings
warnings.filterwarnings("ignore")
print("--- Initializing DEEPSPEED ZeRO Stage 2 (PROFILING RUN - LABELED) ---")

# --- 1. DDP/DeepSpeed Setup ---
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])

torch.cuda.set_device(local_rank)

# --- 2. Model Setup (ViT-Huge) ---
if rank == 0:
    print("Loading pre-trained ViT-Huge model...")
model = ViTForImageClassification.from_pretrained('google/vit-huge-patch14-224-in21k', num_labels=101, ignore_mismatched_sizes=True)

# Enable Gradient Checkpointing
model.gradient_checkpointing_enable()

# --- 3. DeepSpeed Initialization ---
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(), 
    config_params='deepspeed_config_stage2_vit_huge_bs16.json'
)

# --- 4. Data Prep (Food-101) ---
processor = ViTImageProcessor.from_pretrained('google/vit-huge-patch14-224-in21k')
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

if rank == 0:
    print("Loading dataset...")
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=True, transform=train_transform)
dist.barrier()
if rank != 0:
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=False, transform=train_transform)

# --- 5. Data Samplers ---
MICRO_BATCH_SIZE = 16 
train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
train_loader = DataLoader(trainset, batch_size=MICRO_BATCH_SIZE, sampler=train_sampler, num_workers=2, pin_memory=True)

# --- 6. Define Profiler Handler ---
def trace_handler(prof):
    if rank == 0:
        print("Profiler trace ready. Saving to ./logs/stage2_trace ...")
        os.makedirs("./logs/stage2_trace", exist_ok=True)
        prof.export_chrome_trace(f"./logs/stage2_trace/rank{rank}_trace.json")
        print("Trace saved.")

# --- 7. Training Loop with Profiler & LABELS ---
device = model_engine.device
model_engine.train()

print("Starting profiling run...")

# Short schedule: Wait 1, Warmup 1, Active 3 (Just enough to get clear data)
with torch.profiler.profile(
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
    on_trace_ready=trace_handler,
    record_shapes=True,
    with_stack=True
) as prof:

    for i, batch in enumerate(train_loader):
        # Stop early after the active phase
        if i >= 6:
            break

        # Create a top-level label for the whole step
        with torch.profiler.record_function(f"## Training Step {i} ##"):
            
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            # --- LABEL 1: FORWARD PASS ---
            # In Stage 3, look for AllGather INSIDE this block.
            # In Stage 2, it should be clean (mostly compute).
            with torch.profiler.record_function("## Forward Pass ##"):
                outputs = model_engine(inputs, labels=labels)
                loss = outputs.loss

            # --- LABEL 2: BACKWARD PASS ---
            # In Stage 2, look for ReduceScatter at the END of this block (blocking).
            with torch.profiler.record_function("## Backward Pass ##"):
                model_engine.backward(loss)

            # --- LABEL 3: OPTIMIZER ---
            with torch.profiler.record_function("## Optimizer Step ##"):
                model_engine.step()

            if rank == 0:
                print(f"Step {i}: Loss = {loss.item():.4f}")

        # Step the profiler
        prof.step()

print("Profiling complete.")

In [None]:
!deepspeed --num_gpus=2 train_deepspeed_stage2_profiler.py

### Script for DeepSpeed Full Fine-Tune, running on 2 T4 gpus stage 3

In [None]:
%%writefile train_deepspeed_stage3_profiler.py
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTImageProcessor
import deepspeed
import warnings
import os
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import torch.profiler # Import Profiler

# Suppress warnings
warnings.filterwarnings("ignore")
print("--- Initializing DEEPSPEED ZeRO Stage 3 (PROFILING RUN - LABELED) ---")

# --- 1. DDP/DeepSpeed Setup ---
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])

torch.cuda.set_device(local_rank)

# --- 2. Model Setup (ViT-Huge) ---
if rank == 0:
    print("Loading pre-trained ViT-Huge model...")
model = ViTForImageClassification.from_pretrained('google/vit-huge-patch14-224-in21k', num_labels=101, ignore_mismatched_sizes=True)

# Enable Gradient Checkpointing
model.gradient_checkpointing_enable()

# --- 3. DeepSpeed Initialization ---
# Note: Using Stage 3 Config
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(), 
    config_params='deepspeed_config_stage3_vit_huge_bs16.json'
)

# --- 4. Data Prep (Food-101) ---
processor = ViTImageProcessor.from_pretrained('google/vit-huge-patch14-224-in21k')
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

if rank == 0:
    print("Loading dataset...")
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=True, transform=train_transform)
    
dist.barrier()

if rank != 0:
    trainset = torchvision.datasets.Food101(root='./data', split="train", download=False, transform=train_transform)

# --- 5. Samplers ---
MICRO_BATCH_SIZE = 16 
train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank, shuffle=True)
train_loader = DataLoader(trainset, batch_size=MICRO_BATCH_SIZE, sampler=train_sampler, num_workers=2, pin_memory=True)

# --- 6. Define Profiler Handler ---
def trace_handler(prof):
    if rank == 0:
        print("Profiler trace ready. Saving to ./logs/stage3_trace ...")
        os.makedirs("./logs/stage3_trace", exist_ok=True)
        prof.export_chrome_trace(f"./logs/stage3_trace/rank{rank}_trace.json")
        print("Trace saved.")

# --- 7. Training Loop with Profiler & LABELS ---
device = model_engine.device
model_engine.train()

print("Starting profiling run...")

with torch.profiler.profile(
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
    on_trace_ready=trace_handler,
    record_shapes=True,
    with_stack=True
) as prof:

    for i, batch in enumerate(train_loader):
        # Stop early! We only need enough steps for the schedule
        if i >= 6:
            break

        # Create a top-level label for the whole step
        with torch.profiler.record_function(f"## Training Step {i} ##"):

            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            # --- LABEL 1: FORWARD PASS ---
            # PROOF POINT: In Stage 3, you will see communication INSIDE this block
            with torch.profiler.record_function("## Forward Pass ##"):
                outputs = model_engine(inputs, labels=labels)
                loss = outputs.loss

            # --- LABEL 2: BACKWARD PASS ---
            # PROOF POINT: In Stage 3, you will see communication INSIDE this block
            with torch.profiler.record_function("## Backward Pass ##"):
                model_engine.backward(loss)

            # --- LABEL 3: OPTIMIZER ---
            with torch.profiler.record_function("## Optimizer Step ##"):
                model_engine.step()

            if rank == 0:
                print(f"Step {i}: Loss = {loss.item():.4f}")

        # Step the profiler
        prof.step()

print("Profiling complete.")

In [None]:
!deepspeed --num_gpus=2 train_deepspeed_stage3_profiler.py

### checking the results

In [None]:
!pip install -q tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./logs

In [None]:
%tensorboard --logdir ./logs


In [None]:

!pkill tensorboard


In [None]:
!zip -r my_profiler_logs.zip ./logs