In [None]:
# Clone the DiT repository
!git clone https://github.com/facebookresearch/DiT.git
%cd DiT

# Install xformers
!pip install -q xformers
!pip install --upgrade torchvision


Cloning into 'DiT'...
remote: Enumerating objects: 102, done.[K
remote: Counting objects: 100% (82/82), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 102 (delta 57), reused 33 (delta 33), pack-reused 20 (from 1)[K
Receiving objects: 100% (102/102), 6.36 MiB | 27.04 MiB/s, done.
Resolving deltas: 100% (57/57), done.
/kaggle/working/DiT
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.3/43.3 MB[0m [31m34.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m766.7/766.7 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m87.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[

# Efficient Transformer Attention Implementation with XFormers

This code implements an efficient attention mechanism using the XFormers library to optimize memory usage in transformer models. Here's what it does:

## Imports and Dependencies
- Imports PyTorch and neural network modules
- Imports DiT model variants from a local `models` module
- Attempts to import XFormers' memory-efficient attention operation (raising an error if not installed)

## XFormersAttention Class
The class implements a drop-in replacement for standard attention that:
- Maintains the same API as regular attention modules
- Takes hidden size, number of heads, and dropout rate parameters
- Projects input to query, key, and value representations
- Applies proper scaling to query vectors before attention computation
- Uses XFormers' memory-efficient attention implementation
- Projects attention outputs back to the original dimensionality

## Key Optimizations
- Uses memory-efficient attention for better GPU utilization
- Includes critical scaling factor for stable training
- Maintains dimensional transformations needed for compatibility with DiT models

The comment highlights a previously missing scaling operation that's now properly applied to query vectors before the attention mechanism.

In [None]:
import torch
import torch.nn as nn
import time

from models import DiT, DiT_B_4  , DiT_XL_4

# Import xformers memory efficient attention
try:
    from xformers.ops import memory_efficient_attention
except ImportError:
    raise ImportError("xformers is not installed.")

# Define the XFormersAttention module that mimics the original Attention API
class XFormersAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = self.head_dim ** -0.5  # Critical scaling factor
        self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
        self.out_proj = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        # Apply scaling BEFORE attention
        q = q * self.scale  # ← This was missing

        attn_output = memory_efficient_attention(q, k, v)
        attn_output = attn_output.transpose(1, 2).reshape(B, N, C)
        return self.out_proj(attn_output)

## Cell to replace the attention model with xformers attetnion block


In [None]:
def replace_attention_with_xformers(model):
    """
    Replaces the attention module in each DiTBlock of the model with the XFormersAttention version.
    """
    for block in model.blocks:
        hidden_size = block.norm1.normalized_shape[0]
        num_heads = block.attn.num_heads  # Using the original number of heads
        # Replace with our XFormersAttention (dropout is set to 0.0; adjust if needed)
        block.attn = XFormersAttention(hidden_size, num_heads, dropout=0.0)


## Benchmarking Function to compare basline model and Xformers model

In [None]:
def compare_attention_speed(model, device='cuda', num_images=50, num_warmup=10, num_repeats=100):
    model = model.to(device).eval()
    x = torch.randn(num_images, model.in_channels, 32, 32, device=device)
    t = torch.randint(0, 1000, (num_images,), device=device)
    y = torch.randint(0, 1000, (num_images,), device=device)

    # Warm-up
    with torch.no_grad(), torch.cuda.amp.autocast():
        for _ in range(num_warmup):
            _ = model(x, t, y)

    # Timed runs
    torch.cuda.synchronize()
    start = time.time()
    with torch.no_grad(), torch.cuda.amp.autocast():
        for _ in range(num_repeats):
            _ = model(x, t, y)
    torch.cuda.synchronize()

    elapsed = (time.time() - start) / num_repeats
    print(f"Avg time per {num_images} images: {elapsed:.5f}s")
    return elapsed

In [None]:
import copy

model_baseline = DiT_XL_4()
# Create a deep copy for the xformers-modified model and replace its attention modules
model_xformers=deep.copy(model_basline)
replace_attention_with_xformers(model_xformers):

time_baseline = compare_attention_speed(model_baseline, num_images=50)
time_xformers = compare_attention_speed(model_xformers, num_images=50)



device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("Baseline model timing:")
time_baseline = compare_attention_speed(model_baseline, device=device, num_images=50)

print("\nXFormers-based model timing:")
time_xformers = compare_attention_speed(model_xformers, device=device, num_images=50)

speedup = time_baseline / time_xformers if time_xformers > 0 else float('inf')
print(f"\nSpeedup: {speedup:.2f}x faster with xformers attention.")



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Avg time per 50 images: 0.25898s
Avg time per 50 images: 0.06053s
Baseline model timing:
Avg time per 50 images: 0.24864s

XFormers-based model timing:
Avg time per 50 images: 0.05422s

Speedup: 4.59x faster with xformers attention.


In [None]:
import copy

model_baseline = DiT_XL_4()

# Create a deep copy for the xformers-modified model and replace its attention modules
time_baseline = compare_attention_speed(model_baseline, num_images=50)
time_xformers = compare_attention_speed(model_xformers, num_images=50)



device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("Baseline model timing:")
time_baseline = compare_attention_speed(model_baseline, device=device, num_images=50)

print("\nXFormers-based model timing:")
time_xformers = compare_attention_speed(model_xformers, device=device, num_images=50)

speedup = time_baseline / time_xformers if time_xformers > 0 else float('inf')
print(f"\nSpeedup: {speedup:.2f}x faster with xformers attention.")



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Avg time per 50 images: 0.25127s
Avg time per 50 images: 0.06060s
Baseline model timing:
Avg time per 50 images: 0.25291s

XFormers-based model timing:
Avg time per 50 images: 0.05463s

Speedup: 4.63x faster with xformers attention.


In [None]:

import numpy as np

# Test across different batch sizes
batch_sizes = [10, 50, 100, 200]
times_baseline, times_xformers = [], []

for bs in batch_sizes:
    print(f"\nTesting batch size: {bs}")
    times_baseline.append(compare_attention_speed(model_baseline, device=device, num_images=bs))
    times_xformers.append(compare_attention_speed(model_xformers, device=device, num_images=bs))

# Create DataFrame
df = pd.DataFrame({
    "Batch Size": np.tile(batch_sizes, 2),
    "Time (s)": times_baseline + times_xformers,
    "Model": ["Baseline"]*len(batch_sizes) + ["XFormers"]*len(batch_sizes)
})

# Interactive line plot
fig = px.line(
    df,
    x="Batch Size",
    y="Time (s)",
    color="Model",
    title="Inference Time vs Batch Size",
    markers=True,
    labels={"Time (s)": "Time (seconds)"},
    hover_data={"Time (s)": ":.4f"}
)

fig.update_layout(
    hovermode="x unified",
    xaxis_title="Number of Images",
    yaxis_title="Time (seconds)",
    template="plotly_white"
)

fig.show()


Testing batch size: 10



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Avg time per 10 images: 0.05103s
Avg time per 10 images: 0.01535s

Testing batch size: 50
Avg time per 50 images: 0.26693s
Avg time per 50 images: 0.05673s

Testing batch size: 100
Avg time per 100 images: 0.45293s
Avg time per 100 images: 0.11553s

Testing batch size: 200
Avg time per 200 images: 0.99126s
Avg time per 200 images: 0.22171s
