# ENERGIZE NILM — Structured Pruning, Fine-tuning & Evaluation

This notebook supports **CNN** and **TCN** models and any PLEGMA appliance.

Pipeline:
1. **Configure** — choose model, appliance and pruning ratio in one cell
2. **Baseline** — load the trained checkpoint, measure cost, evaluate on test set
3. **Prune** — apply global structured channel pruning (50% by default)
4. **Evaluate pruned** — test-set metrics immediately after pruning
5. **Fine-tune** — 1-epoch recovery training on the training set
6. **Evaluate fine-tuned** — test-set metrics after fine-tuning
7. **Export** — save all results (Params, MACs, MB + metrics) to Excel

> Pruning functions live in `src_pytorch/pruner.py` and are imported here.

---

## Google Colab Setup
1. Upload your `ENERGIZE` folder to Google Drive
2. Run the cell below first and edit `DRIVE_PROJECT_PATH`

In [None]:
# ============================================================================
# COLAB SETUP — run this cell first
# ============================================================================
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

    import subprocess
    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'torch_pruning', 'openpyxl'])

    # =========================================================================
    DRIVE_PROJECT_PATH = '/content/drive/MyDrive/ENERGIZE'  # <-- EDIT THIS
    # =========================================================================

    import os
    from pathlib import Path
    project_root = Path(DRIVE_PROJECT_PATH)
    if not project_root.exists():
        raise FileNotFoundError(f"Project folder not found: {project_root}")
    os.chdir(project_root)
    sys.path.insert(0, str(project_root))
    print(f"Project root: {project_root}")
else:
    import os
    from pathlib import Path
    project_root = Path(os.getcwd()).parent
    sys.path.insert(0, str(project_root))
    print(f"Running locally. Project root: {project_root}")

## 1. Imports

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch_pruning as tp
from types import SimpleNamespace
from tqdm import tqdm

# NILM package
from src_pytorch import (
    CNN_NILM, TCN_NILM,
    SimpleNILMDataLoader,
    set_seeds, get_device, count_parameters,
    # Pruning utilities (src_pytorch/pruner.py)
    count_parameters_per_layer,
    get_model_stats,
    apply_torch_pruning,
    evaluate_model,
)

set_seeds(42)
device = get_device()

print(f"PyTorch version : {torch.__version__}")
print(f"torch_pruning   : {tp.__version__}")
print(f"Device          : {device}")

## 2. Configuration

**Edit only this cell.** Everything else adapts automatically.

In [None]:
# ============================================================================
# USER CONFIGURATION
# ============================================================================
MODEL_NAME     = 'tcn'      # 'cnn'  |  'tcn'
APPLIANCE_NAME = 'boiler'   # 'boiler'  |  'ac_1'  |  'washing_machine'
PRUNING_RATIO  = 0.5        # fraction of channels to remove  (0.0 – 1.0)
FINETUNE_LR    = 1e-4       # Adam learning rate for 1-epoch fine-tuning
DATASET_NAME   = 'plegma'

# ============================================================================
# AUTO-DERIVED — do not edit below this line
# ============================================================================

# Model-specific architecture parameters
_MODEL_CFGS = {
    'cnn': {
        'window'          : 299,
        'batch_size'      : 1024,
        'args_window_size': 1,      # out_features of final Linear — protects it
    },
    'tcn': {
        'window'          : 600,
        'batch_size'      : 50,
        'depth'           : 9,
        'filters'         : [512, 256, 256, 128, 128, 256, 256, 256, 512],
        'dropout'         : 0.2,
        'stacks'          : 1,
        'args_window_size': 600,    # no Linear layers in TCN, kept for clarity
    },
}

# PLEGMA appliance thresholds and cutoffs
_APPLIANCE_CFGS = {
    'boiler'          : {'threshold': 50, 'cutoff': 5000},
    'ac_1'            : {'threshold': 50, 'cutoff': 2300},
    'washing_machine' : {'threshold': 50, 'cutoff': 2600},
}

cfg     = _MODEL_CFGS[MODEL_NAME]
app_cfg = _APPLIANCE_CFGS[APPLIANCE_NAME]

WINDOW     = cfg['window']
BATCH_SIZE = cfg['batch_size']
THRESHOLD  = app_cfg['threshold']
CUTOFF     = app_cfg['cutoff']

DATA_DIR    = project_root / 'data' / 'processed' / DATASET_NAME / APPLIANCE_NAME
CKPT_PATH   = project_root / 'outputs' / f'{MODEL_NAME}_{APPLIANCE_NAME}' / 'checkpoint' / 'model.pt'
# Results saved alongside the model outputs for this experiment
RESULTS_DIR = project_root / 'outputs' / f'{MODEL_NAME}_{APPLIANCE_NAME}' / 'comparative_results'
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Model          : {MODEL_NAME.upper()}")
print(f"Appliance      : {APPLIANCE_NAME}")
print(f"Pruning ratio  : {PRUNING_RATIO * 100:.0f}%")
print(f"Fine-tune LR   : {FINETUNE_LR}")
print(f"Window length  : {WINDOW}")
print(f"Threshold      : {THRESHOLD} W  |  Cutoff: {CUTOFF} W")
print(f"Checkpoint     : {CKPT_PATH}  ({'found' if CKPT_PATH.exists() else 'NOT FOUND'})")
print(f"Results dir    : {RESULTS_DIR}")

## 3. Load Data

In [None]:
data_loader = SimpleNILMDataLoader(
    data_dir=str(DATA_DIR),
    model_name=MODEL_NAME,
    batch_size=BATCH_SIZE,
    input_window_length=WINDOW,
    train=True,
    num_workers=0
)

print(f"Train batches : {len(data_loader.train)}")
print(f"Val   batches : {len(data_loader.val)}")
print(f"Test  batches : {len(data_loader.test)}")

## 4. Helper — Model Factory

A single function that builds and loads the correct architecture from a checkpoint.

In [None]:
def build_model(model_name: str, cfg: dict, ckpt_path, device) -> nn.Module:
    """Instantiate and load a NILM model from a checkpoint."""
    if model_name == 'cnn':
        model = CNN_NILM(input_window_length=cfg['window'])
    elif model_name == 'tcn':
        model = TCN_NILM(
            input_window_length=cfg['window'],
            depth=cfg['depth'],
            nb_filters=cfg['filters'],
            dropout=cfg['dropout'],
            stacks=cfg['stacks'],
        )
    else:
        raise ValueError(f"Unknown model: {model_name}. Choose 'cnn' or 'tcn'.")

    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    return model.to(device).eval()


print("build_model() defined.")

## 5. Baseline — Load & Evaluate

In [None]:
# Load baseline model
baseline_model = build_model(MODEL_NAME, cfg, CKPT_PATH, device)
dummy_input    = torch.randn(1, WINDOW).to(device)

baseline_params, baseline_macs, baseline_mb = get_model_stats(baseline_model, dummy_input)

print(f"Model      : {MODEL_NAME.upper()}")
print(f"Parameters : {baseline_params:,}")
print(f"MACs       : {baseline_macs:,}")
print(f"Size       : {baseline_mb:.3f} MB")

In [None]:
# Evaluate baseline on test set
print(f"Evaluating {MODEL_NAME.upper()} baseline on test set...")
baseline_metrics = evaluate_model(
    model=baseline_model,
    data_loader=data_loader,
    model_name=MODEL_NAME,
    cutoff=CUTOFF,
    threshold=THRESHOLD,
    device=device,
    input_window_length=WINDOW,
)

print(f"\nBaseline Results:")
for k, v in baseline_metrics.items():
    print(f"  {k:<25}: {v:.4f}")

## 6. Prune & Evaluate

In [None]:
# Reload a fresh instance — pruning is irreversible
pruned_model = build_model(MODEL_NAME, cfg, CKPT_PATH, device)
pruning_args = SimpleNamespace(window_size=cfg['args_window_size'])

print("=" * 60)
print(f"Pruning {MODEL_NAME.upper()} at {PRUNING_RATIO * 100:.0f}%")
print("=" * 60)

pruned_model, _ = apply_torch_pruning(
    model=pruned_model,
    args=pruning_args,
    inputs=dummy_input,
    pruning_ratio=PRUNING_RATIO,
)

In [None]:
# Model cost after pruning
pruned_params, pruned_macs, pruned_mb = get_model_stats(pruned_model, dummy_input)

print(f"Parameters : {pruned_params:,}  ({(1 - pruned_params/baseline_params)*100:.1f}% reduction)")
print(f"MACs       : {pruned_macs:,}  ({(1 - pruned_macs/baseline_macs)*100:.1f}% reduction)")
print(f"Size       : {pruned_mb:.3f} MB  ({(1 - pruned_mb/baseline_mb)*100:.1f}% reduction)")

print("\nPer-layer parameter counts (after pruning):")
for name, cnt in count_parameters_per_layer(pruned_model).items():
    print(f"  {name:<60} {cnt:>10,}")

In [None]:
# Evaluate pruned model
print(f"Evaluating pruned {MODEL_NAME.upper()} on test set...")
pruned_metrics = evaluate_model(
    model=pruned_model,
    data_loader=data_loader,
    model_name=MODEL_NAME,
    cutoff=CUTOFF,
    threshold=THRESHOLD,
    device=device,
    input_window_length=WINDOW,
)

print(f"\nPruned Results:")
for k, v in pruned_metrics.items():
    print(f"  {k:<25}: {v:.4f}")

In [None]:
# Save pruned checkpoint
pruned_ckpt = RESULTS_DIR / f'{MODEL_NAME}_{APPLIANCE_NAME}_pruned_{int(PRUNING_RATIO*100)}pct.pt'
torch.save(pruned_model.state_dict(), pruned_ckpt)
print(f"Pruned checkpoint saved: {pruned_ckpt}")

## 7. Fine-tune (1 epoch)

One epoch of MSE training on the training set with Adam at `FINETUNE_LR`.
The pruned model structure is preserved — no re-growing of pruned channels.

In [None]:
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=FINETUNE_LR)
loss_fn   = nn.MSELoss()

pruned_model.train()
total_loss  = 0.0
total_mae   = 0.0
n_batches   = 0

for batch_x, batch_y in tqdm(data_loader.train, desc="Fine-tuning epoch 1"):
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)

    optimizer.zero_grad()
    outputs = pruned_model(batch_x)

    # Align target shape to model output
    # CNN  : output (B, 1),          target (B,)       → unsqueeze to (B, 1)
    # TCN  : output (B, seq_len, 1), target (B, seq_len) → unsqueeze to (B, seq_len, 1)
    if MODEL_NAME == 'cnn' and batch_y.dim() == 1:
        batch_y = batch_y.unsqueeze(1)
    elif MODEL_NAME == 'tcn' and batch_y.dim() == 2:
        batch_y = batch_y.unsqueeze(-1)

    loss = loss_fn(outputs, batch_y)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    total_mae  += torch.mean(torch.abs(outputs.detach() - batch_y)).item()
    n_batches  += 1

pruned_model.eval()
print(f"\nFine-tune epoch 1 — avg MSE loss: {total_loss/n_batches:.6f}  "
      f"avg MAE: {total_mae/n_batches:.6f}")

In [None]:
# Save fine-tuned checkpoint
finetuned_ckpt = RESULTS_DIR / f'{MODEL_NAME}_{APPLIANCE_NAME}_pruned_{int(PRUNING_RATIO*100)}pct_finetuned.pt'
torch.save(pruned_model.state_dict(), finetuned_ckpt)
print(f"Fine-tuned checkpoint saved: {finetuned_ckpt}")

## 8. Evaluate Fine-tuned Model

Architecture (and therefore Params / MACs / MB) is unchanged by fine-tuning.

In [None]:
finetuned_params, finetuned_macs, finetuned_mb = get_model_stats(pruned_model, dummy_input)

print(f"Evaluating fine-tuned {MODEL_NAME.upper()} on test set...")
finetuned_metrics = evaluate_model(
    model=pruned_model,
    data_loader=data_loader,
    model_name=MODEL_NAME,
    cutoff=CUTOFF,
    threshold=THRESHOLD,
    device=device,
    input_window_length=WINDOW,
)

print(f"\nFine-tuned Results:")
for k, v in finetuned_metrics.items():
    print(f"  {k:<25}: {v:.4f}")

## 9. Export Results to Excel

In [None]:
def make_row(label, pruning_pct, params, macs, mb, metrics):
    return {
        'Model'           : label,
        'Architecture'    : MODEL_NAME.upper(),
        'Appliance'       : APPLIANCE_NAME,
        'Pruning_Ratio_%' : pruning_pct,
        'Params'          : params,
        'MACs'            : macs,
        'MB'              : mb,
        'MAE'             : round(metrics['mae'],                  4),
        'F1'              : round(metrics['f1'],                   4),
        'Precision'       : round(metrics['precision'],            4),
        'Recall'          : round(metrics['recall'],               4),
        'Accuracy'        : round(metrics['accuracy'],             4),
        'Energy_Error_%'  : round(metrics['energy_error_percent'], 2),
    }


pct = int(PRUNING_RATIO * 100)

results_df = pd.DataFrame([
    make_row(f'{MODEL_NAME.upper()} Baseline',
             0, baseline_params, baseline_macs, baseline_mb, baseline_metrics),
    make_row(f'{MODEL_NAME.upper()} Pruned {pct}%',
             pct, pruned_params, pruned_macs, pruned_mb, pruned_metrics),
    make_row(f'{MODEL_NAME.upper()} Pruned {pct}% + Fine-tuned 1ep',
             pct, finetuned_params, finetuned_macs, finetuned_mb, finetuned_metrics),
])

excel_path = RESULTS_DIR / f'{MODEL_NAME}_{APPLIANCE_NAME}_pruning_results.xlsx'
results_df.to_excel(excel_path, index=False)

print(f"Results saved to: {excel_path}")
results_df

## 10. Summary

In [None]:
C, W3 = 22, 16
SEP   = "=" * (C + W3 * 3 + 3)
sep   = "-" * (C + W3 * 3 + 3)
hfmt  = f"{{:<{C}}} {{:>{W3}}} {{:>{W3}}} {{:>{W3}}}"
rfmt  = f"{{:<{C}}} {{:>{W3}}} {{:>{W3}}} {{:>{W3}}}"

col_baseline  = f"{MODEL_NAME.upper()} Baseline"
col_pruned    = f"Pruned {pct}%"
col_finetuned = f"Pruned+FT 1ep"

print(SEP)
print(f"PRUNING SUMMARY  |  {MODEL_NAME.upper()}  |  {APPLIANCE_NAME}  |  pruning={pct}%  |  LR={FINETUNE_LR}")
print(SEP)
print(hfmt.format('Metric', col_baseline, col_pruned, col_finetuned))
print(sep)

def rrow(label, b, p, f):
    print(rfmt.format(label, b, p, f))

rrow('Params',        f"{baseline_params:,}",   f"{pruned_params:,}",   f"{finetuned_params:,}")
rrow('MACs',          f"{baseline_macs:,}",     f"{pruned_macs:,}",     f"{finetuned_macs:,}")
rrow('MB',            f"{baseline_mb:.3f}",     f"{pruned_mb:.3f}",     f"{finetuned_mb:.3f}")
print(sep)
rrow('MAE (W)',       f"{baseline_metrics['mae']:.4f}",
                      f"{pruned_metrics['mae']:.4f}",
                      f"{finetuned_metrics['mae']:.4f}")
rrow('F1',            f"{baseline_metrics['f1']:.4f}",
                      f"{pruned_metrics['f1']:.4f}",
                      f"{finetuned_metrics['f1']:.4f}")
rrow('Precision',     f"{baseline_metrics['precision']:.4f}",
                      f"{pruned_metrics['precision']:.4f}",
                      f"{finetuned_metrics['precision']:.4f}")
rrow('Recall',        f"{baseline_metrics['recall']:.4f}",
                      f"{pruned_metrics['recall']:.4f}",
                      f"{finetuned_metrics['recall']:.4f}")
rrow('Accuracy',      f"{baseline_metrics['accuracy']:.4f}",
                      f"{pruned_metrics['accuracy']:.4f}",
                      f"{finetuned_metrics['accuracy']:.4f}")
rrow('Energy Err %',  f"{baseline_metrics['energy_error_percent']:.2f}",
                      f"{pruned_metrics['energy_error_percent']:.2f}",
                      f"{finetuned_metrics['energy_error_percent']:.2f}")
print(SEP)
print(f"Excel  : {excel_path}")
print(f"Ckpts  : {pruned_ckpt.name}")
print(f"         {finetuned_ckpt.name}")
print(SEP)