# V-JEPA 2 Demo Notebook

This tutorial provides an example of how to load the V-JEPA 2 model in vanilla PyTorch and HuggingFace, extract a video embedding, and then predict an action class. For more details about the paper and model weights, please see https://github.com/facebookresearch/vjepa2.

First, let's import the necessary libraries and load the necessary functions for this tutorial.

In [19]:
import json
import sys
import os
# Get the parent directory of the current working directory (i.e., the project root)
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
import subprocess

import numpy as np
import torch
import torch.nn.functional as F
from decord import VideoReader
from transformers import AutoVideoProcessor, AutoModel

import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms
from src.models.attentive_pooler import AttentiveClassifier
from src.models.vision_transformer import vit_giant_xformers_rope

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

def load_pretrained_vjepa_pt_weights(model, pretrained_weights):
    # Load weights of the VJEPA2 encoder
    # The PyTorch state_dict is already preprocessed to have the right key names
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))


def load_pretrained_vjepa_classifier_weights(model, pretrained_weights):
    # Load weights of the VJEPA2 classifier
    # The PyTorch state_dict is already preprocessed to have the right key names
    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["classifiers"][0]
    pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
    msg = model.load_state_dict(pretrained_dict, strict=False)
    print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))


def build_pt_video_transform(img_size):
    short_side_size = int(256.0 / 224 * img_size)
    # Eval transform has no random cropping nor flip
    eval_transform = video_transforms.Compose(
        [
            video_transforms.Resize(short_side_size, interpolation="bilinear"),
            video_transforms.CenterCrop(size=(img_size, img_size)),
            volume_transforms.ClipToTensor(),
            video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        ]
    )
    return eval_transform


def get_video():
    vr = VideoReader("sample_video.mp4")
    # choosing some frames here, you can define more complex sampling strategy
    frame_idx = np.arange(0, 128, 2)
    video = vr.get_batch(frame_idx).asnumpy()
    return video


def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):
    # Run a sample inference with VJEPA
    with torch.inference_mode():
        # Read and pre-process the image
        video = get_video()  # T x H x W x C
        video = torch.from_numpy(video).permute(0, 3, 1, 2)  # T x C x H x W
        x_pt = pt_transform(video).unsqueeze(0)
        x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"]
        # Extract the patch-wise features from the last layer
        out_patch_features_pt = model_pt(x_pt)
        out_patch_features_hf = model_hf.get_vision_features(x_hf)

    return out_patch_features_hf, out_patch_features_pt


def get_vjepa_video_classification_results(classifier, out_patch_features_pt):
    SOMETHING_SOMETHING_V2_CLASSES = json.load(open("ssv2_classes.json", "r"))

    with torch.inference_mode():
        out_classifier = classifier(out_patch_features_pt)

    print(f"Classifier output shape: {out_classifier.shape}")

    print("Top 5 predicted class names:")
    top5_indices = out_classifier.topk(5).indices[0]
    top5_probs = F.softmax(out_classifier.topk(5).values[0]) * 100.0  # convert to percentage
    for idx, prob in zip(top5_indices, top5_probs):
        str_idx = str(idx.item())
        print(f"{SOMETHING_SOMETHING_V2_CLASSES[str_idx]} ({prob}%)")

    return

Next, let's download a sample video to the local repository. If the video is already downloaded, the code will skip this step. Likewise, let's download a mapping for the action recognition classes used in Something-Something V2, so we can interpret the predicted action class from our model.

In [20]:
sample_video_path = "sample_video.mp4"
# Download the video if not yet downloaded to local path
if not os.path.exists(sample_video_path):
    video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
    command = ["wget", video_url, "-O", sample_video_path]
    subprocess.run(command)
    print("Downloading video")

# Download SSV2 classes if not already present
ssv2_classes_path = "ssv2_classes.json"
if not os.path.exists(ssv2_classes_path):
    command = [
        "wget",
        "https://huggingface.co/datasets/huggingface/label-files/resolve/d79675f2d50a7b1ecf98923d42c30526a51818e2/"
        "something-something-v2-id2label.json",
        "-O",
        "ssv2_classes.json",
    ]
    subprocess.run(command)
    print("Downloading SSV2 classes")

Now, let's load the models in both vanilla Pytorch as well as through the HuggingFace API. Note that HuggingFace API will automatically load the weights through `from_pretrained()`, so there is no additional download required for HuggingFace.

To download the PyTorch model weights, use wget and specify your preferred target path. See the README for the model weight URLs.
E.g. 
```
wget https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt -P YOUR_DIR
```
Then update `pt_model_path` with `YOUR_DIR/vitg-384.pt`. Also note that you have the option to use `torch.hub.load`.

In [None]:


# HuggingFace model repo name
hf_model_name = (
    "facebook/vjepa2-vitg-fpc64-384"  # Replace with your favored model, e.g. facebook/vjepa2-vitg-fpc64-384
)
# Path to local PyTorch weights
pt_model_path = "../models/vitg-384.pt"

# Initialize the HuggingFace model, load pretrained weights
model_hf = AutoModel.from_pretrained(hf_model_name)
#model_hf.cuda().eval()
model_hf.eval()

# Build HuggingFace preprocessing transform
hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)
img_size = hf_transform.crop_size["height"]  # E.g. 384, 256, etc.

# Initialize the PyTorch model, load pretrained weights
model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64)
#model_pt.cuda().eval()
model_pt.eval()
load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)

### Can also use torch.hub to load the model
# model_pt, _ = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384')
# model_pt.cuda().eval()

# Build PyTorch preprocessing transform
pt_video_transform = build_pt_video_transform(img_size=img_size)

Pretrained weights found at ../models/vitg-384.pt and loaded with msg: <All keys matched successfully>


Now we can run the encoder on the video to get the patch-wise features from the last layer of the encoder. To verify that the HuggingFace and PyTorch models are equivalent, we will compare the values of the features.

In [23]:
# Inference on video to get the patch-wise features
out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(
    model_hf, model_pt, hf_transform, pt_video_transform
)

print(
    f"""
    Inference results on video:
    HuggingFace output shape: {out_patch_features_hf.shape}
    PyTorch output shape:     {out_patch_features_pt.shape}
    Absolute difference sum:  {torch.abs(out_patch_features_pt - out_patch_features_hf).sum():.6f}
    Close: {torch.allclose(out_patch_features_pt, out_patch_features_hf, atol=1e-3, rtol=1e-3)}
    """
)


    Inference results on video:
    HuggingFace output shape: torch.Size([1, 18432, 1408])
    PyTorch output shape:     torch.Size([1, 18432, 1408])
    Absolute difference sum:  1688.154175
    Close: False
    


Great! Now we know that the features from both models are equivalent. Now let's run a pretrained attentive probe classifier on top of the extracted features, to predict an action class for the video. Let's use the Something-Something V2 probe. Note that the repository also includes attentive probe weights for other evaluations such as EPIC-KITCHENS-100 and Diving48.

To download the attentive probe weights, use wget and specify your preferred target path. E.g. `wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P YOUR_DIR`

Then update `classifier_model_path` with `YOUR_DIR/ssv2-vitg-384-64x2x3.pt`.

In [24]:
# Initialize the classifier
classifier_model_path = "../models/ssv2-vitg-384-64x2x3.pt"
classifier = (
    AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).eval()
)
load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)

# Get classification results
get_vjepa_video_classification_results(classifier, out_patch_features_pt)

Pretrained weights found at ../models/ssv2-vitg-384-64x2x3.pt and loaded with msg: <All keys matched successfully>


  self.gen = func(*args, **kwds)


Classifier output shape: torch.Size([1, 174])
Top 5 predicted class names:
Putting [something] into [something] (44.937313079833984%)
Stuffing [something] into [something] (28.10003089904785%)
Putting [something] onto [something] (14.435855865478516%)
Failing to put [something] into [something] because [something] does not fit (7.636945724487305%)
Putting [number of] [something] onto [something] (4.889848232269287%)


  top5_probs = F.softmax(out_classifier.topk(5).values[0]) * 100.0  # convert to percentage


The video features a man putting a bowling ball into a tube, so the predicted action of "Putting [something] into [something]" makes sense!

This concludes the tutorial. Please see the README and paper for full details on the capabilities of V-JEPA 2 :)