# =====================================================================
# üìå CELL 1 ‚Äî Install dependencies (RUN THIS FIRST IN COLAB)
# =====================================================================


In [1]:
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install timm bitsandbytes flash-attn --no-build-isolation
!pip install matplotlib pandas


Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting flash-attn
  Downloading flash_attn-2.8.3.tar.gz (8.4 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m8.4/8.4 MB[0m [31m70.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.4/59.4 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.8.3-cp312-cp312-linux_x86_64.whl size=256040057 sha25

# =====================================================================
# üìå CELL 2 ‚Äî Imports, device setup, Tiny-ImageNet paths
# =====================================================================


In [8]:
import os
import time
import zipfile
import urllib.request
from pathlib import Path
from statistics import mean
from typing import Dict, Any, List

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T

import timm
import pandas as pd
import matplotlib.pyplot as plt

# Check flash-attn
try:
    from flash_attn import flash_attn_func
    FLASH_ATTENTION_AVAILABLE = True
    print("‚úÖ FlashAttention-2 available.")
except Exception as e:
    FLASH_ATTENTION_AVAILABLE = False
    print("‚ùå FlashAttention-2 NOT available:", e)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

SEED = 42
torch.manual_seed(SEED)
if device.type == "cuda":
    torch.cuda.manual_seed_all(SEED)

DATA_ROOT = "/content/tiny-imagenet-200"
RESULTS_DIR = Path("/content/results_vit_quant_fa2")
RESULTS_DIR.mkdir(exist_ok=True, parents=True)

BATCH_SIZE = 64
NUM_WORKERS = 2
IMAGE_SIZE = 224
MAX_VAL_BATCHES = 10   # small subset for Colab speed


‚úÖ FlashAttention-2 available.
Using device: cuda


# =====================================================================
# üìå CELL 3 ‚Äî Download + reorganize Tiny-ImageNet
# =====================================================================


In [9]:
if not os.path.exists(DATA_ROOT):
    print("üì• Downloading Tiny-ImageNet...")
    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    zip_path = "/content/tiny-imagenet-200.zip"
    urllib.request.urlretrieve(url, zip_path)
    print("Extracting...")
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall("/content")
    print("Done.")
else:
    print("Tiny-ImageNet already present.")

# Fix val/ to ImageFolder format
val_dir = Path(DATA_ROOT) / "val"
images_dir = val_dir / "images"
ann_path = val_dir / "val_annotations.txt"

if images_dir.exists():
    print("üîß Organizing val/...")
    import shutil
    with open(ann_path, "r") as f:
        for line in f:
            img, wnid = line.strip().split("\t")[:2]
            cls_dir = val_dir / wnid
            cls_dir.mkdir(exist_ok=True)
            src = images_dir / img
            dst = cls_dir / img
            if src.exists():
                shutil.move(str(src), str(dst))
    print("val/ reorganized.")
else:
    print("val/ already organized.")

# ... your download + reorg code above ...

# If 'images' directory is empty, remove it so ImageFolder doesn't treat it as a class
if images_dir.exists():
    import shutil
    if not any(images_dir.glob("*")):
        print("Removing empty 'val/images' directory to avoid ImageFolder error.")
        shutil.rmtree(images_dir)

# Validation loader
val_transforms = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),
    T.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])

val_dataset = torchvision.datasets.ImageFolder(
    root=os.path.join(DATA_ROOT, "val"),
    transform=val_transforms,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

num_classes = len(val_dataset.classes)
print("Val samples:", len(val_dataset), "| Classes:", num_classes)

Tiny-ImageNet already present.
val/ already organized.
Val samples: 10000 | Classes: 200


# =====================================================================
# üìå CELL 4 ‚Äî Accuracy + latency utilities
# =====================================================================


In [10]:
def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


@torch.no_grad()
def evaluate_accuracy(model, loader, max_batches=None):
    model.eval()
    top1 = top5 = total = 0
    use_half = next(model.parameters()).dtype == torch.float16

    for idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        if use_half:
            x = x.half()

        out = model(x)
        _, p1 = out.topk(1, dim=1)
        _, p5 = out.topk(5, dim=1)

        total += y.size(0)
        top1 += (p1.squeeze() == y).sum().item()
        top5 += (p5 == y.unsqueeze(1)).any(dim=1).sum().item()

        if max_batches and (idx + 1) >= max_batches:
            break

    return top1 / total * 100.0, top5 / total * 100.0


@torch.no_grad()
def measure_latency(model, input_shape=(1,3,224,224), warmup=10, iters=50):
    model.eval()
    dummy = torch.randn(input_shape, device=device)
    if next(model.parameters()).dtype == torch.float16:
        dummy = dummy.half()

    for _ in range(warmup):
        model(dummy)
    if device.type == "cuda":
        torch.cuda.synchronize()

    times = []
    for _ in range(iters):
        if device.type == "cuda":
            torch.cuda.synchronize()
        t0 = time.time()
        model(dummy)
        if device.type == "cuda":
            torch.cuda.synchronize()
        times.append((time.time() - t0) * 1000)

    times_s = sorted(times)
    return {
        "mean_ms": mean(times),
        "p50_ms": times_s[int(0.5*len(times))],
        "p95_ms": times_s[int(0.95*len(times))-1],
        "p99_ms": times_s[int(0.99*len(times))-1],
    }


# =====================================================================
# üìå CELL 5 ‚Äî FlashAttention-2 wrapper for timm ViT
# =====================================================================


In [17]:
class FlashAttentionVit(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.qkv = base.qkv
        self.proj = base.proj
        self.num_heads = base.num_heads
        self.scale = base.scale
        self.attn_drop = getattr(base, "attn_drop", nn.Identity())
        self.proj_drop = getattr(base, "proj_drop", nn.Identity())

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C//self.num_heads)
        qkv = qkv.permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q.permute(0,2,1,3)
        k = k.permute(0,2,1,3)
        v = v.permute(0,2,1,3)

        q = q.half()
        k = k.half()
        v = v.half()

        out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
        out = out.permute(0,2,1,3).reshape(B,N,C)
        out = self.proj(out)
        return self.proj_drop(out)


def apply_flash_attention_to_vit(model):
    # Guard: require CUDA + Ampere (SM80) or newer, and flash_attn installed
    if device.type != "cuda":
        print("‚ö†Ô∏è No CUDA device ‚Äî skipping FlashAttention.")
        return model

    if not FLASH_ATTENTION_AVAILABLE:
        print("‚ö†Ô∏è FA2 not available ‚Äî skipping.")
        return model

    try:
        major, _ = torch.cuda.get_device_capability()
    except Exception:
        major = 0

    if major < 8:
        print("‚ö†Ô∏è FlashAttention requires Ampere (SM80) or newer GPU ‚Äî skipping.")
        return model

    model = model.to(device).half()
    replaced = 0

    def recurse(m):
        nonlocal replaced
        for name, child in list(m.named_children()):
            if child.__class__.__name__ == "Attention":
                setattr(m, name, FlashAttentionVit(child))
                replaced += 1
            else:
                recurse(child)

    recurse(model)
    print(f"‚ú® FlashAttention integrated into {replaced} blocks.")
    return model


# =====================================================================
# üìå CELL 6 ‚Äî bitsandbytes quantization + model builders
# =====================================================================


In [18]:
import bitsandbytes as bnb

def quantize_linear_8bit(m: nn.Linear):
    q = bnb.nn.Linear8bitLt(m.in_features, m.out_features, bias=m.bias is not None)
    q.weight.data = m.weight.data.clone()
    if m.bias is not None:
        q.bias.data = m.bias.data.clone()
    return q

def quantize_linear_4bit(m: nn.Linear):
    q = bnb.nn.Linear4bit(
        m.in_features, m.out_features,
        bias=m.bias is not None,
        quant_type="nf4",
        compute_dtype=torch.float16,
    )
    q.weight.data = m.weight.data.clone()
    if m.bias is not None:
        q.bias.data = m.bias.data.clone()
    return q

def _quantize_block(block, bits):
    for name, child in list(block.named_children()):
        if isinstance(child, nn.Linear):
            new = quantize_linear_4bit(child) if bits==4 else quantize_linear_8bit(child)
            setattr(block, name, new)
        else:
            _quantize_block(child, bits)

def apply_quantization_to_vit(model, bits):
    for blk in model.blocks:
        _quantize_block(blk, bits)
    print(f"üßä Quantized ViT blocks to {bits}-bit.")
    return model

def build_vit_baseline(use_half=False):
    m = timm.create_model("vit_large_patch16_224", pretrained=True, num_classes=num_classes)
    m.to(device)
    return m.half() if use_half else m

def build_vit_quantized(bits, use_fa2):
    m = build_vit_baseline(use_half=use_fa2)
    m = apply_quantization_to_vit(m, bits)
    if use_fa2:
        m = apply_flash_attention_to_vit(m)
    return m.to(device)

def get_model_registry():
    return {
        "vit_fp32_baseline": {
            "desc": "FP32 baseline",
            "builder": lambda: build_vit_baseline(use_half=False),
            "bits": None,
            "fa2": False,
        },
        "vit_4bit_fa2": {
            "desc": "4-bit + FlashAttention-2",
            "builder": lambda: build_vit_quantized(4, True),
            "bits": 4,
            "fa2": True,
        },
        "vit_8bit_fa2": {
            "desc": "8-bit + FlashAttention-2",
            "builder": lambda: build_vit_quantized(8, True),
            "bits": 8,
            "fa2": True,
        },
        "vit_4bit_sdpa": {
            "desc": "4-bit SDPA",
            "builder": lambda: build_vit_quantized(4, False),
            "bits": 4,
            "fa2": False,
        },
        "vit_8bit_sdpa": {
            "desc": "8-bit SDPA",
            "builder": lambda: build_vit_quantized(8, False),
            "bits": 8,
            "fa2": False,
        },
    }


# =====================================================================
# üìå CELL 7 ‚Äî Stage 1 Benchmark (ALL 5 MODELS)
# =====================================================================


In [19]:
registry = get_model_registry()
results = []

for name, cfg in registry.items():
    print("="*80)
    print("MODEL:", name, "|", cfg["desc"])
    print("="*80)

    model = cfg["builder"]()
    print("Params:", count_parameters(model))

    top1, top5 = evaluate_accuracy(model, val_loader, MAX_VAL_BATCHES)
    print(f"Accuracy ‚Äî Top-1: {top1:.2f}% | Top-5: {top5:.2f}%")

    lat = measure_latency(model)
    print("Latency (ms):", lat)

    results.append({
        "model": name,
        "desc": cfg["desc"],
        "bits": cfg["bits"],
        "fa2": cfg["fa2"],
        "top1": top1,
        "top5": top5,
        "lat_mean_ms": lat["mean_ms"],
        "lat_p50": lat["p50_ms"],
        "lat_p95": lat["p95_ms"],
        "lat_p99": lat["p99_ms"],
    })

    del model
    torch.cuda.empty_cache()

df = pd.DataFrame(results)
df


MODEL: vit_fp32_baseline | FP32 baseline
Params: 303506632
Accuracy ‚Äî Top-1: 0.31% | Top-5: 1.41%
Latency (ms): {'mean_ms': 58.863863945007324, 'p50_ms': 56.79464340209961, 'p95_ms': 66.55097007751465, 'p99_ms': 71.10881805419922}
MODEL: vit_4bit_fa2 | 4-bit + FlashAttention-2
üßä Quantized ViT blocks to 4-bit.
‚ö†Ô∏è FlashAttention requires Ampere (SM80) or newer GPU ‚Äî skipping.
Params: 152511688
Accuracy ‚Äî Top-1: 0.78% | Top-5: 3.59%
Latency (ms): {'mean_ms': 35.03471851348877, 'p50_ms': 34.30628776550293, 'p95_ms': 38.0706787109375, 'p99_ms': 44.88730430603027}
MODEL: vit_8bit_fa2 | 8-bit + FlashAttention-2
üßä Quantized ViT blocks to 8-bit.
‚ö†Ô∏è FlashAttention requires Ampere (SM80) or newer GPU ‚Äî skipping.
Params: 303506632
Accuracy ‚Äî Top-1: 2.50% | Top-5: 6.56%
Latency (ms): {'mean_ms': 92.11071014404297, 'p50_ms': 68.52865219116211, 'p95_ms': 220.32570838928223, 'p99_ms': 239.84813690185547}
MODEL: vit_4bit_sdpa | 4-bit SDPA
üßä Quantized ViT blocks to 4-bit.
Para



Accuracy ‚Äî Top-1: 0.00% | Top-5: 1.25%
Latency (ms): {'mean_ms': 60.82291126251221, 'p50_ms': 60.050010681152344, 'p95_ms': 66.54858589172363, 'p99_ms': 70.20211219787598}


Unnamed: 0,model,desc,bits,fa2,top1,top5,lat_mean_ms,lat_p50,lat_p95,lat_p99
0,vit_fp32_baseline,FP32 baseline,,False,0.3125,1.40625,58.863864,56.794643,66.55097,71.108818
1,vit_4bit_fa2,4-bit + FlashAttention-2,4.0,True,0.78125,3.59375,35.034719,34.306288,38.070679,44.887304
2,vit_8bit_fa2,8-bit + FlashAttention-2,8.0,True,2.5,6.5625,92.11071,68.528652,220.325708,239.848137
3,vit_4bit_sdpa,4-bit SDPA,4.0,False,3.59375,9.21875,42.054234,40.718079,48.508883,53.93815
4,vit_8bit_sdpa,8-bit SDPA,8.0,False,0.0,1.25,60.822911,60.050011,66.548586,70.202112


# =====================================================================
# üìå CELL 8 ‚Äî Save results CSV
# =====================================================================


In [20]:
csv_path = "/content/vit_quant_fa2_stage1_results.csv"
df.to_csv(csv_path, index=False)
print("Saved:", csv_path)


Saved: /content/vit_quant_fa2_stage1_results.csv
