# ENERGIZE NILM — Post-Training Static Quantization (PTQ)

This notebook loads a **fine-tuned pruned model** (output of `03_pruning.ipynb`) and applies
**Static Post-Training Quantization** using PyTorch's `torch.ao.quantization` with the
`fbgemm` backend (x86 CPU).

Pipeline:
1. **Configure** — choose model, appliance and pruning ratio (must match `03_pruning.ipynb`)
2. **Load data** — same split used for pruning
3. **Rebuild pruned model** — fresh baseline → re-apply pruning → load fine-tuned weights
4. **Apply Static PTQ** — insert observers → calibrate → convert to INT8
5. **Evaluate** — test-set metrics on the quantized model
6. **Export** — append quantized row to existing Excel, save quantized model

> Pruning / evaluation utilities live in `src_pytorch/pruner.py`.

---

## Google Colab Setup
1. Upload your `ENERGIZE` folder to Google Drive
2. Run the cell below first and edit `DRIVE_PROJECT_PATH`

In [None]:
# ============================================================================
# COLAB SETUP — run this cell first
# ============================================================================
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

    import subprocess
    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'torch_pruning', 'openpyxl'])

    # =========================================================================
    DRIVE_PROJECT_PATH = '/content/drive/MyDrive/ENERGIZE'  # <-- EDIT THIS
    # =========================================================================

    import os
    from pathlib import Path
    project_root = Path(DRIVE_PROJECT_PATH)
    if not project_root.exists():
        raise FileNotFoundError(f"Project folder not found: {project_root}")
    os.chdir(project_root)
    sys.path.insert(0, str(project_root))
    print(f"Project root: {project_root}")
else:
    import os
    from pathlib import Path
    project_root = Path(os.getcwd()).parent
    sys.path.insert(0, str(project_root))
    print(f"Running locally. Project root: {project_root}")

## 1. Imports

In [None]:
import torch
import torch.nn as nn
import torch.ao.quantization as tq
import pandas as pd
from types import SimpleNamespace
from tqdm import tqdm

# NILM package
from src_pytorch import (
    CNN_NILM, TCN_NILM,
    SimpleNILMDataLoader,
    set_seeds, get_device,
    apply_torch_pruning,
    get_model_stats,
    evaluate_model,
)

set_seeds(42)
device = get_device()

print(f"PyTorch version : {torch.__version__}")
print(f"Device          : {device}")
print(f"Quantized engine available: fbgemm" if hasattr(torch.backends, 'quantized') else "")

## 2. Configuration

**Edit only this cell.** The `MODEL_NAME`, `APPLIANCE_NAME` and `PRUNING_RATIO` must match
what was used in `03_pruning.ipynb` so the correct fine-tuned checkpoint is loaded.

In [None]:
# ============================================================================
# USER CONFIGURATION  (must match 03_pruning.ipynb settings)
# ============================================================================
MODEL_NAME      = 'tcn'      # 'cnn'  |  'tcn'
APPLIANCE_NAME  = 'boiler'   # 'boiler'  |  'ac_1'  |  'washing_machine'
PRUNING_RATIO   = 0.5        # must match the pruning ratio used in 03_pruning.ipynb
N_CALIB_BATCHES = 100        # number of training batches used for PTQ calibration
DATASET_NAME    = 'plegma'

# ============================================================================
# AUTO-DERIVED — do not edit below this line
# ============================================================================

_MODEL_CFGS = {
    'cnn': {
        'window'          : 299,
        'batch_size'      : 1024,
        'args_window_size': 1,
    },
    'tcn': {
        'window'          : 600,
        'batch_size'      : 50,
        'depth'           : 9,
        'filters'         : [512, 256, 256, 128, 128, 256, 256, 256, 512],
        'dropout'         : 0.2,
        'stacks'          : 1,
        'args_window_size': 600,
    },
}

_APPLIANCE_CFGS = {
    'boiler'          : {'threshold': 50, 'cutoff': 5000},
    'ac_1'            : {'threshold': 50, 'cutoff': 2300},
    'washing_machine' : {'threshold': 50, 'cutoff': 2600},
}

cfg     = _MODEL_CFGS[MODEL_NAME]
app_cfg = _APPLIANCE_CFGS[APPLIANCE_NAME]

WINDOW     = cfg['window']
BATCH_SIZE = cfg['batch_size']
THRESHOLD  = app_cfg['threshold']
CUTOFF     = app_cfg['cutoff']
pct        = int(PRUNING_RATIO * 100)

DATA_DIR    = project_root / 'data' / 'processed' / DATASET_NAME / APPLIANCE_NAME
CKPT_PATH   = project_root / 'outputs' / f'{MODEL_NAME}_{APPLIANCE_NAME}' / 'checkpoint' / 'model.pt'
RESULTS_DIR = project_root / 'outputs' / f'{MODEL_NAME}_{APPLIANCE_NAME}' / 'comparative_results'
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

finetuned_ckpt = RESULTS_DIR / f'{MODEL_NAME}_{APPLIANCE_NAME}_pruned_{pct}pct_finetuned.pt'

print(f"Model          : {MODEL_NAME.upper()}")
print(f"Appliance      : {APPLIANCE_NAME}")
print(f"Pruning ratio  : {pct}%")
print(f"Calib batches  : {N_CALIB_BATCHES}")
print(f"Window         : {WINDOW}")
print(f"Threshold      : {THRESHOLD} W  |  Cutoff: {CUTOFF} W")
print(f"Baseline ckpt  : {CKPT_PATH}  ({'found' if CKPT_PATH.exists() else 'NOT FOUND'})")
print(f"FT pruned ckpt : {finetuned_ckpt}  ({'found' if finetuned_ckpt.exists() else 'NOT FOUND'})")
print(f"Results dir    : {RESULTS_DIR}")

## 3. Load Data

In [None]:
data_loader = SimpleNILMDataLoader(
    data_dir=str(DATA_DIR),
    model_name=MODEL_NAME,
    batch_size=BATCH_SIZE,
    input_window_length=WINDOW,
    train=True,
    num_workers=0
)

print(f"Train batches : {len(data_loader.train)}")
print(f"Val   batches : {len(data_loader.val)}")
print(f"Test  batches : {len(data_loader.test)}")

## 4. Helper — Model Factory

`rebuild_pruned_model` reproduces the pruned architecture deterministically:
1. Build a fresh baseline model and load the original checkpoint
2. Re-apply the same pruning (magnitude-based, same ratio → same channels removed)
3. Replace the weights with the fine-tuned checkpoint

In [None]:
def build_model(model_name: str, cfg: dict, ckpt_path, target_device) -> nn.Module:
    """Instantiate and load a NILM model from a checkpoint."""
    if model_name == 'cnn':
        model = CNN_NILM(input_window_length=cfg['window'])
    elif model_name == 'tcn':
        model = TCN_NILM(
            input_window_length=cfg['window'],
            depth=cfg['depth'],
            nb_filters=cfg['filters'],
            dropout=cfg['dropout'],
            stacks=cfg['stacks'],
        )
    else:
        raise ValueError(f"Unknown model: {model_name}")

    model.load_state_dict(torch.load(ckpt_path, map_location=target_device))
    return model.to(target_device).eval()


def rebuild_pruned_model(model_name, cfg, baseline_ckpt, pruning_ratio, finetuned_ckpt_path, target_device):
    """
    Rebuild a pruned architecture and load fine-tuned weights.

    Steps
    -----
    1. Instantiate baseline model with original weights
    2. Re-apply the same structured pruning (deterministic: same weights → same channels pruned)
    3. Load the fine-tuned state_dict into the pruned architecture
    """
    if not finetuned_ckpt_path.exists():
        raise FileNotFoundError(
            f"Fine-tuned checkpoint not found:\n  {finetuned_ckpt_path}\n"
            "Run 03_pruning.ipynb first with matching MODEL_NAME / APPLIANCE_NAME / PRUNING_RATIO."
        )

    # Step 1: fresh baseline
    model = build_model(model_name, cfg, baseline_ckpt, target_device)

    # Step 2: re-apply pruning to reproduce the pruned architecture
    pruning_args = SimpleNamespace(window_size=cfg['args_window_size'])
    dummy        = torch.randn(1, cfg['window']).to(target_device)
    model, _     = apply_torch_pruning(model, pruning_args, dummy, pruning_ratio)

    # Step 3: load fine-tuned weights
    model.load_state_dict(torch.load(finetuned_ckpt_path, map_location=target_device))
    return model.eval()


print("build_model() and rebuild_pruned_model() defined.")

## 5. Load Fine-tuned Pruned Model

The quantization pipeline starts from the fine-tuned pruned model.
The `fbgemm` backend requires the model to run on **CPU**.

In [None]:
# fbgemm (x86) requires CPU
quant_device = torch.device('cpu')

print("Rebuilding pruned model and loading fine-tuned weights...")
quant_model = rebuild_pruned_model(
    MODEL_NAME, cfg, CKPT_PATH, PRUNING_RATIO, finetuned_ckpt, quant_device
)

# Record model cost *before* quantization (architecture is unchanged by PTQ)
dummy_input   = torch.randn(1, WINDOW).to(quant_device)
pruned_params, pruned_macs, pruned_mb = get_model_stats(quant_model, dummy_input)

print(f"\nPruned model loaded (pre-quantization):")
print(f"  Parameters : {pruned_params:,}")
print(f"  MACs       : {pruned_macs:,}")
print(f"  Size       : {pruned_mb:.3f} MB")

## 6. Apply Static Post-Training Quantization (PTQ)

Three-step process:
1. **Prepare** — insert activation observers into the model
2. **Calibrate** — run forward passes to collect activation statistics
3. **Convert** — replace float ops with INT8 quantized ops

In [None]:
def _fix_string_padding(model: nn.Module) -> None:
    """Replace Conv1d string padding ('same'/'valid') with integers.
    The quantization C++ backend requires integer padding, not strings.
    """
    for m in model.modules():
        if isinstance(m, nn.Conv1d) and isinstance(m.padding, str):
            if m.padding == 'valid':
                m.padding = (0,)
            elif m.padding == 'same':
                k, d = m.kernel_size[0], m.dilation[0]
                m.padding = ((k - 1) * d // 2,)


class QuantWrapper(nn.Module):
    """
    Wraps a float model with explicit QuantStub / DeQuantStub boundaries.

    QuantStub  — converts float input tensor to a quantized tensor
    DeQuantStub — converts the quantized output back to float

    This defines a clean quantization domain for eager-mode PTQ without
    symbolic tracing, so models with control flow (if x.dim() == ...) work.
    """
    def __init__(self, base: nn.Module):
        super().__init__()
        self.quant   = tq.QuantStub()
        self.base    = base
        self.dequant = tq.DeQuantStub()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.quant(x)      # float → quantized
        x = self.base(x)       # model runs on QuantizedCPU
        x = self.dequant(x)    # quantized → float (eval code receives floats)
        return x


# Patch string padding before wrapping
_fix_string_padding(quant_model)

# --- Step 1: Wrap and configure ---
torch.backends.quantized.engine = "onednn"

wrapped = QuantWrapper(quant_model).eval()
wrapped.qconfig = tq.get_default_qconfig("onednn")

# --- Step 2: Prepare (eager mode — no tracing, no TraceError) ---
tq.prepare(wrapped, inplace=True)
print("Observers inserted.")

# --- Step 3: Calibrate ---
print(f"Calibrating on up to {N_CALIB_BATCHES} training batches...")
with torch.inference_mode():
    for i, (x, _) in enumerate(tqdm(data_loader.train, desc="Calibration")):
        wrapped(x.to(quant_device))
        if i >= N_CALIB_BATCHES:
            break

# --- Step 4: Convert to INT8 ---
tq.convert(wrapped, inplace=True)
quant_model = wrapped   # rest of the notebook keeps using quant_model
print("\nStatic quantization applied — model converted to INT8.")

## 7. Evaluate Quantized Model

In [None]:
print(f"Evaluating quantized {MODEL_NAME.upper()} on test set...")
quantized_metrics = evaluate_model(
    model=quant_model,
    data_loader=data_loader,
    model_name=MODEL_NAME,
    cutoff=CUTOFF,
    threshold=THRESHOLD,
    device=quant_device,
    input_window_length=WINDOW,
)

print(f"\nQuantized Model Results:")
for k, v in quantized_metrics.items():
    print(f"  {k:<25}: {v:.4f}")

## 8. Save Quantized Model

In [None]:
quant_ckpt = RESULTS_DIR / f'{MODEL_NAME}_{APPLIANCE_NAME}_pruned_{pct}pct_finetuned_quantized.pt'

# convert_fx returns a GraphModule — save the full model object (not state_dict),
# because the quantized graph structure is required to reload it correctly.
torch.save(quant_model, quant_ckpt)

# Measure on-disk size — captures the actual INT8 weight representation
quant_mb = round(quant_ckpt.stat().st_size / (1024 ** 2), 3)

print(f"Quantized model saved : {quant_ckpt}")
print(f"On-disk size          : {quant_mb:.3f} MB  (vs {pruned_mb:.3f} MB before quantization)")
print(f"Size reduction        : {(1 - quant_mb / pruned_mb) * 100:.1f}%")

## 9. Update Results Excel

A new row for the quantized model is appended to the existing
`{model}_{appliance}_pruning_results.xlsx` created by `03_pruning.ipynb`.

In [None]:
excel_path = RESULTS_DIR / f'{MODEL_NAME}_{APPLIANCE_NAME}_pruning_results.xlsx'

new_row = {
    'Model'           : f'{MODEL_NAME.upper()} Pruned {pct}% + FT + Quantized',
    'Architecture'    : MODEL_NAME.upper(),
    'Appliance'       : APPLIANCE_NAME,
    'Pruning_Ratio_%' : pct,
    'Params'          : pruned_params,
    'MACs'            : pruned_macs,
    'MB'              : quant_mb,
    'MAE'             : round(quantized_metrics['mae'],                  4),
    'F1'              : round(quantized_metrics['f1'],                   4),
    'Precision'       : round(quantized_metrics['precision'],            4),
    'Recall'          : round(quantized_metrics['recall'],               4),
    'Accuracy'        : round(quantized_metrics['accuracy'],             4),
    'Energy_Error_%'  : round(quantized_metrics['energy_error_percent'], 2),
}

if excel_path.exists():
    existing_df = pd.read_excel(excel_path)
    updated_df  = pd.concat([existing_df, pd.DataFrame([new_row])], ignore_index=True)
    print(f"Appending quantized row to existing Excel: {excel_path}")
else:
    updated_df = pd.DataFrame([new_row])
    print(f"Excel not found — creating new file: {excel_path}")

updated_df.to_excel(excel_path, index=False)
print(f"Excel updated: {excel_path}")
updated_df

## 10. Summary

In [None]:
C, W = 22, 18
SEP  = "=" * (C + W * 2 + 2)
sep  = "-" * (C + W * 2 + 2)
hfmt = f"{{:<{C}}} {{:>{W}}} {{:>{W}}}"
rfmt = f"{{:<{C}}} {{:>{W}}} {{:>{W}}}"

col_pruned_ft = f"Pruned {pct}% + FT"
col_quantized = f"+ Quantized (INT8)"

print(SEP)
print(f"QUANTIZATION SUMMARY  |  {MODEL_NAME.upper()}  |  {APPLIANCE_NAME}  |  pruning={pct}%")
print(SEP)
print(hfmt.format('Metric', col_pruned_ft, col_quantized))
print(sep)

# Recover fine-tuned metrics from Excel for comparison
ft_label = f'{MODEL_NAME.upper()} Pruned {pct}% + Fine-tuned 1ep'
if excel_path.exists():
    _df = pd.read_excel(excel_path)
    _ft = _df[_df['Model'] == ft_label]
    if not _ft.empty:
        _ft = _ft.iloc[0]
        ft_mae  = _ft['MAE']
        ft_f1   = _ft['F1']
        ft_prec = _ft['Precision']
        ft_rec  = _ft['Recall']
        ft_acc  = _ft['Accuracy']
        ft_ee   = _ft['Energy_Error_%']
        ft_mb   = _ft['MB']
    else:
        ft_mae = ft_f1 = ft_prec = ft_rec = ft_acc = ft_ee = ft_mb = float('nan')
else:
    ft_mae = ft_f1 = ft_prec = ft_rec = ft_acc = ft_ee = ft_mb = float('nan')

def rrow(label, a, b):
    print(rfmt.format(label, str(a), str(b)))

rrow('MB',           f"{ft_mb:.3f}",              f"{quant_mb:.3f}")
print(sep)
rrow('MAE (W)',      f"{ft_mae:.4f}",              f"{quantized_metrics['mae']:.4f}")
rrow('F1',           f"{ft_f1:.4f}",               f"{quantized_metrics['f1']:.4f}")
rrow('Precision',    f"{ft_prec:.4f}",             f"{quantized_metrics['precision']:.4f}")
rrow('Recall',       f"{ft_rec:.4f}",              f"{quantized_metrics['recall']:.4f}")
rrow('Accuracy',     f"{ft_acc:.4f}",              f"{quantized_metrics['accuracy']:.4f}")
rrow('Energy Err %', f"{ft_ee:.2f}",               f"{quantized_metrics['energy_error_percent']:.2f}")
print(SEP)
print(f"Excel  : {excel_path}")
print(f"Ckpt   : {quant_ckpt.name}")
print(SEP)