In [1]:
# Cell 1: Setup
import sys
sys.path.append("..")

import torch
from medjepa.models.lejepa import LeJEPA
from medjepa.utils.device import get_device_info, get_device

In [2]:
# Cell 2: Check your machine
get_device_info()
device = get_device()

DEVICE INFORMATION
PyTorch version: 2.10.0+cpu
CUDA available: False
MPS available: False
Using CPU (no GPU detected)
Selected device: cpu
Using CPU (no GPU detected)


In [3]:
# Cell 3: Create a small model for testing
# Using smaller dimensions so it runs fast on CPU
model = LeJEPA(
    image_size=224,
    patch_size=16,
    embed_dim=384,        # Smaller than default 768
    encoder_depth=6,      # Fewer layers than default 12
    encoder_heads=6,
    predictor_dim=192,
    predictor_depth=3,
    predictor_heads=3,
    mask_ratio=0.75,
    lambda_reg=1.0,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print(f"That's {total_params / 1e6:.1f} million parameters")

Total parameters: 12,538,944
That's 12.5 million parameters


In [4]:
# Cell 4: Test forward pass with fake data
# Create fake batch of 4 images
fake_images = torch.randn(4, 3, 224, 224).to(device)

# Run forward pass
model.train()
losses = model(fake_images)

print(f"Total loss: {losses['total_loss'].item():.4f}")
print(f"Prediction loss: {losses['prediction_loss'].item():.4f}")
print(f"Regularization loss: {losses['regularization_loss'].item():.4f}")
print("\nForward pass successful!")

Total loss: 0.0083
Prediction loss: 0.0052
Regularization loss: 0.0031

Forward pass successful!


In [5]:
# Cell 5: Test encoding (what you'd use for downstream tasks)
model.eval()
embeddings = model.encode(fake_images)
print(f"Embedding shape: {embeddings.shape}")  # Should be (4, 384)
print(f"Each image is represented by a vector of {embeddings.shape[1]} numbers")

Embedding shape: torch.Size([4, 384])
Each image is represented by a vector of 384 numbers


In [6]:
# Cell 6: Test that gradients flow (model can learn)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

losses = model(fake_images)
losses["total_loss"].backward()  # Compute gradients
optimizer.step()                  # Update weights
optimizer.zero_grad()             # Reset gradients

print("Backward pass successful! Model can learn.")

Backward pass successful! Model can learn.
