In [1]:
# Cell 1: Setup
import sys
sys.path.append("..")

import torch
from medjepa.models.lejepa import LeJEPA
from medjepa.utils.device import get_device_info, get_device

In [2]:
# Cell 2: Check your machine
get_device_info()
device = get_device()

DEVICE INFORMATION
PyTorch version: 2.10.0+cpu
CUDA available: False
MPS available: False
Using CPU (no GPU detected)
Selected device: cpu
Using CPU (no GPU detected)


In [3]:
# Cell 3: Create a small model for testing
# Using smaller dimensions so it runs fast on CPU
model = LeJEPA(
    image_size=224,
    patch_size=16,
    embed_dim=384,        # Smaller than default 768
    encoder_depth=6,      # Fewer layers than default 12
    encoder_heads=6,
    predictor_dim=192,
    predictor_depth=3,
    predictor_heads=3,
    mask_ratio=0.75,
    lambda_reg=1.0,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print(f"That's {total_params / 1e6:.1f} million parameters")

Total parameters: 12,538,944
That's 12.5 million parameters


In [4]:
# Cell 4: Test forward pass with fake data
# Create fake batch of 4 images
fake_images = torch.randn(4, 3, 224, 224).to(device)

# Run forward pass
model.train()
losses = model(fake_images)

print(f"Total loss: {losses['total_loss'].item():.4f}")
print(f"Prediction loss: {losses['prediction_loss'].item():.4f}")
print(f"Regularization loss: {losses['regularization_loss'].item():.4f}")
print("\nForward pass successful!")

Total loss: 0.0083
Prediction loss: 0.0052
Regularization loss: 0.0031

Forward pass successful!


In [5]:
# Cell 5: Test encoding (what you'd use for downstream tasks)
model.eval()
embeddings = model.encode(fake_images)
print(f"Embedding shape: {embeddings.shape}")  # Should be (4, 384)
print(f"Each image is represented by a vector of {embeddings.shape[1]} numbers")

Embedding shape: torch.Size([4, 384])
Each image is represented by a vector of 384 numbers


In [6]:
# Cell 6: Test that gradients flow (model can learn)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

losses = model(fake_images)
losses["total_loss"].backward()  # Compute gradients
optimizer.step()                  # Update weights
optimizer.zero_grad()             # Reset gradients

print("Backward pass successful! Model can learn.")

Backward pass successful! Model can learn.


## V-JEPA Model Test (3D Volumes)

Test the V-JEPA architecture used for BraTS and Decathlon 3D data.

In [None]:
# Cell 8: V-JEPA model â€” forward pass test
from medjepa.models.vjepa import VJEPA

vjepa = VJEPA(
    volume_size=(64, 64, 32),   # Smaller for testing
    patch_size=(16, 16, 8),
    in_channels=1,
    embed_dim=384,
    depth=4,
    num_heads=6,
    predictor_dim=192,
    predictor_depth=2,
    mask_ratio=0.75,
    lambda_reg=1.0,
).to(device)

total_params_v = sum(p.numel() for p in vjepa.parameters())
print(f"V-JEPA parameters: {total_params_v:,} ({total_params_v/1e6:.1f}M)")

# Forward pass with fake 3D volume
fake_volume = torch.randn(2, 1, 64, 64, 32).to(device)
vjepa.train()
v_losses = vjepa(fake_volume)
print(f"V-JEPA Total loss:      {v_losses['total_loss'].item():.4f}")
print(f"V-JEPA Prediction loss: {v_losses['prediction_loss'].item():.4f}")
print(f"V-JEPA Reg loss:        {v_losses['regularization_loss'].item():.4f}")

# Test 3D encoding
vjepa.eval()
v_emb = vjepa.encode(fake_volume)
print(f"V-JEPA embedding shape: {v_emb.shape}")  # (2, 384)
print("V-JEPA forward + encode successful!")

## Load Pre-trained Checkpoint

Load a real pre-trained model from `checkpoints/` and verify it works.

In [None]:
# Cell 10: Load pre-trained LeJEPA checkpoint
import os

CHECKPOINT_PATH = "../checkpoints/best_model.pt"

if os.path.exists(CHECKPOINT_PATH):
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
    cfg = ckpt.get("config", {})
    print("Checkpoint keys:", list(ckpt.keys()))
    print("Config:", cfg)

    # Reconstruct model from checkpoint config
    loaded_model = LeJEPA(
        image_size=cfg.get("image_size", 224),
        patch_size=cfg.get("patch_size", 16),
        embed_dim=cfg.get("embed_dim", 768),
        encoder_depth=cfg.get("encoder_depth", 12),
        predictor_depth=cfg.get("predictor_depth", 6),
    )
    loaded_model.load_state_dict(ckpt["model_state_dict"])
    loaded_model = loaded_model.to(device).eval()

    total_params_loaded = sum(p.numel() for p in loaded_model.parameters())
    print(f"\nLoaded model: {total_params_loaded:,} parameters")

    if "epoch" in ckpt:
        print(f"Trained for {ckpt['epoch']} epochs")
    if "loss" in ckpt:
        print(f"Best loss: {ckpt['loss']:.6f}")

    # Quick test with fake image
    test_img = torch.randn(1, 3, cfg.get("image_size", 224), cfg.get("image_size", 224)).to(device)
    emb = loaded_model.encode(test_img)
    print(f"Embedding shape: {emb.shape}")
    print("Pre-trained checkpoint loaded and verified!")
else:
    print(f"No checkpoint at {CHECKPOINT_PATH}")
    print("Run pre-training first: python scripts/run_gpu_full.py")

# Check V-JEPA checkpoint too
VJEPA_CKPT = "../checkpoints/best_vjepa_model.pt"
if os.path.exists(VJEPA_CKPT):
    v_ckpt = torch.load(VJEPA_CKPT, map_location=device, weights_only=False)
    print(f"\nV-JEPA checkpoint also found! Epoch: {v_ckpt.get('epoch')}, Loss: {v_ckpt.get('loss', 'N/A')}")
else:
    print(f"\nNo V-JEPA checkpoint yet at {VJEPA_CKPT}")