# DCASE 2022 VQ-VAE + PixelSNAIL — Training on Colab

Runs the project training pipeline on Google Colab with **longer epochs**.

**Setup:** Run cells **1 → 2 → 3 → 4 → 5 in order** (Mount Drive → Clone repo → Project & deps → Config → Run). Set `DATA_ROOT` in the config cell to your dataset path. **Full step-by-step guide:** see `notebooks/COLAB_TRAINING_GUIDE.md` in the repo.

## 1. Mount Google Drive
Required so we can clone the project under `/content/drive/MyDrive/` and (optionally) keep dataset and checkpoints on Drive.

In [1]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


## 2. Clone project into Drive
Clone the repo under your Drive so the Colab kernel has the full project (including `src/`). If the folder already exists, we pull the latest; otherwise we clone. Change `REPO_DIR` if you want a different path.

In [2]:
import os

# Clone into this folder under your Drive (My Drive = /content/drive/MyDrive)
REPO_DIR = "/content/drive/MyDrive/semcom_asd_vqar"
REPO_URL = "https://github.com/raidantimosquitos/semcom_asd_vqar.git"

if os.path.isdir(REPO_DIR):
    !cd "{REPO_DIR}" && git pull
else:
    os.makedirs(os.path.dirname(REPO_DIR), exist_ok=True)
    !git clone {REPO_URL} {REPO_DIR}

remote: Enumerating objects: 16, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 9 (delta 4), reused 9 (delta 4), pack-reused 0 (from 0)[K
Unpacking objects: 100% (9/9), 9.77 KiB | 15.00 KiB/s, done.
From https://github.com/raidantimosquitos/semcom_asd_vqar
   86afe77..6b436c3  main       -> origin/main
Updating 86afe77..6b436c3
Fast-forward
 notebooks/COLAB_TRAINING_GUIDE.md | 255 [32m++++++++++++++++++++++++++++++++++++++[m
 notebooks/README.md               |  23 [32m++[m[31m--[m
 notebooks/colab_train.ipynb       | 139 [32m++++++++++++++++[m[31m-----[m
 src/engine/train.py               |   8 [32m+[m[31m-[m
 4 files changed, 377 insertions(+), 48 deletions(-)
 create mode 100644 notebooks/COLAB_TRAINING_GUIDE.md


## 3. Project and dependencies
**Run this cell before any training.** Uses `PROJECT_ROOT` from the clone path above, adds it to `sys.path`, and installs deps so `src` can be imported.

In [3]:
import sys
import os

# Use the clone path from the cell above (or set manually if you cloned elsewhere)
try:
    PROJECT_ROOT = REPO_DIR
except NameError:
    PROJECT_ROOT = "/content/drive/MyDrive/semcom_asd_vqar"

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

# Install dependencies (Colab usually has torch/numpy/sklearn; add PyYAML if missing)
!pip install -q PyYAML

## 4. Config and paths
Set `DATA_ROOT` to your dataset (e.g. on Drive: `/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset`). Checkpoints and logs are under the cloned project on Drive.

In [4]:
# Dataset path: on Colab use Drive path after mount, e.g. "/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset"
DATA_ROOT = "/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset"

# Checkpoints and logs under the cloned project on Drive
CHECKPOINT_DIR = "./checkpoints"
LOG_DIR = "./logs"

# Overrides: longer epochs (merge into config)
OVERRIDES = {
    "data": {"root_dir": DATA_ROOT},
    "phase1": {
        "num_epochs": 50,
        "checkpoint": f"{CHECKPOINT_DIR}/mobilenetv2_8x_vqvae.pth",
    },
    "phase2": {
        "num_epochs": 80,
        "checkpoint": f"{CHECKPOINT_DIR}/pixelsnail_prior.pth",
    },
    "eval": {
        "vqvae_checkpoint": f"{CHECKPOINT_DIR}/mobilenetv2_8x_vqvae.pth",
        "prior_checkpoint": f"{CHECKPOINT_DIR}/pixelsnail_prior.pth",
    },
    "logging": {"log_dir": LOG_DIR},
}

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

## 5. Run training
**Run cells 1–4 first.** Uses `configs/colab.yaml` and the overrides above. Logs go to the logger (console + file).

In [None]:
# Sanity check: run cells 1–4 first (Mount, Clone, Dependencies, Config)
assert "PROJECT_ROOT" in dir(), "Run the 'Project and dependencies' cell (section 3) first."
from src.main import run

run(
    config_path=os.path.join(PROJECT_ROOT, "configs", "colab.yaml"),
    overrides=OVERRIDES,
    mode="train",
    log_dir=LOG_DIR,
)

2026-02-18 04:00:25 | INFO | Config: {'mode': 'train', 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 50, 'lr': 0.001, 'checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 80, 'lr': 0.0001, 'checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


INFO:main:Config: {'mode': 'train', 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 50, 'lr': 0.001, 'checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 80, 'lr': 0.0001, 'checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


2026-02-18 04:00:25 | INFO | Mode: train


INFO:main:Mode: train


2026-02-18 04:00:25 | INFO | Config: {'mode': 'train', 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 50, 'lr': 0.001, 'checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 80, 'lr': 0.0001, 'checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


INFO:main:Config: {'mode': 'train', 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 50, 'lr': 0.001, 'checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 80, 'lr': 0.0001, 'checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


2026-02-18 04:00:25 | INFO | Device: cuda


INFO:main:Device: cuda


2026-02-18 04:34:06 | INFO | Data loading


INFO:main:Data loading


2026-02-18 04:34:06 | INFO |   root_dir: /content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset


INFO:main:  root_dir: /content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset


2026-02-18 04:34:06 | INFO |   appliance: fan | mode: train


INFO:main:  appliance: fan | mode: train


2026-02-18 04:34:06 | INFO |   train_size: 3307 | val_size: 368


INFO:main:  train_size: 3307 | val_size: 368


2026-02-18 04:34:06 | INFO |   total_samples (before split): 3675


INFO:main:  total_samples (before split): 3675


2026-02-18 04:34:06 | INFO |   batch_size: 32 | num_workers: 0


INFO:main:  batch_size: 32 | num_workers: 0


2026-02-18 04:34:06 | INFO |   train_batches: 103 | val_batches: 11


INFO:main:  train_batches: 103 | val_batches: 11


Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


100%|██████████| 13.6M/13.6M [00:00<00:00, 134MB/s]


2026-02-18 04:34:07 | INFO | Phase 1: Training VQ-VAE


INFO:main:Phase 1: Training VQ-VAE


2026-02-18 04:36:33 | INFO | Epoch 1/50 | train_recon: 0.377169 | train_vq: 0.692358 | train_perp: 8.0597 | val_recon: 0.463087 | val_vq: 0.743335 | val_perp: 2.1474 [best]


INFO:main:Epoch 1/50 | train_recon: 0.377169 | train_vq: 0.692358 | train_perp: 8.0597 | val_recon: 0.463087 | val_vq: 0.743335 | val_perp: 2.1474 [best]


2026-02-18 04:39:01 | INFO | Epoch 2/50 | train_recon: 0.411195 | train_vq: 0.835069 | train_perp: 3.2292 | val_recon: 0.713574 | val_vq: 10.913621 | val_perp: 1.7289


INFO:main:Epoch 2/50 | train_recon: 0.411195 | train_vq: 0.835069 | train_perp: 3.2292 | val_recon: 0.713574 | val_vq: 10.913621 | val_perp: 1.7289


2026-02-18 04:41:31 | INFO | Epoch 3/50 | train_recon: 0.439968 | train_vq: 1.406788 | train_perp: 25.8300 | val_recon: 0.390943 | val_vq: 1.711137 | val_perp: 57.4180


INFO:main:Epoch 3/50 | train_recon: 0.439968 | train_vq: 1.406788 | train_perp: 25.8300 | val_recon: 0.390943 | val_vq: 1.711137 | val_perp: 57.4180


2026-02-18 04:44:00 | INFO | Epoch 4/50 | train_recon: 0.325279 | train_vq: 1.706637 | train_perp: 63.5566 | val_recon: 0.297520 | val_vq: 2.471393 | val_perp: 66.4018


INFO:main:Epoch 4/50 | train_recon: 0.325279 | train_vq: 1.706637 | train_perp: 63.5566 | val_recon: 0.297520 | val_vq: 2.471393 | val_perp: 66.4018


2026-02-18 04:46:30 | INFO | Epoch 5/50 | train_recon: 0.270029 | train_vq: 1.114122 | train_perp: 36.5721 | val_recon: 0.249388 | val_vq: 8.552254 | val_perp: 37.6079


INFO:main:Epoch 5/50 | train_recon: 0.270029 | train_vq: 1.114122 | train_perp: 36.5721 | val_recon: 0.249388 | val_vq: 8.552254 | val_perp: 37.6079
