# DCASE 2022 VQ-VAE + PixelSNAIL — Training on Colab

Runs the project training pipeline on Google Colab with **longer epochs**.

**Before running:**
1. Upload the DCASE2020 Task 2 dev dataset to Colab (or mount Drive and set path).
2. Upload this project (zip → Colab, then unzip) or clone from Git; set `PROJECT_ROOT` below.
3. Set `DATA_ROOT` and optional epoch overrides in the config cell.

## 1. Mount Google Drive (optional)
If your dataset and/or project live on Drive, mount it first.

In [None]:
from google.colab import drive
drive.mount("/content/drive")

## 2. Project and dependencies
Set `PROJECT_ROOT` to where the repo is (e.g. after unzipping or cloning). Then install deps and add project to path.

In [None]:
import sys
import os

# Path to the project root (where configs/ and src/ live)
PROJECT_ROOT = "/content/dcase-2022-vq-vae-ar"  # or "/content/drive/MyDrive/dcase-2022-vq-vae-ar"

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

# Install dependencies (Colab usually has torch/numpy/sklearn; add PyYAML if missing)
!pip install -q PyYAML

## 3. Config and paths
Set the dataset root and, if you want, override epochs (and checkpoint paths) for longer training.

In [None]:
# Dataset path on Colab
DATA_ROOT = "/content/dcase2020-task2-dev-dataset"  # or "/content/drive/MyDrive/..."

# Optional: save checkpoints to Drive so they persist after runtime disconnect
CHECKPOINT_DIR = "/content/checkpoints"  # or "/content/drive/MyDrive/checkpoints"
LOG_DIR = "/content/logs"

# Overrides: longer epochs (merge into config)
OVERRIDES = {
    "data": {"root_dir": DATA_ROOT},
    "phase1": {
        "num_epochs": 50,
        "checkpoint": f"{CHECKPOINT_DIR}/mobilenetv2_8x_vqvae.pth",
    },
    "phase2": {
        "num_epochs": 80,
        "checkpoint": f"{CHECKPOINT_DIR}/pixelsnail_prior.pth",
    },
    "eval": {
        "vqvae_checkpoint": f"{CHECKPOINT_DIR}/mobilenetv2_8x_vqvae.pth",
        "prior_checkpoint": f"{CHECKPOINT_DIR}/pixelsnail_prior.pth",
    },
    "logging": {"log_dir": LOG_DIR},
}

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

## 4. Run training
Uses `configs/colab.yaml` as base and applies the overrides above. All logs go to the logger (console + file).

In [None]:
from src.main import run

run(
    config_path=os.path.join(PROJECT_ROOT, "configs", "colab.yaml"),
    overrides=OVERRIDES,
    mode="train",
    log_dir=LOG_DIR,
)