# DCASE 2022 VQ-VAE + PixelSNAIL — Training on Colab

Runs the project training pipeline on Google Colab with **longer epochs**.

**Setup:** Run cells **1 → 2 → 3 → 4 → 5 in order** (Mount Drive → Clone repo → Project & deps → Config → Run). Set `DATA_ROOT` in the config cell to your dataset path. **Full step-by-step guide:** see `notebooks/COLAB_TRAINING_GUIDE.md` in the repo.

## 1. Mount Google Drive
Required so we can clone the project under `/content/drive/MyDrive/` and (optionally) keep dataset and checkpoints on Drive.

In [26]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 2. Clone project into Drive
Clone the repo under your Drive so the Colab kernel has the full project (including `src/`). If the folder already exists, we pull the latest; otherwise we clone. Change `REPO_DIR` if you want a different path.

In [27]:
import os

# Clone into this folder under your Drive (My Drive = /content/drive/MyDrive)
REPO_DIR = "/content/drive/MyDrive/semcom_asd_vqar"
REPO_URL = "https://github.com/raidantimosquitos/semcom_asd_vqar.git"

if os.path.isdir(REPO_DIR):
    !cd "{REPO_DIR}" && git pull
else:
    os.makedirs(os.path.dirname(REPO_DIR), exist_ok=True)
    !git clone {REPO_URL} {REPO_DIR}

remote: Enumerating objects: 9, done.[K
remote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 6 (delta 3), reused 6 (delta 3), pack-reused 0 (from 0)[K
Unpacking objects: 100% (6/6), 1.95 KiB | 19.00 KiB/s, done.
From https://github.com/raidantimosquitos/semcom_asd_vqar
   0a7d8a7..86afe77  main       -> origin/main
Updating 0a7d8a7..86afe77
Fast-forward
 .gitignore                |   4 [32m+[m[31m-[m
 src/data/__init__.py      |   0
 src/data/preprocessing.py | 133 [32m++++++++++++++++++++++++++++++++++++++++++++++[m
 3 files changed, 135 insertions(+), 2 deletions(-)
 create mode 100644 src/data/__init__.py
 create mode 100644 src/data/preprocessing.py


## 3. Project and dependencies
**Run this cell before any training.** Uses `PROJECT_ROOT` from the clone path above, adds it to `sys.path`, and installs deps so `src` can be imported.

In [28]:
import sys
import os

# Use the clone path from the cell above (or set manually if you cloned elsewhere)
try:
    PROJECT_ROOT = REPO_DIR
except NameError:
    PROJECT_ROOT = "/content/drive/MyDrive/semcom_asd_vqar"

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

# Install dependencies (Colab usually has torch/numpy/sklearn; add PyYAML if missing)
!pip install -q PyYAML

## 4. Config and paths
Set `DATA_ROOT` to your dataset (e.g. on Drive: `/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset`). Checkpoints and logs are under the cloned project on Drive.

In [29]:
# Dataset path: on Colab use Drive path after mount, e.g. "/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset"
DATA_ROOT = "/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset"

# Checkpoints and logs under the cloned project on Drive
CHECKPOINT_DIR = "./checkpoints"
LOG_DIR = "./logs"

# Overrides: longer epochs (merge into config)
OVERRIDES = {
    "data": {"root_dir": DATA_ROOT},
    "phase1": {
        "num_epochs": 50,
        "checkpoint": f"{CHECKPOINT_DIR}/mobilenetv2_8x_vqvae.pth",
    },
    "phase2": {
        "num_epochs": 80,
        "checkpoint": f"{CHECKPOINT_DIR}/pixelsnail_prior.pth",
    },
    "eval": {
        "vqvae_checkpoint": f"{CHECKPOINT_DIR}/mobilenetv2_8x_vqvae.pth",
        "prior_checkpoint": f"{CHECKPOINT_DIR}/pixelsnail_prior.pth",
    },
    "logging": {"log_dir": LOG_DIR},
}

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

## 5. Run training
**Run cells 1–4 first.** Uses `configs/colab.yaml` and the overrides above. Logs go to the logger (console + file).

In [31]:
# Sanity check: run cells 1–4 first (Mount, Clone, Dependencies, Config)
assert "PROJECT_ROOT" in dir(), "Run the 'Project and dependencies' cell (section 3) first."
from src.main import run

run(
    config_path=os.path.join(PROJECT_ROOT, "configs", "colab.yaml"),
    overrides=OVERRIDES,
    mode="train",
    log_dir=LOG_DIR,
)

2026-02-18 03:51:10 | INFO | Config: {'mode': 'train', 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 50, 'lr': 0.001, 'checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 80, 'lr': 0.0001, 'checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


INFO:main:Config: {'mode': 'train', 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 50, 'lr': 0.001, 'checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 80, 'lr': 0.0001, 'checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


2026-02-18 03:51:10 | INFO | Mode: train


INFO:main:Mode: train


2026-02-18 03:51:10 | INFO | Config: {'mode': 'train', 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 50, 'lr': 0.001, 'checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 80, 'lr': 0.0001, 'checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


INFO:main:Config: {'mode': 'train', 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 50, 'lr': 0.001, 'checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 80, 'lr': 0.0001, 'checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


2026-02-18 03:51:10 | INFO | Device: cpu


INFO:main:Device: cpu


KeyboardInterrupt: 