# JEPA Training on Google Colab

This notebook trains the JEPA model using the production package structure.
It assumes the repository is cloned to your Google Drive.

In [None]:
# 1. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set path to your repository (CHANGE THIS if needed)
REPO_PATH = '/content/drive/MyDrive/AV-SSL-Optimization-JEPA'
%cd $REPO_PATH

In [None]:
# 2. Install Dependencies & Package
!pip install -e .[dev]
!pip install -r requirements.txt

In [None]:
# 3. Load Configuration
import yaml
import torch
from src.jepa.data import JEPADataset, TubeletDataset, MaskTubelet
from src.jepa.models import JEPAModel
from src.jepa.training import Trainer
from torch.utils.data import DataLoader, random_split

# Load default config
with open('configs/default.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Override config for Colab if needed
config['training']['batch_size'] = 8  # Adjust based on GPU VRAM
config['training']['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {config['training']['device']}")

In [None]:
# 4. Prepare Data
mask_transform = MaskTubelet(
    mask_ratio=config['data']['mask_ratio'],
    patch_size=config['data']['patch_size']
)

# Load full dataset
full_dataset = TubeletDataset(
    manifest_path='videomae/clips_manifest.jsonl',  # Path to your manifest
    tubelet_size=config['data']['tubelet_size'],
    transform=mask_transform
)

# Split train/val
train_size = int(config['data']['train_split'] * len(full_dataset))
val_size = len(full_dataset) - train_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(
    train_ds,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    val_ds,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    num_workers=2
)

print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")

In [None]:
# 5. Initialize Model
model = JEPAModel(
    encoder_name=config['model']['encoder_name'],
    predictor_hidden=config['model']['predictor']['hidden_dim'],
    predictor_dropout=config['model']['predictor']['dropout'],
    freeze_encoder=config['model']['freeze_encoder']
)

device = torch.device(config['training']['device'])
model.to(device)

# Optimizer
optimizer = torch.optim.AdamW(
    model.predictor.parameters(),  # Only optimize predictor
    lr=float(config['training']['lr']),
    weight_decay=float(config['training']['weight_decay'])
)

In [None]:
# 6. Training Loop
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    device=device,
    checkpoint_dir=config['training']['checkpoint_dir']
)

num_epochs = config['training']['epochs']
best_loss = float('inf')

for epoch in range(num_epochs):
    # Train
    train_loss = trainer.train_epoch(train_loader, epoch)
    
    # Validate
    val_loss = trainer.validate_epoch(val_loader, epoch)
    
    # Checkpoint
    is_best = val_loss < best_loss
    if is_best:
        best_loss = val_loss
        
    if (epoch + 1) % config['training']['checkpoint_every'] == 0 or is_best:
        trainer.save_checkpoint(epoch, val_loss, is_best, config)