# V-JEPA 2 Demo Notebook for AMD ROCm GPUs

This tutorial demonstrates how to run V-JEPA 2 on AMD GPUs using ROCm. It covers:
- Loading models in vanilla PyTorch and HuggingFace
- Extracting video embeddings
- Predicting action classes
- AMD GPU-specific optimizations

For more details about the model, see https://github.com/facebookresearch/vjepa2.

**AMD GPU Environment:**
- ROCm support via PyTorch CUDA API (ROCm provides CUDA compatibility)
- Multi-GPU support (8 AMD GPUs detected)
- Optimized for AMD Instinct MI210/MI250/MI300X series

## 1. Check AMD GPU Environment

First, let's verify that AMD GPUs are detected and accessible.

In [1]:
import torch
import os

print("=" * 60)
print("AMD GPU Environment Check")
print("=" * 60)
print(f"PyTorch version: {torch.__version__}")
print(f"ROCm/CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        print(f"  GPU {i}: {props.name}")
        print(f"    Memory: {props.total_memory / 1024**3:.1f} GB")
        print(f"    Compute capability: {props.major}.{props.minor}")
else:
    print("Warning: No GPUs detected. Running on CPU will be very slow.")

# Check ROCm environment variables
print("\nROCm environment variables:")
rocm_vars = ['HSA_OVERRIDE_GFX_VERSION', 'PYTORCH_ROCM_ARCH', 'HIP_VISIBLE_DEVICES']
for var in rocm_vars:
    value = os.environ.get(var, 'not set')
    print(f"  {var}: {value}")

print("=" * 60)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

AMD GPU Environment Check
PyTorch version: 2.5.1+rocm6.2
ROCm/CUDA available: True
Number of GPUs: 8
  GPU 0: AMD Instinct MI300X VF
    Memory: 191.7 GB
    Compute capability: 9.4
  GPU 1: AMD Instinct MI300X VF
    Memory: 191.7 GB
    Compute capability: 9.4
  GPU 2: AMD Instinct MI300X VF
    Memory: 191.7 GB
    Compute capability: 9.4
  GPU 3: AMD Instinct MI300X VF
    Memory: 191.7 GB
    Compute capability: 9.4
  GPU 4: AMD Instinct MI300X VF
    Memory: 191.7 GB
    Compute capability: 9.4
  GPU 5: AMD Instinct MI300X VF
    Memory: 191.7 GB
    Compute capability: 9.4
  GPU 6: AMD Instinct MI300X VF
    Memory: 191.7 GB
    Compute capability: 9.4
  GPU 7: AMD Instinct MI300X VF
    Memory: 191.7 GB
    Compute capability: 9.4

ROCm environment variables:
  HSA_OVERRIDE_GFX_VERSION: not set
  PYTORCH_ROCM_ARCH: not set
  HIP_VISIBLE_DEVICES: not set

Using device: cuda


## 2. Import Libraries and Define Helper Functions

Import all necessary libraries and define the helper functions for loading models and processing videos.

In [2]:
import json
import subprocess
import time

import numpy as np
import torch
import torch.nn.functional as F
from decord import VideoReader
from transformers import AutoVideoProcessor, AutoModel

import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms
from src.models.attentive_pooler import AttentiveClassifier
from src.models.vision_transformer import vit_giant_xformers_rope

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

def load_pretrained_vjepa_pt_weights(model, pretrained_weights):
    """Load weights of the VJEPA2 encoder."""
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print(f"Pretrained weights found at {pretrained_weights} and loaded with msg: {msg}")


def load_pretrained_vjepa_classifier_weights(model, pretrained_weights):
    """Load weights of the VJEPA2 classifier."""
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["classifiers"][0]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print(f"Pretrained weights found at {pretrained_weights} and loaded with msg: {msg}")


def build_pt_video_transform(img_size):
    """Build PyTorch preprocessing transform for videos."""
    short_side_size = int(256.0 / 224 * img_size)
    eval_transform = video_transforms.Compose(
        [
            video_transforms.Resize(short_side_size, interpolation="bilinear"),
            video_transforms.CenterCrop(size=(img_size, img_size)),
            volume_transforms.ClipToTensor(),
            video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        ]
    )
    return eval_transform


def get_video(video_path="sample_video.mp4"):
    """Load video frames using decord."""
    vr = VideoReader(video_path)
    # Sample frames (every 2nd frame for 64 frames total)
    frame_idx = np.arange(0, 128, 2)
    video = vr.get_batch(frame_idx).asnumpy()
    return video


def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform, device):
    """Run inference with VJEPA models on video."""
    with torch.inference_mode():
        # Read and pre-process the video
        video = get_video()  # T x H x W x C
        video = torch.from_numpy(video).permute(0, 3, 1, 2)  # T x C x H x W
        x_pt = pt_transform(video).to(device).unsqueeze(0)
        x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to(device)
        
        # Warm-up run (important for AMD GPUs due to kernel compilation)
        if device.type == 'cuda':
            print("Warming up AMD GPUs (kernel compilation)...")
            _ = model_pt(x_pt)
            torch.cuda.synchronize()
        
        # Timed inference
        start_time = time.time()
        out_patch_features_pt = model_pt(x_pt)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        pt_time = time.time() - start_time
        
        start_time = time.time()
        out_patch_features_hf = model_hf.get_vision_features(x_hf)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        hf_time = time.time() - start_time
        
        print(f"PyTorch model inference time: {pt_time * 1000:.2f} ms")
        print(f"HuggingFace model inference time: {hf_time * 1000:.2f} ms")

    return out_patch_features_hf, out_patch_features_pt


def get_vjepa_video_classification_results(classifier, out_patch_features_pt, device):
    """Run classification on extracted features."""
    SOMETHING_SOMETHING_V2_CLASSES = json.load(open("ssv2_classes.json", "r"))

    with torch.inference_mode():
        out_classifier = classifier(out_patch_features_pt)

    print(f"Classifier output shape: {out_classifier.shape}")
    print("\nTop 5 predicted class names:")
    top5_indices = out_classifier.topk(5).indices[0]
    top5_probs = F.softmax(out_classifier.topk(5).values[0], dim=0) * 100.0
    for idx, prob in zip(top5_indices, top5_probs):
        str_idx = str(idx.item())
        print(f"  {SOMETHING_SOMETHING_V2_CLASSES[str_idx]}: {prob:.2f}%")

print("Helper functions loaded successfully!")

Helper functions loaded successfully!




## 3. Download Sample Video and Labels

Download a sample video and the Something-Something V2 action class labels.

In [3]:
sample_video_path = "sample_video.mp4"

# Download the video if not yet downloaded
if not os.path.exists(sample_video_path):
    print("Downloading sample video...")
    video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
    command = ["wget", video_url, "-O", sample_video_path]
    subprocess.run(command, check=True)
    print("Video downloaded successfully!")
else:
    print(f"Video already exists at {sample_video_path}")

# Download SSV2 classes if not already present
ssv2_classes_path = "ssv2_classes.json"
if not os.path.exists(ssv2_classes_path):
    print("Downloading SSV2 class labels...")
    command = [
        "wget",
        "https://huggingface.co/datasets/huggingface/label-files/resolve/d79675f2d50a7b1ecf98923d42c30526a51818e2/"
        "something-something-v2-id2label.json",
        "-O",
        "ssv2_classes.json",
    ]
    subprocess.run(command, check=True)
    print("SSV2 classes downloaded successfully!")
else:
    print(f"SSV2 classes already exist at {ssv2_classes_path}")

Video already exists at sample_video.mp4
SSV2 classes already exist at ssv2_classes.json


## 4. Load Models (PyTorch & HuggingFace)

Load the V-JEPA 2 models using both PyTorch and HuggingFace.


To manually download:
```bash
wget https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt -P ./weights/
```

**AMD GPU Optimization:**
- Models are loaded to AMD GPU using `.to(device)`

In [6]:
# # Model configuration
hf_model_name = "facebook/vjepa2-vitg-fpc64-384"  # Options: vitl, vith, vitg with 256 or 384 resolution

print("Loading HuggingFace model...")
model_hf = AutoModel.from_pretrained(hf_model_name)
model_hf.to(device).eval()
print(f"HuggingFace model loaded on {device}")

# Build HuggingFace preprocessing transform
hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)
img_size = hf_transform.crop_size["height"]  # E.g. 384 or 256
print(f"Image size: {img_size}x{img_size}")

Loading HuggingFace model...
HuggingFace model loaded on cuda
Image size: 384x384


In [None]:
# Manually download the model weights
!wget https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt -P ./weights/

In [None]:
# Load model from local weights
pt_model_path = "./weights/vitg-384.pt"  # Update this path
model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64)
model_pt.to(device).eval()
load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)

# Build PyTorch preprocessing transform
pt_video_transform = build_pt_video_transform(img_size=img_size)
print("\nModels loaded successfully!")
num_params = sum(p.numel() for p in model_pt.parameters())
print(f"Model parameters: {num_params / 1e6:.1f}M")

Pretrained weights found at ./weights/vitg-384.pt and loaded with msg: <All keys matched successfully>

Models loaded successfully!
Model parameters: 1012.2M


## 5. Run Video Inference

Extract patch-wise features from the video using both models and verify they produce equivalent results.

**AMD GPU Performance Notes:**
- First inference includes MIOpen kernel autotuning (may be slow)
- Subsequent inferences will be much faster due to cached kernels
- Inference time is measured with GPU synchronization for accuracy

In [7]:
# Run inference on video to get patch-wise features
print("Running inference on video...\n")
out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(
    model_hf, model_pt, hf_transform, pt_video_transform, device
)

print(f"""
Inference results on video:
  HuggingFace output shape: {out_patch_features_hf.shape}
  PyTorch output shape:     {out_patch_features_pt.shape}
  Absolute difference sum:  {torch.abs(out_patch_features_pt - out_patch_features_hf).sum():.6f}
  Close: {torch.allclose(out_patch_features_pt, out_patch_features_hf, atol=1e-3, rtol=1e-3)}
""")

if torch.allclose(out_patch_features_pt, out_patch_features_hf, atol=1e-3, rtol=1e-3):
    print("✓ Models produce equivalent features!")
else:
    print("⚠ Warning: Models produce different features (this may be expected for different precision)")

Running inference on video...

Warming up AMD GPUs (kernel compilation)...


  self.gen = func(*args, **kwds)


PyTorch model inference time: 1643.84 ms
HuggingFace model inference time: 1603.17 ms

Inference results on video:
  HuggingFace output shape: torch.Size([1, 18432, 1408])
  PyTorch output shape:     torch.Size([1, 18432, 1408])
  Absolute difference sum:  1964.459473
  Close: False



## 6. Run Action Classification

Use a pretrained attentive probe classifier to predict action classes from the extracted features.

To download the attentive probe weights:
```bash
wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P ./weights/
```

Then update `classifier_model_path` below.

In [None]:
!wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P ./weights/

In [8]:
# Path to classifier weights (update this path)
classifier_model_path = "./weights/ssv2-vitg-384-64x2x3.pt"

# Check if classifier weights exist
if not os.path.exists(classifier_model_path):
    print(f"⚠ Classifier weights not found at {classifier_model_path}")
    print("Please download the weights using:")
    print("wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P ./weights/")
else:
    print("Loading attentive probe classifier...")
    classifier = AttentiveClassifier(
        embed_dim=model_pt.embed_dim, 
        num_heads=16, 
        depth=4, 
        num_classes=174
    ).to(device).eval()
    
    load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)
    print("Classifier loaded successfully!\n")
    
    # Get classification results
    get_vjepa_video_classification_results(classifier, out_patch_features_pt, device)

Loading attentive probe classifier...
Pretrained weights found at ./weights/ssv2-vitg-384-64x2x3.pt and loaded with msg: <All keys matched successfully>
Classifier loaded successfully!

Classifier output shape: torch.Size([1, 174])

Top 5 predicted class names:
  Putting [something] into [something]: 44.93%
  Stuffing [something] into [something]: 28.10%
  Putting [something] onto [something]: 14.44%
  Failing to put [something] into [something] because [something] does not fit: 7.64%
  Putting [number of] [something] onto [something]: 4.89%


## 7. AMD GPU Performance Benchmarking (Optional)

Run multiple inference iterations to benchmark AMD GPU performance.

In [None]:
# Benchmark inference performance
num_iterations = 10
print(f"Running benchmark with {num_iterations} iterations...\n")

# Prepare input
video = get_video()
video = torch.from_numpy(video).permute(0, 3, 1, 2)
x_pt = pt_video_transform(video).to(device).unsqueeze(0)

times = []
with torch.inference_mode():
    for i in range(num_iterations):
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        start_time = time.time()
        _ = model_pt(x_pt)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        elapsed = time.time() - start_time
        times.append(elapsed * 1000)  # Convert to ms
        print(f"Iteration {i+1}/{num_iterations}: {elapsed * 1000:.2f} ms")

print(f"""
Benchmark Results:
  Average: {np.mean(times):.2f} ms
  Std Dev: {np.std(times):.2f} ms
  Min:     {np.min(times):.2f} ms
  Max:     {np.max(times):.2f} ms
""")

Running benchmark with 10 iterations...

Iteration 1/10: 1647.77 ms
Iteration 2/10: 1645.19 ms
Iteration 3/10: 1748.78 ms
Iteration 4/10: 1715.32 ms
Iteration 5/10: 1652.44 ms
Iteration 6/10: 1659.66 ms
Iteration 7/10: 1666.28 ms
Iteration 8/10: 1766.66 ms
Iteration 9/10: 1718.87 ms
Iteration 10/10: 1749.62 ms

Benchmark Results:
  Average: 1697.06 ms
  Std Dev: 45.33 ms
  Min:     1645.19 ms
  Max:     1766.66 ms



## 8. Multi-GPU Inference (Optional)

Example of using multiple AMD GPUs for batch processing.

In [None]:
if torch.cuda.device_count() > 1:
    print(f"Using DataParallel with {torch.cuda.device_count()} AMD GPUs\n")
    
    # Wrap model with DataParallel
    model_pt_multi = torch.nn.DataParallel(model_pt)
    
    # Create a batch of videos (duplicate for demo)
    batch_size = torch.cuda.device_count()
    x_batch = x_pt.repeat(batch_size, 1, 1, 1, 1)
    
    print(f"Processing batch of {batch_size} videos...")
    with torch.inference_mode():
        start_time = time.time()
        outputs = model_pt_multi(x_batch)
        torch.cuda.synchronize()
        elapsed = time.time() - start_time
    
    print(f"Batch inference time: {elapsed * 1000:.2f} ms")
    print(f"Per-video time: {elapsed * 1000 / batch_size:.2f} ms")
    print(f"Output shape: {outputs.shape}")
else:
    print("Only 1 GPU available, skipping multi-GPU demo")

Using DataParallel with 8 AMD GPUs

Processing batch of 8 videos...
Batch inference time: 17134.59 ms
Per-video time: 2141.82 ms
Output shape: torch.Size([8, 18432, 1408])


## Conclusion

This notebook includes the following:
- Loading V-JEPA 2 models on AMD GPUs
- Video feature extraction
- Action classification
- Performance benchmarking on AMD ROCm (inference
- Multi-GPU inference

Monitor GPU usage with: `rocm-smi`

Forked from the [V-JEPA 2 repository](https://github.com/facebookresearch/vjepa2).