# DCASE 2022 VQ-VAE + PixelSNAIL — Training on Colab

Runs the project training pipeline on Google Colab with **longer epochs**.

**Setup:** Run cells **1 → 2 → 3 → 4 → 5** in order (Mount Drive → Clone repo → Project & deps → Config → Run training). Cell **6** runs evaluation and compares ROC AUC for reconstruction vs PixelSNAIL. Set `DATA_ROOT` in the config cell to your dataset path. **Full step-by-step guide:** see `notebooks/COLAB_TRAINING_GUIDE.md` in the repo.

## 1. Mount Google Drive
Required so we can clone the project under `/content/drive/MyDrive/` and (optionally) keep dataset and checkpoints on Drive.

In [1]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 2. Clone project into Drive
Clone the repo under your Drive so the Colab kernel has the full project (including `src/`). If the folder already exists, we pull the latest; otherwise we clone. Change `REPO_DIR` if you want a different path.

In [6]:
import os

# Clone into this folder under your Drive (My Drive = /content/drive/MyDrive)
REPO_DIR = "/content/drive/MyDrive/semcom_asd_vqar"
REPO_URL = "https://github.com/raidantimosquitos/semcom_asd_vqar.git"

if os.path.isdir(REPO_DIR):
    !cd "{REPO_DIR}" && git pull
else:
    os.makedirs(os.path.dirname(REPO_DIR), exist_ok=True)
    !git clone {REPO_URL} {REPO_DIR}

remote: Enumerating objects: 13, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 7 (delta 5), reused 7 (delta 5), pack-reused 0 (from 0)[K
Unpacking objects: 100% (7/7), 4.25 KiB | 9.00 KiB/s, done.
From https://github.com/raidantimosquitos/semcom_asd_vqar
   5ac8692..e40ed10  main       -> origin/main
Updating 5ac8692..e40ed10
Fast-forward
 notebooks/colab_train.ipynb | 925 [32m++++++++++++++++++++++++++++++++++++++++[m[31m----[m
 src/engine/test.py          |  12 [32m+[m[31m-[m
 2 files changed, 864 insertions(+), 73 deletions(-)


## 3. Project and dependencies
**Run this cell before any training.** Uses `PROJECT_ROOT` from the clone path above, adds it to `sys.path`, and installs deps so `src` can be imported.

In [3]:
import sys
import os

# Use the clone path from the cell above (or set manually if you cloned elsewhere)
try:
    PROJECT_ROOT = REPO_DIR
except NameError:
    PROJECT_ROOT = "/content/drive/MyDrive/semcom_asd_vqar"

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

# Install dependencies (Colab usually has torch/numpy/sklearn; add PyYAML if missing)
!pip install -q PyYAML

## 4. Config and paths
Set `DATA_ROOT` to your dataset (e.g. on Drive: `/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset`). Checkpoints and logs are under the cloned project on Drive.

In [4]:
# Dataset path: on Colab use Drive path after mount, e.g. "/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset"
DATA_ROOT = "/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset"

# Checkpoints and logs under the cloned project on Drive
CHECKPOINT_DIR = "./checkpoints"
LOG_DIR = "./logs"

# Overrides: longer epochs (merge into config)
OVERRIDES = {
    "data": {"root_dir": DATA_ROOT},
    "checkpoints": {"dir": CHECKPOINT_DIR},
    "phase1": {
        "num_epochs": 15,
        "lr": 0.002,
        "checkpoint": f"{CHECKPOINT_DIR}/models/mobilenetv2_8x_vqvae.pth",
    },
    "phase2": {
        "num_epochs": 30,
        "lr": 0.0002,
        "checkpoint": f"{CHECKPOINT_DIR}/models/pixelsnail_prior.pth",
    },
    "eval": {
        "vqvae_checkpoint": f"{CHECKPOINT_DIR}/models/mobilenetv2_8x_vqvae.pth",
        "prior_checkpoint": f"{CHECKPOINT_DIR}/models/pixelsnail_prior.pth",
    },
    "logging": {"log_dir": LOG_DIR},
}

os.makedirs(f"{CHECKPOINT_DIR}/models", exist_ok=True)  
os.makedirs(LOG_DIR, exist_ok=True)

## 5. Run training
**Run cells 1–4 first.** Uses `configs/colab.yaml` and the overrides above. Logs go to the logger (console + file).

In [5]:
# Sanity check: run cells 1–4 first (Mount, Clone, Dependencies, Config)
assert "PROJECT_ROOT" in dir(), "Run the 'Project and dependencies' cell (section 3) first."
from src.main import run

run(
    config_path=os.path.join(PROJECT_ROOT, "configs", "colab.yaml"),
    overrides=OVERRIDES,
    mode="train",
    log_dir=LOG_DIR,
)

2026-02-18 06:53:49 | INFO | Config: {'mode': 'train', 'checkpoints': {'dir': './checkpoints'}, 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 15, 'lr': 0.002, 'checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 30, 'lr': 0.0002, 'checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


INFO:main:Config: {'mode': 'train', 'checkpoints': {'dir': './checkpoints'}, 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 15, 'lr': 0.002, 'checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 30, 'lr': 0.0002, 'checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


2026-02-18 06:53:49 | INFO | Mode: train


INFO:main:Mode: train


2026-02-18 06:53:49 | INFO | Config: {'mode': 'train', 'checkpoints': {'dir': './checkpoints'}, 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 15, 'lr': 0.002, 'checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 30, 'lr': 0.0002, 'checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


INFO:main:Config: {'mode': 'train', 'checkpoints': {'dir': './checkpoints'}, 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 15, 'lr': 0.002, 'checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 30, 'lr': 0.0002, 'checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


2026-02-18 06:53:49 | INFO | Device: cuda


INFO:main:Device: cuda


2026-02-18 06:53:49 | INFO | Loaded train stats from /content/drive/MyDrive/semcom_asd_vqar/checkpoints/stats/fan_train_stats.pt: mean=-13.015643 | std=6.970860


INFO:main:Loaded train stats from /content/drive/MyDrive/semcom_asd_vqar/checkpoints/stats/fan_train_stats.pt: mean=-13.015643 | std=6.970860


2026-02-18 06:53:49 | INFO | Data loading


INFO:main:Data loading


2026-02-18 06:53:49 | INFO |   root_dir: /content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset


INFO:main:  root_dir: /content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset


2026-02-18 06:53:49 | INFO |   appliance: fan | mode: train


INFO:main:  appliance: fan | mode: train


2026-02-18 06:53:49 | INFO |   train_size: 3307 | val_size: 368


INFO:main:  train_size: 3307 | val_size: 368


2026-02-18 06:53:49 | INFO |   total_samples (before split): 3675


INFO:main:  total_samples (before split): 3675


2026-02-18 06:53:49 | INFO |   batch_size: 32 | num_workers: 0


INFO:main:  batch_size: 32 | num_workers: 0


2026-02-18 06:53:49 | INFO |   train_batches: 103 | val_batches: 11


INFO:main:  train_batches: 103 | val_batches: 11


2026-02-18 06:53:49 | INFO | Phase 1: Training VQ-VAE


INFO:main:Phase 1: Training VQ-VAE


2026-02-18 07:13:55 | INFO | Epoch 1/15 | train_recon: 0.355797 | train_vq: 0.828874 | train_perp: 11.2684 | val_recon: 0.297578 | val_vq: 1.184144 | val_perp: 4.7936 [best]


INFO:main:Epoch 1/15 | train_recon: 0.355797 | train_vq: 0.828874 | train_perp: 11.2684 | val_recon: 0.297578 | val_vq: 1.184144 | val_perp: 4.7936 [best]


2026-02-18 07:16:24 | INFO | Epoch 2/15 | train_recon: 0.277599 | train_vq: 0.743960 | train_perp: 7.6398 | val_recon: 0.275176 | val_vq: 1.191539 | val_perp: 10.9180 [best]


INFO:main:Epoch 2/15 | train_recon: 0.277599 | train_vq: 0.743960 | train_perp: 7.6398 | val_recon: 0.275176 | val_vq: 1.191539 | val_perp: 10.9180 [best]


2026-02-18 07:18:54 | INFO | Epoch 3/15 | train_recon: 0.258520 | train_vq: 0.625307 | train_perp: 14.8310 | val_recon: 0.246343 | val_vq: 1.222178 | val_perp: 19.0358


INFO:main:Epoch 3/15 | train_recon: 0.258520 | train_vq: 0.625307 | train_perp: 14.8310 | val_recon: 0.246343 | val_vq: 1.222178 | val_perp: 19.0358


2026-02-18 07:21:23 | INFO | Epoch 4/15 | train_recon: 0.243523 | train_vq: 0.469672 | train_perp: 22.3266 | val_recon: 0.232363 | val_vq: 1.182801 | val_perp: 26.0560 [best]


INFO:main:Epoch 4/15 | train_recon: 0.243523 | train_vq: 0.469672 | train_perp: 22.3266 | val_recon: 0.232363 | val_vq: 1.182801 | val_perp: 26.0560 [best]


2026-02-18 07:23:53 | INFO | Epoch 5/15 | train_recon: 0.232101 | train_vq: 0.311580 | train_perp: 29.2899 | val_recon: 0.224506 | val_vq: 0.769045 | val_perp: 33.4242 [best]


INFO:main:Epoch 5/15 | train_recon: 0.232101 | train_vq: 0.311580 | train_perp: 29.2899 | val_recon: 0.224506 | val_vq: 0.769045 | val_perp: 33.4242 [best]


2026-02-18 07:26:22 | INFO | Epoch 6/15 | train_recon: 0.226042 | train_vq: 0.181181 | train_perp: 36.7357 | val_recon: 0.218617 | val_vq: 0.324869 | val_perp: 40.8498 [best]


INFO:main:Epoch 6/15 | train_recon: 0.226042 | train_vq: 0.181181 | train_perp: 36.7357 | val_recon: 0.218617 | val_vq: 0.324869 | val_perp: 40.8498 [best]


2026-02-18 07:28:51 | INFO | Epoch 7/15 | train_recon: 0.217424 | train_vq: 0.123056 | train_perp: 44.9508 | val_recon: 0.215747 | val_vq: 0.173886 | val_perp: 50.5253 [best]


INFO:main:Epoch 7/15 | train_recon: 0.217424 | train_vq: 0.123056 | train_perp: 44.9508 | val_recon: 0.215747 | val_vq: 0.173886 | val_perp: 50.5253 [best]


2026-02-18 07:31:20 | INFO | Epoch 8/15 | train_recon: 0.214303 | train_vq: 0.094800 | train_perp: 55.6250 | val_recon: 0.212721 | val_vq: 0.165460 | val_perp: 60.4585 [best]


INFO:main:Epoch 8/15 | train_recon: 0.214303 | train_vq: 0.094800 | train_perp: 55.6250 | val_recon: 0.212721 | val_vq: 0.165460 | val_perp: 60.4585 [best]


2026-02-18 07:33:50 | INFO | Epoch 9/15 | train_recon: 0.213001 | train_vq: 0.077821 | train_perp: 65.0181 | val_recon: 0.210749 | val_vq: 0.141471 | val_perp: 69.7948 [best]


INFO:main:Epoch 9/15 | train_recon: 0.213001 | train_vq: 0.077821 | train_perp: 65.0181 | val_recon: 0.210749 | val_vq: 0.141471 | val_perp: 69.7948 [best]


2026-02-18 07:36:19 | INFO | Epoch 10/15 | train_recon: 0.206547 | train_vq: 0.068577 | train_perp: 74.3632 | val_recon: 0.204003 | val_vq: 0.175994 | val_perp: 78.5346


INFO:main:Epoch 10/15 | train_recon: 0.206547 | train_vq: 0.068577 | train_perp: 74.3632 | val_recon: 0.204003 | val_vq: 0.175994 | val_perp: 78.5346


2026-02-18 07:38:48 | INFO | Epoch 11/15 | train_recon: 0.205342 | train_vq: 0.059946 | train_perp: 82.2804 | val_recon: 0.202833 | val_vq: 0.114191 | val_perp: 86.4064 [best]


INFO:main:Epoch 11/15 | train_recon: 0.205342 | train_vq: 0.059946 | train_perp: 82.2804 | val_recon: 0.202833 | val_vq: 0.114191 | val_perp: 86.4064 [best]


2026-02-18 07:41:17 | INFO | Epoch 12/15 | train_recon: 0.203130 | train_vq: 0.055399 | train_perp: 90.1621 | val_recon: 0.199541 | val_vq: 0.073166 | val_perp: 93.5836 [best]


INFO:main:Epoch 12/15 | train_recon: 0.203130 | train_vq: 0.055399 | train_perp: 90.1621 | val_recon: 0.199541 | val_vq: 0.073166 | val_perp: 93.5836 [best]


2026-02-18 07:43:47 | INFO | Epoch 13/15 | train_recon: 0.200658 | train_vq: 0.051446 | train_perp: 97.4257 | val_recon: 0.194897 | val_vq: 0.080476 | val_perp: 99.6788


INFO:main:Epoch 13/15 | train_recon: 0.200658 | train_vq: 0.051446 | train_perp: 97.4257 | val_recon: 0.194897 | val_vq: 0.080476 | val_perp: 99.6788


2026-02-18 07:46:16 | INFO | Epoch 14/15 | train_recon: 0.198554 | train_vq: 0.048314 | train_perp: 103.0341 | val_recon: 0.194377 | val_vq: 0.062238 | val_perp: 105.1520 [best]


INFO:main:Epoch 14/15 | train_recon: 0.198554 | train_vq: 0.048314 | train_perp: 103.0341 | val_recon: 0.194377 | val_vq: 0.062238 | val_perp: 105.1520 [best]


2026-02-18 07:48:46 | INFO | Epoch 15/15 | train_recon: 0.197477 | train_vq: 0.045772 | train_perp: 108.7538 | val_recon: 0.194921 | val_vq: 0.058581 | val_perp: 111.1148 [best]


INFO:main:Epoch 15/15 | train_recon: 0.197477 | train_vq: 0.045772 | train_perp: 108.7538 | val_recon: 0.194921 | val_vq: 0.058581 | val_perp: 111.1148 [best]


2026-02-18 07:48:46 | INFO | Training finished.


INFO:main:Training finished.


2026-02-18 07:48:46 | INFO |   best_val_loss: 0.253502


INFO:main:  best_val_loss: 0.253502


2026-02-18 07:48:46 | INFO |   checkpoint: /content/drive/MyDrive/semcom_asd_vqar/checkpoints/models/mobilenetv2_8x_vqvae.pth


INFO:main:  checkpoint: /content/drive/MyDrive/semcom_asd_vqar/checkpoints/models/mobilenetv2_8x_vqvae.pth


2026-02-18 07:48:46 | INFO | Phase 2: Training PixelSNAIL prior (anomaly criterion)


  WeightNorm.apply(module, name, dim)
INFO:main:Phase 2: Training PixelSNAIL prior (anomaly criterion)


2026-02-18 07:51:12 | INFO | Prior Epoch 1/30 | train_nll: 4.895500 | val_nll: 3.405912 [best]


INFO:main:Prior Epoch 1/30 | train_nll: 4.895500 | val_nll: 3.405912 [best]


2026-02-18 07:53:37 | INFO | Prior Epoch 2/30 | train_nll: 2.916703 | val_nll: 2.630814 [best]


INFO:main:Prior Epoch 2/30 | train_nll: 2.916703 | val_nll: 2.630814 [best]


2026-02-18 07:56:03 | INFO | Prior Epoch 3/30 | train_nll: 2.503174 | val_nll: 2.390116 [best]


INFO:main:Prior Epoch 3/30 | train_nll: 2.503174 | val_nll: 2.390116 [best]


2026-02-18 07:58:28 | INFO | Prior Epoch 4/30 | train_nll: 2.356417 | val_nll: 2.298385 [best]


INFO:main:Prior Epoch 4/30 | train_nll: 2.356417 | val_nll: 2.298385 [best]


2026-02-18 08:00:53 | INFO | Prior Epoch 5/30 | train_nll: 2.286344 | val_nll: 2.245523 [best]


INFO:main:Prior Epoch 5/30 | train_nll: 2.286344 | val_nll: 2.245523 [best]


2026-02-18 08:03:17 | INFO | Prior Epoch 6/30 | train_nll: 2.249933 | val_nll: 2.224596 [best]


INFO:main:Prior Epoch 6/30 | train_nll: 2.249933 | val_nll: 2.224596 [best]


2026-02-18 08:05:42 | INFO | Prior Epoch 7/30 | train_nll: 2.223256 | val_nll: 2.201829 [best]


INFO:main:Prior Epoch 7/30 | train_nll: 2.223256 | val_nll: 2.201829 [best]


2026-02-18 08:08:08 | INFO | Prior Epoch 8/30 | train_nll: 2.205198 | val_nll: 2.180064 [best]


INFO:main:Prior Epoch 8/30 | train_nll: 2.205198 | val_nll: 2.180064 [best]


2026-02-18 08:10:33 | INFO | Prior Epoch 9/30 | train_nll: 2.187623 | val_nll: 2.184603


INFO:main:Prior Epoch 9/30 | train_nll: 2.187623 | val_nll: 2.184603


2026-02-18 08:12:57 | INFO | Prior Epoch 10/30 | train_nll: 2.175998 | val_nll: 2.165101 [best]


INFO:main:Prior Epoch 10/30 | train_nll: 2.175998 | val_nll: 2.165101 [best]


2026-02-18 08:15:23 | INFO | Prior Epoch 11/30 | train_nll: 2.166103 | val_nll: 2.149232 [best]


INFO:main:Prior Epoch 11/30 | train_nll: 2.166103 | val_nll: 2.149232 [best]


2026-02-18 08:17:48 | INFO | Prior Epoch 12/30 | train_nll: 2.156976 | val_nll: 2.145380 [best]


INFO:main:Prior Epoch 12/30 | train_nll: 2.156976 | val_nll: 2.145380 [best]


2026-02-18 08:20:13 | INFO | Prior Epoch 13/30 | train_nll: 2.149378 | val_nll: 2.135883 [best]


INFO:main:Prior Epoch 13/30 | train_nll: 2.149378 | val_nll: 2.135883 [best]


2026-02-18 08:22:40 | INFO | Prior Epoch 14/30 | train_nll: 2.143674 | val_nll: 2.134573 [best]


INFO:main:Prior Epoch 14/30 | train_nll: 2.143674 | val_nll: 2.134573 [best]


2026-02-18 08:25:08 | INFO | Prior Epoch 15/30 | train_nll: 2.139055 | val_nll: 2.128314 [best]


INFO:main:Prior Epoch 15/30 | train_nll: 2.139055 | val_nll: 2.128314 [best]


2026-02-18 08:27:34 | INFO | Prior Epoch 16/30 | train_nll: 2.128778 | val_nll: 2.123200 [best]


INFO:main:Prior Epoch 16/30 | train_nll: 2.128778 | val_nll: 2.123200 [best]


2026-02-18 08:29:59 | INFO | Prior Epoch 17/30 | train_nll: 2.127343 | val_nll: 2.116858 [best]


INFO:main:Prior Epoch 17/30 | train_nll: 2.127343 | val_nll: 2.116858 [best]


2026-02-18 08:32:24 | INFO | Prior Epoch 18/30 | train_nll: 2.119571 | val_nll: 2.115822 [best]


INFO:main:Prior Epoch 18/30 | train_nll: 2.119571 | val_nll: 2.115822 [best]


2026-02-18 08:34:49 | INFO | Prior Epoch 19/30 | train_nll: 2.115727 | val_nll: 2.112962 [best]


INFO:main:Prior Epoch 19/30 | train_nll: 2.115727 | val_nll: 2.112962 [best]


2026-02-18 08:37:14 | INFO | Prior Epoch 20/30 | train_nll: 2.111755 | val_nll: 2.108480 [best]


INFO:main:Prior Epoch 20/30 | train_nll: 2.111755 | val_nll: 2.108480 [best]


2026-02-18 08:39:39 | INFO | Prior Epoch 21/30 | train_nll: 2.106296 | val_nll: 2.102051 [best]


INFO:main:Prior Epoch 21/30 | train_nll: 2.106296 | val_nll: 2.102051 [best]


2026-02-18 08:42:04 | INFO | Prior Epoch 22/30 | train_nll: 2.104457 | val_nll: 2.110945


INFO:main:Prior Epoch 22/30 | train_nll: 2.104457 | val_nll: 2.110945


2026-02-18 08:44:28 | INFO | Prior Epoch 23/30 | train_nll: 2.099117 | val_nll: 2.097527 [best]


INFO:main:Prior Epoch 23/30 | train_nll: 2.099117 | val_nll: 2.097527 [best]


2026-02-18 08:46:54 | INFO | Prior Epoch 24/30 | train_nll: 2.096569 | val_nll: 2.092134 [best]


INFO:main:Prior Epoch 24/30 | train_nll: 2.096569 | val_nll: 2.092134 [best]


2026-02-18 08:49:19 | INFO | Prior Epoch 25/30 | train_nll: 2.092177 | val_nll: 2.095285


INFO:main:Prior Epoch 25/30 | train_nll: 2.092177 | val_nll: 2.095285


2026-02-18 08:51:43 | INFO | Prior Epoch 26/30 | train_nll: 2.089233 | val_nll: 2.090580 [best]


INFO:main:Prior Epoch 26/30 | train_nll: 2.089233 | val_nll: 2.090580 [best]


2026-02-18 08:54:09 | INFO | Prior Epoch 27/30 | train_nll: 2.082852 | val_nll: 2.085929 [best]


INFO:main:Prior Epoch 27/30 | train_nll: 2.082852 | val_nll: 2.085929 [best]


2026-02-18 08:56:34 | INFO | Prior Epoch 28/30 | train_nll: 2.082605 | val_nll: 2.082112 [best]


INFO:main:Prior Epoch 28/30 | train_nll: 2.082605 | val_nll: 2.082112 [best]


2026-02-18 08:58:58 | INFO | Prior Epoch 29/30 | train_nll: 2.079482 | val_nll: 2.084490


INFO:main:Prior Epoch 29/30 | train_nll: 2.079482 | val_nll: 2.084490


2026-02-18 09:01:21 | INFO | Prior Epoch 30/30 | train_nll: 2.075107 | val_nll: 2.081413 [best]


INFO:main:Prior Epoch 30/30 | train_nll: 2.075107 | val_nll: 2.081413 [best]


2026-02-18 09:01:21 | INFO | Prior training finished. best_val_nll: 2.081413 | checkpoint: ./checkpoints/models/pixelsnail_prior.pth


INFO:main:Prior training finished. best_val_nll: 2.081413 | checkpoint: ./checkpoints/models/pixelsnail_prior.pth


## 6. Evaluate — ROC AUC comparison

Run evaluation on the test set and compare **ROC AUC** for the two anomaly criteria (same as `src.engine.test` and `python -m src.main --config configs/colab.yaml --mode eval`):

1. **Reconstruction (MSE)** — higher reconstruction error = more anomalous  
2. **PixelSNAIL NLL** — higher negative log-likelihood on VQ codes = more anomalous  

Uses the same config and checkpoints as above (from the Config cell). Run **after** training (or if you already have checkpoints).

In [7]:
from pathlib import Path
from src.utils.config import load_config, deep_merge
from src.utils.logger import get_logger, log_config
from src.engine.test import run_evaluation

# Same config as training (follows test module and main program)
config = load_config(os.path.join(PROJECT_ROOT, "configs", "colab.yaml"))
deep_merge(config, OVERRIDES)
config["mode"] = "eval"

log_path = Path(LOG_DIR) / "main_eval.log"
log_path.parent.mkdir(parents=True, exist_ok=True)
logger = get_logger("main", log_file=log_path)
log_config(logger, config=config)
logger.info("Mode: eval")

results = run_evaluation(config, logger)

# Compare ROC AUC for the two anomaly criteria
if results:
    print("\n" + "=" * 52)
    print("  ROC AUC comparison (higher = better anomaly detection)")
    print("=" * 52)
    print(f"  Criterion              | ROC AUC")
    print("-" * 52)
    print(f"  Reconstruction (MSE)    | {results['auc_mse']:.4f}")
    print(f"  PixelSNAIL NLL         | {results['auc_nll']:.4f}")
    print("=" * 52)
    print(f"  Test samples: {results['num_test_samples']}")
    better = "PixelSNAIL NLL" if results["auc_nll"] >= results["auc_mse"] else "Reconstruction (MSE)"
    print(f"  Better criterion: {better}")

2026-02-18 09:20:46 | INFO | Config: {'mode': 'eval', 'checkpoints': {'dir': './checkpoints'}, 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 15, 'lr': 0.002, 'checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 30, 'lr': 0.0002, 'checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


INFO:main:Config: {'mode': 'eval', 'checkpoints': {'dir': './checkpoints'}, 'data': {'root_dir': '/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset', 'appliance': 'fan', 'test_size': 0.1, 'batch_size': 32, 'random_state': 42, 'max_samples_stats': None}, 'device': 'cuda', 'phase1': {'num_epochs': 15, 'lr': 0.002, 'checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth'}, 'phase2': {'num_epochs': 30, 'lr': 0.0002, 'checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'eval': {'vqvae_checkpoint': './checkpoints/models/mobilenetv2_8x_vqvae.pth', 'prior_checkpoint': './checkpoints/models/pixelsnail_prior.pth'}, 'logging': {'log_dir': './logs', 'name': 'main'}}


2026-02-18 09:20:46 | INFO | Mode: eval


INFO:main:Mode: eval


2026-02-18 09:20:46 | INFO | Evaluation config: root_dir=/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset, appliance=fan


INFO:main:Evaluation config: root_dir=/content/drive/MyDrive/datasets/dcase2020-task2-dev-dataset, appliance=fan


2026-02-18 09:20:46 | INFO | Device: cuda


INFO:main:Device: cuda


2026-02-18 09:20:46 | INFO | Loaded normalization stats from /content/drive/MyDrive/semcom_asd_vqar/checkpoints/stats: mean=-13.0156, std=6.9709


INFO:main:Loaded normalization stats from /content/drive/MyDrive/semcom_asd_vqar/checkpoints/stats: mean=-13.0156, std=6.9709


2026-02-18 09:20:52 | INFO | Test samples: 1875


INFO:main:Test samples: 1875


2026-02-18 09:20:52 | INFO | Loading VQ-VAE from ./checkpoints/models/mobilenetv2_8x_vqvae.pth


INFO:main:Loading VQ-VAE from ./checkpoints/models/mobilenetv2_8x_vqvae.pth


2026-02-18 09:20:52 | INFO | Loading prior from ./checkpoints/models/pixelsnail_prior.pth


INFO:main:Loading prior from ./checkpoints/models/pixelsnail_prior.pth


2026-02-18 09:20:52 | INFO | Computing anomaly scores...


INFO:main:Computing anomaly scores...


2026-02-18 09:32:57 | INFO | Evaluation finished.


INFO:main:Evaluation finished.


2026-02-18 09:32:57 | INFO |   num_test_samples: 1875


INFO:main:  num_test_samples: 1875


2026-02-18 09:32:57 | INFO |   vqvae_checkpoint: /content/drive/MyDrive/semcom_asd_vqar/checkpoints/models/mobilenetv2_8x_vqvae.pth


INFO:main:  vqvae_checkpoint: /content/drive/MyDrive/semcom_asd_vqar/checkpoints/models/mobilenetv2_8x_vqvae.pth


2026-02-18 09:32:57 | INFO |   prior_checkpoint: /content/drive/MyDrive/semcom_asd_vqar/checkpoints/models/pixelsnail_prior.pth


INFO:main:  prior_checkpoint: /content/drive/MyDrive/semcom_asd_vqar/checkpoints/models/pixelsnail_prior.pth


2026-02-18 09:32:57 | INFO |   ROC AUC (MSE criterion):  0.5542


INFO:main:  ROC AUC (MSE criterion):  0.5542


2026-02-18 09:32:57 | INFO |   ROC AUC (PixelSNAIL NLL): 0.5624


INFO:main:  ROC AUC (PixelSNAIL NLL): 0.5624
