# VideoMAE 2D Training Notebook

This notebook trains a 2D MAE model using an HDF5 dataset:
- Load an HDF5 file from a path
- Split trials into train/val using `split_data`
- Create datasets via `create_dataset('vsd_mae', ...)` and DataLoaders
- Build a 2D MAE (ResNet18 backbone + lightweight decoder)
- Train and log metrics to TensorBoard



In [1]:
# Imports and configuration
import os
import torch
from torch.utils.data import DataLoader

from src.data.dataset_factory import create_dataset
from src.data.split_data import split_data
from src.models.backbone.mae_backbone_2d import MAEResNet18Backbone
from src.models.heads.mae_decoder_2d import MAEDecoder2D
from src.models.systems.mae_system import MAESystem
from src.utils.logger import TBLogger, set_seed

# --- User config ---
HDF5_PATH = r"G:\My Drive\HDF5_DATA_AFTER_PREPROCESSING2\vsd_video_data.hdf5"
LOG_DIR = "logs"
CKPT_DIR = "checkpoints"
SEED = 42

# Data params
CLIP_LENGTH = 1           # 1 -> 2D (single frame); >1 creates clips but MAE2D squeezes T=1
BATCH_SIZE = 8
NUM_WORKERS = 2
SPLIT_RATIO = 0.8

# Masking params for dataset
MASK_RATIO = 0.75
PATCH_SIZE = (1, 16, 16)  # (T, H, W); T=1 for 2D

# Training params
EPOCHS = 5
LR = 1e-4
WEIGHT_DECAY = 0.05

# Setup
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)
set_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


Using device: cpu


In [None]:
# Split data into train/val by trials
train_trials, val_trials, index_entries = split_data(HDF5_PATH, split_ratio=SPLIT_RATIO, random_seed=SEED)
print(f"Train trials: {len(train_trials)} | Val trials: {len(val_trials)}")

Global data split summary:
  Total trials: 212
  Train trials: 169 (79.7%)
  Val trials: 43 (20.3%)
Train trials: 169 | Val trials: 43


In [3]:
# Build datasets and dataloaders
train_ds = create_dataset(
    "vsd_mae",
    hdf5_path=HDF5_PATH,
    clip_length=CLIP_LENGTH,
    trial_indices=train_trials,
    index_entries=index_entries,
    normalize=False,
    mask_ratio=MASK_RATIO,
    patch_size=PATCH_SIZE,
)

val_ds = create_dataset(
    "vsd_mae",
    hdf5_path=HDF5_PATH,
    clip_length=CLIP_LENGTH,
    trial_indices=val_trials,
    index_entries=index_entries,
    normalize=False,
    mask_ratio=MASK_RATIO,
    patch_size=PATCH_SIZE,
)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)}")


Train samples: 43095 | Val samples: 10965


In [4]:
# Build MAE 2D model and optimizer
encoder = MAEResNet18Backbone(pretrained=False, in_channels=1)
decoder = MAEDecoder2D(in_channels=encoder.feature_dim, out_channels=1, hidden_dim=256)

config = {
    "training": {"lr": LR, "weight_decay": WEIGHT_DECAY},
    "loss": {"normalize": True},
}

model = MAESystem(encoder=encoder, decoder=decoder, config=config).to(DEVICE)
optimizer = model.get_optimizer()
logger = TBLogger(log_dir=LOG_DIR)

print(model.__class__.__name__, "built.")


MAESystem built.


In [None]:
# Training loop with TensorBoard logging and simple validation
from tqdm import tqdm

global_step = 0
scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        # Move batch tensors to device and match MAESystem input keys
        batch = {
            "video_masked": batch.get("video_masked", torch.zeros(1)),
            "video_target": batch.get("video_target", torch.zeros(1)),
            "mask": batch.get("mask", torch.zeros(1)),
        }
        batch = {k: v.to(DEVICE) for k, v in batch.items()}

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            out = model(batch)
            loss = out["loss"]
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Logs
        pbar.set_postfix({"loss": float(loss.item())})
        logger.log_scalar("train/loss", float(loss.item()), global_step)
        for mname, mval in out.get("metrics", {}).items():
            logger.log_scalar(f"train/{mname}", mval, global_step)
        global_step += 1

    # Simple validation (average metrics over one pass)
    model.eval()
    with torch.no_grad():
        val_losses = []
        val_metrics = {}
        for i, batch in enumerate(val_loader):
            batch = {
                "video_masked": batch.get("video_masked", torch.zeros(1)),
                "video_target": batch.get("video_target", torch.zeros(1)),
                "mask": batch.get("mask", torch.zeros(1)),
            }
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            out = model(batch)
            val_losses.append(out["loss"].item())
            for mname, mval in out.get("metrics", {}).items():
                val_metrics.setdefault(mname, []).append(mval)
        if val_losses:
            logger.log_scalar("val/loss", sum(val_losses)/len(val_losses), epoch)
            for mname, arr in val_metrics.items():
                logger.log_scalar(f"val/{mname}", sum(arr)/len(arr), epoch)

    # Save checkpoint each epoch
    ckpt_path = os.path.join(CKPT_DIR, f"mae2d_epoch_{epoch+1}.pt")
    torch.save({
        "encoder": model.encoder.state_dict(),
        "decoder": model.decoder.state_dict(),
        "config": config,
        "epoch": epoch+1,
    }, ckpt_path)
    print(f"Saved checkpoint: {ckpt_path}")

print("Training complete.")


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
Epoch 1/5:   1%|          | 44/5387 [04:04<7:24:11,  4.99s/it, loss=nan] 