In [1]:
#!/usr/bin/env python3
"""
Streamed (memory-safe) trainer for MEG LXe energy regression.

- Input: Npho (4760), split into SiPM(4092)→93x44 + PMT(668)
- Preprocess: log1p(Npho / 2e5)
- Model: CNN(SiPM) → flatten + PMT → FC(128) → FC(1)
- Target: standardized z = (E - μ)/σ   (auto-computed)
- Optional rebalance: "loss" | "resample"
- Streaming: uproot.iterate (no full file in RAM)
- ONNX export
"""

import os, re, time, argparse, warnings
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import uproot
import mlflow
import mlflow.pytorch


# ----------------------------------------------------------------------
#  Utilities
# ----------------------------------------------------------------------
def iterate_chunks(path, tree, branches, step_size=4000):
    """Yield NumPy dicts (float64) chunk-by-chunk from ROOT TTree."""
    with uproot.open(path) as f:
        t = f[tree]
        for arrays in t.iterate(branches, step_size=step_size, library="np"):
            yield arrays


def scan_energy_stats(root, tree, step_size=4000):
    """Single streamed pass to compute mean/std of 'energy'."""
    n, mean, M2 = 0, 0.0, 0.0
    for arr in iterate_chunks(root, tree, ["energy"], step_size):
        E = arr["energy"].astype("float64").ravel()
        for x in E:
            n += 1
            delta = x - mean
            mean += delta / n
            M2 += delta * (x - mean)
    if n < 2:
        raise RuntimeError("Not enough samples for stats.")
    var = M2 / (n - 1)
    return float(mean), float(np.sqrt(max(var, 1e-12)))


def scan_energy_hist(root, tree, nbins=30, step_size=4000):
    """Streamed energy histogram for imbalance reweighting."""
    Emin, Emax = +np.inf, -np.inf
    for arr in iterate_chunks(root, tree, ["energy"], step_size):
        E = arr["energy"].astype("float64").ravel()
        if E.size:
            Emin, Emax = min(Emin, E.min()), max(Emax, E.max())
    if not np.isfinite(Emin) or not np.isfinite(Emax):
        raise RuntimeError("Invalid energy range.")
    edges = np.linspace(Emin, Emax, nbins + 1)
    counts = np.zeros(nbins, np.int64)
    for arr in iterate_chunks(root, tree, ["energy"], step_size):
        E = arr["energy"].astype("float64").ravel()
        if E.size:
            h, _ = np.histogram(E, bins=edges)
            counts += h
    eps = 1e-9
    inv = 1.0 / (counts + eps)
    weights = inv / inv.mean()
    return edges, counts, weights


# ----------------------------------------------------------------------
#  Model
# ----------------------------------------------------------------------
class RegressorSimple(nn.Module):
    def __init__(self, nlayer=16, nfc=128, drop_conv=0.3, drop_fc=0.0):
        super().__init__()
        self.conv1 = nn.Conv2d(1, nlayer, (7,6), stride=(4,4), padding=(3,3))
        self.conv2 = nn.Conv2d(nlayer, 2*nlayer, (4,4), stride=(2,2), padding=(1,1))
        self.conv3 = nn.Conv2d(2*nlayer, 4*nlayer, (2,2), stride=(2,2))
        self.bn1, self.bn2, self.bn3 = (
            nn.BatchNorm2d(nlayer),
            nn.BatchNorm2d(2*nlayer),
            nn.BatchNorm2d(4*nlayer),
        )
        self.act = nn.LeakyReLU(0.1)
        self.dropc = nn.Dropout(drop_conv)
        self.dropf = nn.Dropout(drop_fc)
        self.fc1 = nn.Linear((4*nlayer)*18 + 668, nfc)
        self.fc2 = nn.Linear(nfc, 1)

    def forward(self, pm4760):
        x_mppc, x_pmt = torch.split(pm4760, [4092, 668], dim=1)
        x_mppc = x_mppc.view(-1, 1, 93, 44)
        x_mppc = self.act(self.dropc(self.bn1(self.conv1(x_mppc))))
        x_mppc = self.act(self.dropc(self.bn2(self.conv2(x_mppc))))
        x_mppc = self.act(self.dropc(self.bn3(self.conv3(x_mppc))))
        x_mppc = x_mppc.view(x_mppc.size(0), -1)
        z = torch.cat([x_mppc, x_pmt], dim=1)
        z = self.act(self.dropf(self.fc1(z)))
        out = self.fc2(z)        # linear output (predicts z)
        return out


# ----------------------------------------------------------------------
#  Training / Validation (streamed)
# ----------------------------------------------------------------------
def run_epoch_stream(
    model, optimizer, device, root, tree,
    step_size=4000, batch_size=128, train=True, amp=True,
    rebalance="none", edges=None, weights=None, per_bin_cap=0, max_chunks=None,
    std_target=True, mu=0.0, sigma=1.0
):
    model.train(train)
    loss_fn = nn.SmoothL1Loss(reduction="none")
    device_type = "cuda" if torch.cuda.is_available() else "cpu"
    scaler = torch.amp.GradScaler(device_type, enabled=(amp and device_type == "cuda"))
    total_loss, nobs = 0.0, 0
    all_pred, all_true = [], []
    NphoScale = 2e5

    chunks_done = 0
    for arr in iterate_chunks(root, tree, ["Npho", "energy"], step_size):
        if max_chunks and chunks_done >= max_chunks:
            break
        chunks_done += 1

        PM = np.log1p(arr["Npho"].astype("float32") / NphoScale)
        E = arr["energy"].astype("float32").reshape(-1, 1)
        Ez = (E - mu) / sigma if std_target else E

        # --- balanced resampling per chunk
        if rebalance == "resample" and edges is not None:
            eflat = E.ravel()
            bin_ids = np.clip(np.digitize(eflat, edges) - 1, 0, len(edges)-2)
            idxs_by_bin = [np.where(bin_ids == i)[0] for i in range(len(edges)-1)]
            sizes = [len(ix) for ix in idxs_by_bin if len(ix)]
            if sizes:
                take = per_bin_cap or min(sizes)
                sel = np.concatenate([
                    np.random.choice(ix, take, replace=False)
                    for ix in idxs_by_bin if len(ix) >= take
                ])
                PM, E, Ez = PM[sel], E[sel], Ez[sel]

        # --- dataset (optionally with weights)
        if rebalance == "loss" and edges is not None and weights is not None:
            eflat = E.ravel()
            bin_ids = np.clip(np.digitize(eflat, edges) - 1, 0, len(edges)-2)
            w_np = weights[bin_ids].astype("float32").reshape(-1, 1)
            ds = TensorDataset(torch.from_numpy(Ez), torch.from_numpy(PM), torch.from_numpy(w_np))
        else:
            ds = TensorDataset(torch.from_numpy(Ez), torch.from_numpy(PM))
        loader = DataLoader(ds, batch_size=batch_size, shuffle=train, drop_last=False)

        # --- loop over mini-batches
        for batch in loader:
            if rebalance == "loss":
                Ez_b, PM_b, W_b = batch
                W_b = W_b.to(device)
            else:
                Ez_b, PM_b = batch
                W_b = None
            Ez_b, PM_b = Ez_b.to(device), PM_b.to(device)

            if train:
                optimizer.zero_grad(set_to_none=True)
                with torch.amp.autocast(device_type, enabled=(amp and device_type == "cuda"), dtype=torch.bfloat16):
                    pred_z = model(PM_b)
                    lvec = loss_fn(pred_z, Ez_b)
                    loss = (lvec * W_b).mean() if W_b is not None else lvec.mean()
                if amp and device_type == "cuda":
                    scaler.scale(loss).backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer); scaler.update()
                else:
                    loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
            else:
                with torch.no_grad():
                    pred_z = model(PM_b)
                    lvec = loss_fn(pred_z, Ez_b)
                    loss = (lvec * W_b).mean() if W_b is not None else lvec.mean()
                    pred_e = (pred_z * sigma + mu).cpu().numpy()
                    true_e = (Ez_b * sigma + mu).cpu().numpy()
                    all_pred.append(pred_e); all_true.append(true_e)

            total_loss += loss.item() * PM_b.size(0)
            nobs += PM_b.size(0)
        torch.cuda.empty_cache()

    loss_avg = total_loss / max(1, nobs)
    if not train:
        pred_np = np.concatenate(all_pred).ravel()
        true_np = np.concatenate(all_true).ravel()
        res = pred_np - true_np
        return loss_avg, float(np.nanmean(res)), float(np.nanmean(np.abs(res)))
    return loss_avg, float("nan"), float("nan")


# ----------------------------------------------------------------------
#  Evaluation / Plot
# ----------------------------------------------------------------------
@torch.no_grad()
def predict_stream(model, device, root, tree, step_size=4000, batch_size=256,
                   std_target=True, mu=0.0, sigma=1.0):
    model.eval()
    preds, truths = [], []
    for arr in iterate_chunks(root, tree, ["Npho", "energy"], step_size):
        PM = np.log1p(arr["Npho"].astype("float32") / 2e5)
        E = arr["energy"].astype("float32").reshape(-1, 1)
        PM_t = torch.from_numpy(PM).to(device)
        outs = []
        for i in range(0, len(PM_t), batch_size):
            z = model(PM_t[i:i+batch_size]).cpu().numpy()
            outs.append((z * sigma + mu) if std_target else z)
        preds.append(np.concatenate(outs)); truths.append(E)
        torch.cuda.empty_cache()
    return np.concatenate(preds).ravel(), np.concatenate(truths).ravel()


def eval_plots(pred, true, outfile=None):
    res = pred - true
    plt.figure(figsize=(6,4))
    plt.hist(res, bins=200)
    plt.title("Residuals: Pred - True"); 
    plt.xlabel("ΔE"); 
    plt.ylabel("count")
    plt.tight_layout();
    if outfile is None:
        plt.show()
    else: 
        plt.savefig(outfile, dpi=50)
        plt.close()


# ----------------------------------------------------------------------
#  Main entry
# ----------------------------------------------------------------------
def main_with_args(
    root, tree="tout", epochs=10, batch=256, chunksize=4000, lr=3e-4,
    weight_decay=1e-4, amp=True, rebalance="resample",
    bins=20, per_bin_cap=200, max_chunks=None, onnx="meg2enereg.onnx",
    mlflow_experiment="meg2_energy", run_name=None
):
    root = os.path.expanduser(root)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = RegressorSimple().to(device)
    with torch.no_grad(): 
        model.fc2.bias.fill_(1.374)

    # optimizer (no decay for bias / BN)
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad: 
            continue
        (no_decay if n.endswith(".bias") or "bn" in n.lower() else decay).append(p)
    opt = torch.optim.RMSprop(
        [{"params": decay, "weight_decay": weight_decay},
         {"params": no_decay, "weight_decay": 0.0}],
        lr=lr
    )

    # --- stats + histogram
    mu, sigma = scan_energy_stats(root, tree, step_size=chunksize)
    print(f"[info] target mean={mu:.6f}  std={sigma:.6f}")
    edges = weights = None
    if rebalance in ("loss", "resample"):
        print(f"[info] scanning histogram (bins={bins}) ...")
        edges, counts, weights = scan_energy_hist(root, tree, nbins=bins, step_size=chunksize)
        print(f"[info] non-empty bins {np.sum(counts>0)}/{bins}")

    # --- MLflow + TensorBoard setup
    mlflow.set_experiment(mlflow_experiment)
    if run_name is None:
        run_name = time.strftime("run_%Y%m%d_%H%M%S")
    writer = SummaryWriter(log_dir=os.path.join("runs", run_name))

    with mlflow.start_run(run_name=run_name):
        mlflow.log_params({
            "root": root,
            "tree": tree,
            "epochs": epochs,
            "batch": batch,
            "chunksize": chunksize,
            "lr": lr,
            "weight_decay": weight_decay,
            "amp": amp,
            "rebalance": rebalance,
            "bins": bins,
            "per_bin_cap": per_bin_cap,
            "max_chunks": max_chunks,
            "model_nlayer": 16,
            "model_nfc": 128,
        })
        mlflow.log_metrics({
            "target_mean": mu,
            "target_std": sigma,
        })

        best_val, best_state = float("inf"), None

        # --- training
        for ep in range(1, epochs+1):
            t0 = time.time()
            tr_loss, _, _ = run_epoch_stream(
                model, opt, device, root, tree,
                step_size=chunksize, batch_size=batch,
                train=True, amp=amp,
                rebalance=rebalance, edges=edges, weights=weights,
                per_bin_cap=per_bin_cap, max_chunks=max_chunks,
                std_target=True, mu=mu, sigma=sigma
            )
            val_loss, res_mean, mae = run_epoch_stream(
                model, opt, device, root, tree,
                step_size=chunksize, batch_size=max(batch,256),
                train=False, amp=False,
                rebalance=rebalance, edges=edges, weights=weights,
                per_bin_cap=per_bin_cap, max_chunks=max_chunks,
                std_target=True, mu=mu, sigma=sigma
            )
            sec = time.time() - t0
            print(f"[{ep:03d}] train {tr_loss:.6f}  val {val_loss:.6f}  "
                  f"res_mean {res_mean:+.6f}  MAE {mae:.6f}  time {sec:.1f}s")
    
            # --- log to MLflow per epoch
            mlflow.log_metrics({
                "train_loss": tr_loss,
                "val_loss": val_loss,
                "res_mean": res_mean,
                "mae": mae,
                "epoch_time_sec": sec,
            }, step=ep)
    
            # --- log to TensorBoard
            writer.add_scalar("loss/train", tr_loss, ep)
            writer.add_scalar("loss/val", val_loss, ep)
            writer.add_scalar("metric/mae", mae, ep)
            writer.add_scalar("metric/res_mean", res_mean, ep)
            writer.add_scalar("time/epoch_sec", sec, ep)
            
            if val_loss < best_val:
                best_val, best_state = val_loss, {k:v.cpu() for k,v in model.state_dict().items()}
    
        # --- evaluation
        if best_state: 
            model.load_state_dict(best_state)
        
        pred, true = predict_stream(
            model, device, root, tree,
            step_size=chunksize, batch_size=256,
            std_target=True, mu=mu, sigma=sigma
        )
    
        # save residual plot
        residual_png = f"residuals_{run_name}.png"
        eval_plots(pred, true, outfile=residual_png)
        mlflow.log_artifact(residual_png)
    
        # --- export ONNX and log model
        model.eval()
        dummy = torch.randn(1, 4760, device=device)
        torch.onnx.export(model, dummy, onnx,
                          input_names=["Npho4760"], output_names=["energy_z"])
        print(f"[OK] Exported ONNX to {onnx}")
    
        # log ONNX file + PyTorch model to MLflow
        if os.path.exists(onnx):
            mlflow.log_artifact(onnx)
        mlflow.pytorch.log_model(model, "pytorch_model")
    
        writer.close()


# ----------------------------------------------------------------------
#  Example quick test
# ----------------------------------------------------------------------
# if __name__ == "__main__":
#     main_with_args(
#         root="~/meghome/xec-ml-wl/CWMC.root",
#         tree="tout",
#         epochs=10,
#         batch=256,
#         chunksize=4000,
#         lr=3e-4,
#         amp=True,
#         rebalance="resample",   # or "none" / "loss"
#         bins=20,
#         per_bin_cap=200,
#         max_chunks=3
#     )


In [2]:
main_with_args(
    root="~/meghome/xec-ml-wl/CWMC.root",
    tree="tout",
    epochs=50,
    batch=256,
    chunksize=4000,
    lr=2e-4,
    amp=True,
    rebalance="resample",   # or "none" / "loss"
    bins=20,
    per_bin_cap=200,
    max_chunks=4000,
    mlflow_experiment="past_study_CW_energy",
    run_name="energy_resample_2e-4_50ep"
)

  return FileStore(store_uri, store_uri)
2025/11/17 17:20:50 INFO mlflow.tracking.fluent: Experiment with name 'past_study_CW_energy' does not exist. Creating a new experiment.


[info] target mean=0.806783  std=0.113780
[info] scanning histogram (bins=20) ...
[info] non-empty bins 20/20
[001] train 1.077576  val 1.108456  res_mean +0.181385  MAE 0.181385  time 34.2s
[002] train 1.044638  val 1.061611  res_mean +0.175543  MAE 0.175543  time 17.4s
[003] train 1.020584  val 1.019318  res_mean +0.170198  MAE 0.170198  time 17.4s
[004] train 0.931840  val 0.969394  res_mean +0.163768  MAE 0.163768  time 17.3s
[005] train 0.876044  val 0.881293  res_mean +0.152049  MAE 0.152049  time 17.3s
[006] train 0.831680  val 0.815999  res_mean +0.143030  MAE 0.143030  time 17.4s
[007] train 0.815546  val 0.802568  res_mean +0.141141  MAE 0.141141  time 17.3s
[008] train 0.798796  val 0.797630  res_mean +0.140407  MAE 0.140407  time 17.4s
[009] train 0.779845  val 0.792326  res_mean +0.139672  MAE 0.139672  time 17.4s
[010] train 0.761143  val 0.737944  res_mean +0.131784  MAE 0.131787  time 17.5s
[011] train 0.737634  val 0.752550  res_mean +0.133900  MAE 0.133901  time 17.3s



[OK] Exported ONNX to meg2enereg.onnx


