# JEPA Training on Google Colab

This notebook trains the JEPA model using the production package structure.

## Features
- **Dynamic Repository Cloning**: Automatically clones or updates the repository from GitHub
- **Branch Support**: Specify any branch to use (default: `main`)
- **Colab & Local**: Works on both Google Colab and local environments

## Configuration
Edit the `REPO_BRANCH` variable in the next cell to specify which branch to use:
- `main` (default)
- `feat/jepa_evaluation`
- Any other branch name

In [None]:
# 1. Environment Setup (Colab & Local Support)
import sys
import os
import subprocess
import shutil

# Repository configuration
REPO_URL = "https://github.com/shamikkarkhanis/AV-SSL-Optimization-JEPA.git"
REPO_BRANCH = "main"  # Change this to specify a different branch (e.g., "feat/jepa_evaluation")

# Detect environment
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

def clone_or_update_repo(repo_url, branch, target_path):
    """Clone repository or update if it already exists."""
    repo_name = os.path.basename(repo_url).replace('.git', '')
    full_path = os.path.join(target_path, repo_name)
    
    if os.path.exists(full_path):
        print(f"üìÇ Repository found at {full_path}")
        print(f"üîÑ Checking out branch: {branch}")
        os.chdir(full_path)
        
        # Fetch latest changes
        subprocess.run(['git', 'fetch', 'origin'], check=False, capture_output=True)
        
        # Checkout specified branch
        result = subprocess.run(['git', 'checkout', branch], capture_output=True, text=True)
        if result.returncode != 0:
            # Try to checkout remote branch if local doesn't exist
            subprocess.run(['git', 'checkout', '-b', branch, f'origin/{branch}'], 
                         capture_output=True, check=False)
        
        # Pull latest changes
        subprocess.run(['git', 'pull', 'origin', branch], check=False, capture_output=True)
        print(f"‚úÖ Repository updated to latest {branch}")
    else:
        print(f"üì• Cloning repository from {repo_url} (branch: {branch})...")
        os.makedirs(target_path, exist_ok=True)
        os.chdir(target_path)
        
        # Clone with specific branch
        subprocess.run(['git', 'clone', '--branch', branch, '--single-branch', repo_url], 
                      check=True)
        print(f"‚úÖ Repository cloned successfully")
    
    return full_path

if IN_COLAB:
    print("Running on Google Colab")
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Clone/update repo in Colab content directory (faster than Drive)
    REPO_BASE_PATH = '/content'
    REPO_PATH = clone_or_update_repo(REPO_URL, REPO_BRANCH, REPO_BASE_PATH)
    
    os.chdir(REPO_PATH)
    print(f"üìÇ Working directory set to: {os.getcwd()}")
    
    # Install packages (Colab only)
    print("üì¶ Installing dependencies...")
    !pip install -e .[dev]
    !pip install -r requirements.txt

else:
    print("Running Locally")
    current_dir = os.getcwd()
    
    # Check if we are already inside the repository
    if "AV-SSL-Optimization-JEPA" in current_dir:
        # Find the actual root (where .git or pyproject.toml exists)
        temp_dir = current_dir
        while temp_dir != "/" and not (os.path.exists(os.path.join(temp_dir, ".git")) or os.path.exists(os.path.join(temp_dir, "pyproject.toml"))):
            temp_dir = os.path.dirname(temp_dir)

        if os.path.exists(os.path.join(temp_dir, "pyproject.toml")):
            REPO_PATH = temp_dir
            print(f"‚úÖ Already in repository root: {REPO_PATH}")
        else:
            REPO_PATH = current_dir
    else:
        REPO_PATH = current_dir

    os.chdir(REPO_PATH)
    print(f"üìÇ Working directory set to: {os.getcwd()}")
    
    # Add project root to sys.path to find 'src' module
    if os.getcwd() not in sys.path:
        sys.path.append(os.getcwd())
        print("Added project root to sys.path")

Running on Google Colab
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚ö†Ô∏è Repo not found at /content/drive/MyDrive/AV-SSL-Optimization-JEPA. Please clone it to Drive first.


In [2]:
# 3. Load Configuration
import yaml
import torch
from src.jepa.data import JEPADataset, TubeletDataset, MaskTubelet
from src.jepa.models import JEPAModel
from src.jepa.training import Trainer
from torch.utils.data import DataLoader, random_split

# Load default config
with open('configs/default.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Override config for Colab if needed
config['training']['batch_size'] = 8  # Adjust based on GPU VRAM
config['training']['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {config['training']['device']}")

# --- Dynamic Checkpoint Directory ---
def get_next_run_dir(base_dir):
    from datetime import datetime
    date_str = datetime.now().strftime('%Y-%m-%d')
    run_dir = os.path.join(base_dir, date_str)
    
    if not os.path.exists(run_dir):
        return run_dir
    
    i = 2
    while True:
        run_dir_v = f"{run_dir}_{i}"
        if not os.path.exists(run_dir_v):
            return run_dir_v
        i += 1

base_ckpt_dir = config['training'].get('checkpoint_dir', 'experiments/checkpoints')
run_ckpt_dir = get_next_run_dir(base_ckpt_dir)
config['training']['checkpoint_dir'] = run_ckpt_dir

os.makedirs(run_ckpt_dir, exist_ok=True)
print(f"üöÄ Checkpoints will be saved to: {run_ckpt_dir}")


ModuleNotFoundError: No module named 'src'

In [None]:
# 4. Prepare Data
mask_transform = MaskTubelet(
    mask_ratio=config['data']['mask_ratio'],
    patch_size=config['data']['patch_size']
)

# Load full dataset
full_dataset = TubeletDataset(
    manifest_path=config['data']['manifest_path'],
    data_root=config['data'].get('data_root'),  # Handle relative paths
    tubelet_size=config['data']['tubelet_size'],
    transform=mask_transform
)

# Split train/val
train_size = int(config['data']['train_split'] * len(full_dataset))
val_size = len(full_dataset) - train_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(
    train_ds,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    val_ds,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    num_workers=2
)

print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")

Train samples: 134, Val samples: 34


In [None]:
# 5. Initialize Model
model = JEPAModel(
    encoder_name=config['model']['encoder_name'],
    predictor_hidden=config['model']['predictor']['hidden_dim'],
    predictor_dropout=config['model']['predictor']['dropout'],
    freeze_encoder=config['model']['freeze_encoder']
)

device = torch.device(config['training']['device'])
model.to(device)

# Optimizer
optimizer = torch.optim.AdamW(
    model.predictor.parameters(),  # Only optimize predictor
    lr=float(config['training']['lr']),
    weight_decay=float(config['training']['weight_decay'])
)

Loading weights:   0%|          | 0/587 [00:00<?, ?it/s]

: 

In [None]:
# BENCHMARKING: Performance measurement utilities
import time
import json
import statistics
import threading
from collections import deque
from pathlib import Path

# BENCHMARKING: Configuration
BENCHMARK_CONFIG = {
    'warmup_iterations': 50,      # Ignore first N iterations
    'benchmark_iterations': 300,   # Benchmark next M iterations
    'nvidia_smi_interval': 0.25,   # Sample nvidia-smi every 250ms (if available)
    'enable_nvidia_smi': True,     # Enable nvidia-smi sampling (optional)
}

# BENCHMARKING: Global state for benchmarking
benchmark_state = {
    'step_times': deque(maxlen=BENCHMARK_CONFIG['benchmark_iterations']),
    'throughput': deque(maxlen=BENCHMARK_CONFIG['benchmark_iterations']),
    'gpu_samples': [],
    'nvidia_smi_thread': None,
    'nvidia_smi_running': False,
    'current_iteration': 0,
    'benchmark_started': False,
    'peak_memory_mb': 0.0,
}

def reset_memory_stats():
    """BENCHMARKING: Reset CUDA memory statistics."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

def get_peak_memory_mb():
    """BENCHMARKING: Get peak allocated memory in MB."""
    if torch.cuda.is_available():
        return torch.cuda.max_memory_allocated() / (1024 ** 2)
    return 0.0

def sample_nvidia_smi():
    """BENCHMARKING: Sample nvidia-smi for GPU utilization and power (optional)."""
    if not BENCHMARK_CONFIG['enable_nvidia_smi']:
        return None
    
    try:
        import subprocess
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=utilization.gpu,power.draw', '--format=csv,noheader,nounits'],
            capture_output=True,
            text=True,
            timeout=1.0
        )
        if result.returncode == 0:
            parts = result.stdout.strip().split(', ')
            if len(parts) == 2:
                return {
                    'utilization_gpu': float(parts[0]),
                    'power_draw': float(parts[1]),
                    'timestamp': time.perf_counter()
                }
    except (FileNotFoundError, subprocess.TimeoutExpired, ValueError):
        pass
    return None

def nvidia_smi_sampler():
    """BENCHMARKING: Background thread to sample nvidia-smi periodically."""
    while benchmark_state['nvidia_smi_running']:
        sample = sample_nvidia_smi()
        if sample:
            benchmark_state['gpu_samples'].append(sample)
        time.sleep(BENCHMARK_CONFIG['nvidia_smi_interval'])

def start_nvidia_smi_sampling():
    """BENCHMARKING: Start background nvidia-smi sampling."""
    if BENCHMARK_CONFIG['enable_nvidia_smi'] and torch.cuda.is_available():
        benchmark_state['nvidia_smi_running'] = True
        benchmark_state['gpu_samples'] = []
        benchmark_state['nvidia_smi_thread'] = threading.Thread(target=nvidia_smi_sampler, daemon=True)
        benchmark_state['nvidia_smi_thread'].start()

def stop_nvidia_smi_sampling():
    """BENCHMARKING: Stop background nvidia-smi sampling."""
    benchmark_state['nvidia_smi_running'] = False
    if benchmark_state['nvidia_smi_thread']:
        benchmark_state['nvidia_smi_thread'].join(timeout=1.0)

def benchmarked_train_epoch(trainer, loader, epoch, batch_size):
    """
    BENCHMARKING: Training epoch with per-iteration benchmarking.
    Replicates trainer.train_epoch() logic but with instrumentation.
    """
    from src.jepa.training.losses import jepa_loss
    from tqdm import tqdm
    
    trainer.model.train()
    trainer.model.encoder.eval()
    trainer.model.predictor.train()
    
    total_loss = 0.0
    
    # BENCHMARKING: Reset memory stats before training (only once, at start)
    if epoch == 0 and not benchmark_state['benchmark_started']:
        reset_memory_stats()
        benchmark_state['benchmark_started'] = True
        benchmark_state['current_iteration'] = 0  # Initialize global counter
        start_nvidia_smi_sampling()
    
    pbar = tqdm(loader, desc=f"Train Epoch {epoch}")
    for batch in pbar:
        iter_start = time.perf_counter()
        
        # BENCHMARKING: Synchronize before timing
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        # Move data to device
        masked = batch["masked_frames"].to(trainer.device)
        clean = batch["clean_frames"].to(trainer.device)
        mask_frac = batch["mask_frac"].to(trainer.device)
        
        # Forward pass
        clean_emb, pred_emb = trainer.model(clean, masked, mask_frac)
        
        # Calculate loss
        loss = jepa_loss(pred_emb, clean_emb, normalize=True)
        
        # Backward pass
        trainer.optimizer.zero_grad()
        loss.backward()
        trainer.optimizer.step()
        
        # BENCHMARKING: Synchronize after computation
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        iter_end = time.perf_counter()
        iter_time_ms = (iter_end - iter_start) * 1000.0
        
        # Update metrics
        loss_val = loss.item()
        total_loss += loss_val
        pbar.set_postfix({"loss": f"{loss_val:.4f}"})
        
        # BENCHMARKING: Collect metrics after warmup
        if benchmark_state['current_iteration'] >= BENCHMARK_CONFIG['warmup_iterations']:
            if len(benchmark_state['step_times']) < BENCHMARK_CONFIG['benchmark_iterations']:
                benchmark_state['step_times'].append(iter_time_ms)
                # Calculate throughput: batch_size images per iter_time_ms
                images_per_sec = (batch_size * 1000.0) / iter_time_ms if iter_time_ms > 0 else 0.0
                benchmark_state['throughput'].append(images_per_sec)
                
                # Update peak memory
                current_peak = get_peak_memory_mb()
                benchmark_state['peak_memory_mb'] = max(benchmark_state['peak_memory_mb'], current_peak)
        
        benchmark_state['current_iteration'] += 1
    
    # BENCHMARKING: Stop sampling if benchmark window is complete
    if len(benchmark_state['step_times']) >= BENCHMARK_CONFIG['benchmark_iterations']:
        stop_nvidia_smi_sampling()
    
    return total_loss / len(loader)

def print_benchmark_summary():
    """BENCHMARKING: Print concise benchmark summary."""
    if len(benchmark_state['step_times']) == 0:
        print("\n‚ö†Ô∏è  BENCHMARKING: No benchmark data collected yet.")
        return
    
    step_times = list(benchmark_state['step_times'])
    throughputs = list(benchmark_state['throughput'])
    
    mean_step_time = statistics.mean(step_times)
    std_step_time = statistics.stdev(step_times) if len(step_times) > 1 else 0.0
    
    mean_throughput = statistics.mean(throughputs)
    std_throughput = statistics.stdev(throughputs) if len(throughputs) > 1 else 0.0
    
    print("\n" + "="*60)
    print("BENCHMARKING: Upstream Training Performance Summary")
    print("="*60)
    print(f"Benchmark Window: {len(step_times)} iterations")
    print(f"Warmup Iterations: {BENCHMARK_CONFIG['warmup_iterations']}")
    print()
    print("Throughput:")
    print(f"  Images/sec: {mean_throughput:.2f} ¬± {std_throughput:.2f}")
    print(f"  Clips/sec:  {mean_throughput:.2f} ¬± {std_throughput:.2f}  (assuming 1 clip = 1 image)")
    print()
    print("Step Time:")
    print(f"  Average: {mean_step_time:.2f} ¬± {std_step_time:.2f} ms/iter")
    print()
    print("GPU Memory:")
    print(f"  Peak Allocated VRAM: {benchmark_state['peak_memory_mb']:.2f} MB")
    
    # BENCHMARKING: Optional nvidia-smi summary
    if benchmark_state['gpu_samples']:
        utilizations = [s['utilization_gpu'] for s in benchmark_state['gpu_samples']]
        powers = [s['power_draw'] for s in benchmark_state['gpu_samples']]
        mean_util = statistics.mean(utilizations)
        mean_power = statistics.mean(powers)
        
        # Calculate energy per iteration (rough estimate)
        if len(step_times) > 0:
            avg_step_time_sec = mean_step_time / 1000.0
            energy_per_iter = mean_power * avg_step_time_sec / 1000.0  # Joules per iteration
        
        print()
        print("GPU Utilization (nvidia-smi):")
        print(f"  Average GPU Utilization: {mean_util:.1f}%")
        print(f"  Average Power Draw: {mean_power:.2f} W")
        if len(step_times) > 0:
            print(f"  Estimated Energy per Iteration: {energy_per_iter:.4f} J")
    
    print("="*60)

def save_benchmark_results(output_dir):
    """BENCHMARKING: Save benchmark results to JSON file."""
    if len(benchmark_state['step_times']) == 0:
        return
    
    step_times = list(benchmark_state['step_times'])
    throughputs = list(benchmark_state['throughput'])
    
    results = {
        'config': BENCHMARK_CONFIG,
        'summary': {
            'mean_step_time_ms': statistics.mean(step_times),
            'std_step_time_ms': statistics.stdev(step_times) if len(step_times) > 1 else 0.0,
            'mean_throughput_images_per_sec': statistics.mean(throughputs),
            'std_throughput_images_per_sec': statistics.stdev(throughputs) if len(throughputs) > 1 else 0.0,
            'peak_memory_mb': benchmark_state['peak_memory_mb'],
            'benchmark_iterations': len(step_times),
        },
        'raw_data': {
            'step_times_ms': step_times,
            'throughput_images_per_sec': throughputs,
        }
    }
    
    # Add nvidia-smi data if available
    if benchmark_state['gpu_samples']:
        utilizations = [s['utilization_gpu'] for s in benchmark_state['gpu_samples']]
        powers = [s['power_draw'] for s in benchmark_state['gpu_samples']]
        results['summary']['mean_gpu_utilization_percent'] = statistics.mean(utilizations)
        results['summary']['mean_power_draw_watts'] = statistics.mean(powers)
        results['raw_data']['gpu_samples'] = benchmark_state['gpu_samples']
    
    output_path = Path(output_dir) / 'benchmark_results.json'
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nüíæ BENCHMARKING: Results saved to {output_path}")

print("‚úÖ BENCHMARKING: Performance measurement utilities loaded")

In [None]:
# 6. Training Loop
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    device=device,
    checkpoint_dir=config['training']['checkpoint_dir']
)

num_epochs = config['training']['epochs']
best_loss = float('inf')
batch_size = config['training']['batch_size']

# BENCHMARKING: Use benchmarked training function for accurate measurements
for epoch in range(num_epochs):
    # BENCHMARKING: Use benchmarked_train_epoch instead of trainer.train_epoch
    # This replicates the exact same training logic but with performance instrumentation
    train_loss = benchmarked_train_epoch(trainer, train_loader, epoch, batch_size)
    
    # Validate (no benchmarking on validation)
    val_loss = trainer.validate_epoch(val_loader, epoch)
    
    # Checkpoint
    is_best = val_loss < best_loss
    if is_best:
        best_loss = val_loss
        
    if (epoch + 1) % config['training']['checkpoint_every'] == 0 or is_best:
        trainer.save_checkpoint(epoch, val_loss, is_best, config)
    
    # BENCHMARKING: Print summary and save results when benchmark window is complete
    if len(benchmark_state['step_times']) >= BENCHMARK_CONFIG['benchmark_iterations']:
        if not benchmark_state.get('_summary_printed', False):
            print_benchmark_summary()
            save_benchmark_results(config['training']['checkpoint_dir'])
            benchmark_state['_summary_printed'] = True

# BENCHMARKING: Print final summary if training completes before benchmark window
if len(benchmark_state['step_times']) > 0 and not benchmark_state.get('_summary_printed', False):
    print_benchmark_summary()
    save_benchmark_results(config['training']['checkpoint_dir'])

Train Epoch 0:  12%|‚ñà‚ñè        | 2/17 [00:10<01:11,  4.74s/it, loss=0.0288]