# JEPA Training on Google Colab

This notebook trains the JEPA model using the production package structure.
It assumes the repository is cloned to your Google Drive.

In [1]:
# 1. Environment Setup (Colab & Local Support)
import sys
import os

# Detect environment
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Running on Google Colab")
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Path to repo in Drive (CHANGE THIS if needed)
    REPO_PATH = '/content/drive/MyDrive/AV-SSL-Optimization-JEPA'
    
    if os.path.exists(REPO_PATH):
        os.chdir(REPO_PATH)
        print(f"üìÇ Working directory set to: {os.getcwd()}")
        
        # Install packages (Colab only)
        print("üì¶ Installing dependencies...")
        !pip install -e .[dev]
        !pip install -r requirements.txt
    else:
        print(f"‚ö†Ô∏è Repo not found at {REPO_PATH}. Please clone it to Drive first.")

else:
    print("Running Locally")
    # If running from 'notebooks/' directory, move up to root
    current_dir = os.getcwd()
    if current_dir.endswith('notebooks'):
        os.chdir('..')
        print(f"Moved up to project root: {os.getcwd()}")
    
    # Add project root to sys.path to find 'src' module
    if os.getcwd() not in sys.path:
        sys.path.append(os.getcwd())
        print("Added project root to sys.path")


Running Locally
Moved up to project root: /Users/shamik/Documents/AV-SSL-Optimization-JEPA
Added project root to sys.path


In [2]:
# 3. Load Configuration
import yaml
import torch
from src.jepa.data import JEPADataset, TubeletDataset, MaskTubelet
from src.jepa.models import JEPAModel
from src.jepa.training import Trainer
from torch.utils.data import DataLoader, random_split

# Load default config
with open('configs/default.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Override config for Colab if needed
config['training']['batch_size'] = 8  # Adjust based on GPU VRAM
config['training']['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {config['training']['device']}")

# --- Dynamic Checkpoint Directory ---
def get_next_run_dir(base_dir):
    from datetime import datetime
    date_str = datetime.now().strftime('%Y-%m-%d')
    run_dir = os.path.join(base_dir, date_str)
    
    if not os.path.exists(run_dir):
        return run_dir
    
    i = 2
    while True:
        run_dir_v = f"{run_dir}_{i}"
        if not os.path.exists(run_dir_v):
            return run_dir_v
        i += 1

base_ckpt_dir = config['training'].get('checkpoint_dir', 'experiments/checkpoints')
run_ckpt_dir = get_next_run_dir(base_ckpt_dir)
config['training']['checkpoint_dir'] = run_ckpt_dir

os.makedirs(run_ckpt_dir, exist_ok=True)
print(f"üöÄ Checkpoints will be saved to: {run_ckpt_dir}")


  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [3]:
# 4. Prepare Data
mask_transform = MaskTubelet(
    mask_ratio=config['data']['mask_ratio'],
    patch_size=config['data']['patch_size']
)

# Load full dataset
full_dataset = TubeletDataset(
    manifest_path=config['data']['manifest_path'],
    data_root=config['data'].get('data_root'),  # Handle relative paths
    tubelet_size=config['data']['tubelet_size'],
    transform=mask_transform
)

# Split train/val
train_size = int(config['data']['train_split'] * len(full_dataset))
val_size = len(full_dataset) - train_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(
    train_ds,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    val_ds,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    num_workers=2
)

print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")

Train samples: 134, Val samples: 34


In [4]:
# 5. Initialize Model
model = JEPAModel(
    encoder_name=config['model']['encoder_name'],
    predictor_hidden=config['model']['predictor']['hidden_dim'],
    predictor_dropout=config['model']['predictor']['dropout'],
    freeze_encoder=config['model']['freeze_encoder']
)

device = torch.device(config['training']['device'])
model.to(device)

# Optimizer
optimizer = torch.optim.AdamW(
    model.predictor.parameters(),  # Only optimize predictor
    lr=float(config['training']['lr']),
    weight_decay=float(config['training']['weight_decay'])
)

Loading weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 587/587 [00:00<00:00, 1585.85it/s, Materializing param=predictor.proj.weight]                           


In [5]:
# 6. Training Loop
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    device=device,
    checkpoint_dir=config['training']['checkpoint_dir']
)

num_epochs = config['training']['epochs']
best_loss = float('inf')

for epoch in range(num_epochs):
    # Train
    train_loss = trainer.train_epoch(train_loader, epoch)
    
    # Validate
    val_loss = trainer.validate_epoch(val_loader, epoch)
    
    # Checkpoint
    is_best = val_loss < best_loss
    if is_best:
        best_loss = val_loss
        
    if (epoch + 1) % config['training']['checkpoint_every'] == 0 or is_best:
        trainer.save_checkpoint(epoch, val_loss, is_best, config)

Train Epoch 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 17/17 [00:55<00:00,  3.28s/it, loss=0.0166]
Val Epoch 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:23<00:00,  4.64s/it, loss=0.0150]
Train Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 17/17 [00:53<00:00,  3.12s/it, loss=0.0130]
Val Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:23<00:00,  4.68s/it, loss=0.0117]
Train Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 17/17 [00:53<00:00,  3.16s/it, loss=0.0125]
Val Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:23<00:00,  4.68s/it, loss=0.0098]
Train Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 17/17 [00:56<00:00,  3.33s/it, loss=0.0099]
Val Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:23<00:00,  4.70s/it, loss=0.0086]
Train Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 17/17 [00:58<00:00,  3.42s/it, loss=0.0091]
Val Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:26<00:00,  5.29s/it, loss=0.0080]
Train Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 17/17 [0