In [None]:
import sys, os

root_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root_path not in sys.path:
    sys.path.append(root_path)

In [None]:
from src.data.load_tinyimagenet import *

train_loader, val_loader, test_loader = get_tinyimagenet200_hf_dataloaders(
    batch_size=64,
    data_dir="./data",
    num_workers=2,
    val_split=0.1,
    img_size=64,
    seed=7)

In [None]:
from src.Model_A_OutGridNet import * 
from src.stage_config import *

from src.training.one_epoch_train import *
from src.training.train_full_model import *

def imagenet200_stages_t4(drop_path=0.08):
    # resoluciones: 64 -> 32 -> 16 -> 8
    return [
        StageCfg(dim=64,  depth=2, num_heads=2,  grid_size=8, outlook_heads=2,  drop_path=drop_path),
        StageCfg(dim=128, depth=3, num_heads=4,  grid_size=8, outlook_heads=4,  drop_path=drop_path),
        StageCfg(dim=256, depth=4, num_heads=8,  grid_size=4, outlook_heads=8,  drop_path=drop_path),
        StageCfg(dim=384, depth=2, num_heads=6,  grid_size=2, outlook_heads=6,  drop_path=drop_path),]

stages = imagenet200_stages_t4(drop_path=0.1)

model = MaxOutNet(
    num_classes=200,
    stages=stages,
    stem_dim=64,
    dpr_max=0.11)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [22]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

n_params = count_trainable_parameters(model)
print(f"Trainable parameters: {n_params:,}")

Trainable parameters: 22,542,628


In [None]:
load_checkpoint('/content/Model_A_tinyimagenet_50epocs.pt' , model=model)

In [31]:
import numpy as np

evaluate_one_epoch(model=model, dataloader=test_loader)

(1.2630521028518678, {'top1': 69.86, 'top3': 83.59, 'top5': 88.03})

---

# Visualizaciones

In [35]:
def _collect_by_stage(all_res):
    by = {}
    for r in all_res:
        s = r["stage"]
        by.setdefault(s, []).append(r)
    return by

def _vals(rs, key):
    return [r[key] for r in rs if r.get(key, None) is not None]

def _mean_std(vals):
    if len(vals) == 0:
        return None, None, 0
    return float(np.mean(vals)), float(np.std(vals)), int(len(vals))

def _fmt(mu, sd, n, digits=4):
    if mu is None:
        return "None"
    return f"{mu:.{digits}f}±{sd:.{digits}f} (n={n})"

def _fmt2(mu, sd, n):
    if mu is None:
        return "None"
    return f"{mu:.2f}±{sd:.2f} (n={n})"

def print_all_res_report(all_res, show_block_breakdown=True):
    by_stage = _collect_by_stage(all_res)
    stages = sorted(by_stage.keys())

    print("\n" + "="*70)
    print("FULL MAD REPORT (Grid vs Outlooker) — BLINDED SAMPLING")
    print("="*70)

    # ---------- STAGE SUMMARY ----------
    for s in stages:
        rs = by_stage[s]
        blocks = sorted({r["block"] for r in rs})
        seeds  = sorted({r["seed"] for r in rs})
        nimgs  = sorted({r["n_images"] for r in rs})

        # config snapshot
        cfg = rs[0]
        gq = cfg.get("grid_n_q", None)
        gB = cfg.get("grid_exclude_border", None)
        gG = cfg.get("grid_avg_over_groups", None)
        oN = cfg.get("out_n_xy", None)
        oB = cfg.get("out_exclude_border", None)

        # featuremap info
        Hfs = sorted(set(_vals(rs, "grid_Hf")))
        Wfs = sorted(set(_vals(rs, "grid_Wf")))
        den = sorted(set([round(x, 6) for x in _vals(rs, "grid_denom")]))

        print(f"\n--- stage {s} ---")
        print(f"blocks: {blocks} | seeds: {seeds} | n_images: {nimgs}")
        print(f"sampling: GRID n_q={gq}, excl_border={gB}, avg_over_groups={gG} | OUT n_xy={oN}, excl_border={oB}")

        if len(Hfs) and len(Wfs) and len(den):
            print(f"grid featuremap: Hf={Hfs} Wf={Wfs} | denom={den}")

        # GRID (norm)
        mu, sd, n = _mean_std(_vals(rs, "MAD_grid_mean"))
        print("GRID norm      :", _fmt(mu, sd, n, digits=4))

        # GRID (abs)
        mu, sd, n = _mean_std(_vals(rs, "MAD_grid_abs_mean"))
        print("GRID abs       :", _fmt2(mu, sd, n), " (units: featuremap L1 pixels)")

        # OUT (norm)
        mu, sd, n = _mean_std(_vals(rs, "MAD_outlook_mean"))
        print("OUT norm       :", _fmt(mu, sd, n, digits=4), " (max=1)")

        # OUT (abs)
        mu, sd, n = _mean_std(_vals(rs, "MAD_outlook_abs_mean"))
        print("OUT abs        :", _fmt2(mu, sd, n), " (max=2; units: 3×3 L1 steps)")

    # ---------- BLOCK BREAKDOWN ----------
    if show_block_breakdown:
        print("\n" + "="*70)
        print("BLOCK BREAKDOWN (per stage -> per block; means across seeds/images)")
        print("="*70)

        for s in stages:
            rs = by_stage[s]
            blocks = sorted({r["block"] for r in rs})
            print(f"\n--- stage {s} ---")
            for b in blocks:
                rb = [r for r in rs if r["block"] == b]

                gN = _mean_std(_vals(rb, "MAD_grid_mean"))
                gA = _mean_std(_vals(rb, "MAD_grid_abs_mean"))
                oN = _mean_std(_vals(rb, "MAD_outlook_mean"))
                oA = _mean_std(_vals(rb, "MAD_outlook_abs_mean"))

                print(
                    f"block {b}: "
                    f"GRID norm={_fmt(*gN)} | GRID abs={_fmt2(*gA)} || "
                    f"OUT norm={_fmt(*oN)} | OUT abs={_fmt2(*oA)}"
                )

    print("\n[Done]")


def print_mad_abs_by_stage_simple(all_res):
    by_stage = {}
    for r in all_res:
        by_stage.setdefault(r["stage"], []).append(r)

    print("\n=== MAD (ABS) by stage — simple view ===")
    print("GRID_abs is in featuremap L1 pixels; max = (Hf-1)+(Wf-1).")
    print("OUT_abs  is in 3×3 L1 steps; max = 2.\n")

    for s in sorted(by_stage.keys()):
        rs = by_stage[s]

        g_mu, g_sd, g_n = _mean_std(_vals(rs, "MAD_grid_abs_mean"))
        o_mu, o_sd, o_n = _mean_std(_vals(rs, "MAD_outlook_abs_mean"))

        denoms = sorted(set(_vals(rs, "grid_denom")))
        Hfs    = sorted(set(_vals(rs, "grid_Hf")))
        Wfs    = sorted(set(_vals(rs, "grid_Wf")))

        if len(denoms) == 1 and len(Hfs) == 1 and len(Wfs) == 1:
            scale = f"GRID max={denoms[0]:.0f} (Hf={Hfs[0]}, Wf={Wfs[0]}) | OUT max=2"
        else:
            scale = f"GRID max≈{(denoms[0] if len(denoms) else None)} | OUT max=2"

        def f2(mu, sd, n):
            if mu is None:
                return "None"
            return f"{mu:.2f}±{sd:.2f} (n={n})"

        print(f"stage {s}:  GRID_abs={f2(g_mu,g_sd,g_n)}   |   OUT_abs={f2(o_mu,o_sd,o_n)}   |   {scale}")


In [None]:
from src.experiments.entropy_metrics import * 
from src.experiments.mad_metrics import *

def run_mad_pipeline(
    model,
    test_loader,
    stage_depths,
    seeds=(0,1,2),
    n_images=128,
    grid_n_q=32,
    out_n_xy=64,
    device="cuda"):

    all_res = []
    for seed in seeds:
        for s in [0,1,2,3]:
            for b in range(stage_depths[s]):
                all_res += compute_grid_and_outlooker_mad_by_stage(
                    model=model,
                    loader=test_loader,
                    block_idx=b,
                    stages=(s,),
                    n_images=n_images,
                    seed=seed,
                    device=device,
                    normalize_grid=True,
                    grid_n_q=grid_n_q,
                    grid_exclude_border=1,
                    grid_avg_over_groups=True,
                    out_n_xy=out_n_xy,
                    out_exclude_border=1,)

    return all_res

In [42]:
stage_depths = {0:2, 1:3, 2:4, 3:2}
all_res = run_mad_pipeline(model, test_loader, stage_depths, seeds=(3,1,2), n_images=128, grid_n_q=32, out_n_xy=64)

In [43]:
print_all_res_report(all_res, show_block_breakdown=True)


FULL MAD REPORT (Grid vs Outlooker) — BLINDED SAMPLING

--- stage 0 ---
blocks: [0, 1] | seeds: [1, 2, 3] | n_images: [64]
sampling: GRID n_q=32, excl_border=1, avg_over_groups=True | OUT n_xy=64, excl_border=1
grid featuremap: Hf=[64] Wf=[64] | denom=[126.0]
GRID norm      : 0.2504±0.0100 (n=6)
GRID abs       : 31.55±1.26 (n=6)  (units: featuremap L1 pixels)
OUT norm       : 0.6629±0.0654 (n=6)  (max=1)
OUT abs        : 1.33±0.13 (n=6)  (max=2; units: 3×3 L1 steps)

--- stage 1 ---
blocks: [0, 1, 2] | seeds: [1, 2, 3] | n_images: [64]
sampling: GRID n_q=32, excl_border=1, avg_over_groups=True | OUT n_xy=64, excl_border=1
grid featuremap: Hf=[32] Wf=[32] | denom=[62.0]
GRID norm      : 0.2484±0.0024 (n=9)
GRID abs       : 15.40±0.15 (n=9)  (units: featuremap L1 pixels)
OUT norm       : 0.8351±0.0623 (n=9)  (max=1)
OUT abs        : 1.67±0.12 (n=9)  (max=2; units: 3×3 L1 steps)

--- stage 2 ---
blocks: [0, 1, 2, 3] | seeds: [1, 2, 3] | n_images: [64]
sampling: GRID n_q=32, excl_border=1

In [44]:
print_mad_abs_by_stage_simple(all_res)


=== MAD (ABS) by stage — simple view ===
GRID_abs is in featuremap L1 pixels; max = (Hf-1)+(Wf-1).
OUT_abs  is in 3×3 L1 steps; max = 2.

stage 0:  GRID_abs=31.55±1.26 (n=6)   |   OUT_abs=1.33±0.13 (n=6)   |   GRID max=126 (Hf=64, Wf=64) | OUT max=2
stage 1:  GRID_abs=15.40±0.15 (n=9)   |   OUT_abs=1.67±0.12 (n=9)   |   GRID max=62 (Hf=32, Wf=32) | OUT max=2
stage 2:  GRID_abs=7.87±0.17 (n=12)   |   OUT_abs=1.68±0.23 (n=12)   |   GRID max=30 (Hf=16, Wf=16) | OUT max=2
stage 3:  GRID_abs=4.47±0.34 (n=6)   |   OUT_abs=1.74±0.14 (n=6)   |   GRID max=14 (Hf=8, Wf=8) | OUT max=2


---



In [None]:

def print_all_res_report_mad_entropy(all_res, show_block_breakdown=True):
    by_stage = _collect_by_stage(all_res)
    stages = sorted(by_stage.keys())

    print("\n" + "="*70)
    print("FULL REPORT (MAD + Entropy) — Grid vs Outlooker")
    print("="*70)

    for s in stages:
        rs = by_stage[s]
        blocks = sorted({r["block"] for r in rs})
        seeds  = sorted({r["seed"] for r in rs})
        nimgs  = sorted({r["n_images"] for r in rs})

        cfg = rs[0]
        print(f"\n--- stage {s} ---")
        print(f"blocks: {blocks} | seeds: {seeds} | n_images: {nimgs}")
        print(f"sampling: GRID n_q={cfg.get('grid_n_q')}, excl_border={cfg.get('grid_exclude_border')}, avg_over_groups={cfg.get('grid_avg_over_groups')} | "
              f"OUT n_xy={cfg.get('out_n_xy')}, excl_border={cfg.get('out_exclude_border')}")

        Hfs = sorted(set(_vals(rs, "grid_Hf")))
        Wfs = sorted(set(_vals(rs, "grid_Wf")))
        den = sorted(set([round(x, 6) for x in _vals(rs, "grid_denom")]))
        if len(Hfs) and len(Wfs) and len(den):
            print(f"grid featuremap: Hf={Hfs} Wf={Wfs} | denom={den}")

        # MAD
        mu, sd, n = _mean_std(_vals(rs, "MAD_grid_mean"))
        print("GRID MAD norm  :", _fmt(mu, sd, n, digits=4))
        mu, sd, n = _mean_std(_vals(rs, "MAD_grid_abs_mean"))
        print("GRID MAD abs   :", _fmt2(mu, sd, n), " (featuremap L1)")

        mu, sd, n = _mean_std(_vals(rs, "MAD_outlook_mean"))
        print("OUT  MAD norm  :", _fmt(mu, sd, n, digits=4), " (max=1)")
        mu, sd, n = _mean_std(_vals(rs, "MAD_outlook_abs_mean"))
        print("OUT  MAD abs   :", _fmt2(mu, sd, n), " (max=2)")

        # Entropy
        mu, sd, n = _mean_std(_vals(rs, "H_grid_mean"))
        mu2, sd2, n2 = _mean_std(_vals(rs, "Hn_grid_mean"))
        print("GRID H (nats)  :", _fmt(mu, sd, n, digits=4), f"| H_norm={_fmt(mu2, sd2, n2, digits=4)}  (norm by log(N))")

        mu, sd, n = _mean_std(_vals(rs, "H_out_mean"))
        mu2, sd2, n2 = _mean_std(_vals(rs, "Hn_out_mean"))
        print("OUT  H (nats)  :", _fmt(mu, sd, n, digits=4), f"| H_norm={_fmt(mu2, sd2, n2, digits=4)}  (norm by log(9))")

    if show_block_breakdown:
        print("\n" + "="*70)
        print("BLOCK BREAKDOWN (per stage -> per block)")
        print("="*70)

        for s in stages:
            rs = by_stage[s]
            blocks = sorted({r["block"] for r in rs})
            print(f"\n--- stage {s} ---")
            for b in blocks:
                rb = [r for r in rs if r["block"] == b]

                gMAD = _mean_std(_vals(rb, "MAD_grid_mean"))
                oMAD = _mean_std(_vals(rb, "MAD_outlook_mean"))
                gHn  = _mean_std(_vals(rb, "Hn_grid_mean"))
                oHn  = _mean_std(_vals(rb, "Hn_out_mean"))

                print(
                    f"block {b}: "
                    f"GRID MAD={_fmt(*gMAD)} | GRID Hn={_fmt(*gHn)} || "
                    f"OUT MAD={_fmt(*oMAD)} | OUT Hn={_fmt(*oHn)}"
                )

    print("\n[Done]")


In [48]:
stage_depths = {0:2, 1:3, 2:4, 3:2}
all_res = run_mad_entropy_pipeline(model, test_loader, stage_depths,
                                  seeds=(0,1,2), n_images=64, grid_n_q=32, out_n_xy=64)


In [49]:
print_all_res_report_mad_entropy(all_res, show_block_breakdown=True)


FULL REPORT (MAD + Entropy) — Grid vs Outlooker

--- stage 0 ---
blocks: [0, 1] | seeds: [0, 1, 2] | n_images: [64]
sampling: GRID n_q=32, excl_border=1, avg_over_groups=True | OUT n_xy=64, excl_border=1
grid featuremap: Hf=[64] Wf=[64] | denom=[126.0]
GRID MAD norm  : 0.2523±0.0090 (n=6)
GRID MAD abs   : 31.79±1.13 (n=6)  (featuremap L1)
OUT  MAD norm  : 0.6644±0.0630 (n=6)  (max=1)
OUT  MAD abs   : 1.33±0.13 (n=6)  (max=2)
GRID H (nats)  : 3.4034±0.1518 (n=6) | H_norm=0.8183±0.0365 (n=6)  (norm by log(N))
OUT  H (nats)  : 1.9649±0.0103 (n=6) | H_norm=0.8943±0.0047 (n=6)  (norm by log(9))

--- stage 1 ---
blocks: [0, 1, 2] | seeds: [0, 1, 2] | n_images: [64]
sampling: GRID n_q=32, excl_border=1, avg_over_groups=True | OUT n_xy=64, excl_border=1
grid featuremap: Hf=[32] Wf=[32] | denom=[62.0]
GRID MAD norm  : 0.2485±0.0025 (n=9)
GRID MAD abs   : 15.41±0.15 (n=9)  (featuremap L1)
OUT  MAD norm  : 0.8324±0.0621 (n=9)  (max=1)
OUT  MAD abs   : 1.66±0.12 (n=9)  (max=2)
GRID H (nats)  : 2.