In [None]:
import sys, os

root_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root_path not in sys.path:
    sys.path.append(root_path)

In [None]:
from src.data.load_tinyimagenet import *

train_loader, val_loader, test_loader = get_tinyimagenet200_hf_dataloaders(
    batch_size=64,
    data_dir="./data",
    num_workers=2,
    val_split=0.1,
    img_size=64,
    seed=7)

---

In [None]:
from src.Model_A_OutGridNet import * 
from src.stage_config import *

from src.training.one_epoch_train import *
from src.training.train_full_model import *

def imagenet200_stages_t4(drop_path=0.08):
    # resoluciones: 64 -> 32 -> 16 -> 8
    return [
        StageCfg(dim=64,  depth=2, num_heads=2,  grid_size=8, outlook_heads=2,  drop_path=drop_path),
        StageCfg(dim=128, depth=3, num_heads=4,  grid_size=8, outlook_heads=4,  drop_path=drop_path),
        StageCfg(dim=256, depth=4, num_heads=8,  grid_size=4, outlook_heads=8,  drop_path=drop_path),
        StageCfg(dim=384, depth=2, num_heads=6,  grid_size=2, outlook_heads=6,  drop_path=drop_path),]

stages = imagenet200_stages_t4(drop_path=0.1)

model = MaxOutNet(
    num_classes=200,
    stages=stages,
    stem_dim=64,
    dpr_max=0.11)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [29]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

n_params = count_trainable_parameters(model)
print(f"Trainable parameters: {n_params:,}")

Trainable parameters: 22,542,628


In [None]:
load_checkpoint('/content/Model_A_tinyimagenet_50epocs.pt' , model=model)

In [31]:
evaluate_one_epoch(model=model, dataloader=test_loader)

(1.2630521028518678, {'top1': 69.86, 'top3': 83.59, 'top5': 88.03})

---

In [None]:
from src.experiments.mad_metrics import *

In [None]:
def agg_stage(all_res, key):
    by = {}
    for r in all_res:
        s = r["stage"]
        v = r.get(key, None)
        if v is None:
            continue
        by.setdefault(s, []).append(v)
    out = {}
    for s, vals in by.items():
        out[s] = (float(np.mean(vals)), float(np.std(vals)), len(vals))
    return out


stage_depths = {0:2, 1:3, 2:4, 3:2}

all_res = []
for s in [0,1,2,3]:
    for b in range(stage_depths[s]):
        all_res += compute_grid_and_outlooker_mad_by_stage(
            model=model,
            loader=test_loader,
            block_idx=b,
            stages=(s,),
            n_images=16,
            seed=10,
            device="cuda")


grid_abs = agg_stage(all_res, "MAD_grid_abs_mean")
out_abs  = agg_stage(all_res, "MAD_outlook_abs_mean")

def fmt2(x):
    return "None" if x is None else f"{x:.2f}"

def _collect_by_stage(all_res):
    by = {}
    for r in all_res:
        s = r["stage"]
        by.setdefault(s, []).append(r)
    return by

def _vals(rs, key):
    return [r[key] for r in rs if r.get(key, None) is not None]

def _mean_std(vals):
    if len(vals) == 0:
        return None, None, 0
    return float(np.mean(vals)), float(np.std(vals)), int(len(vals))

def _fmt(mu, sd, n, digits=4):
    if mu is None:
        return "None"
    return f"{mu:.{digits}f}±{sd:.{digits}f} (n={n})"

def _fmt2(mu, sd, n):
    if mu is None:
        return "None"
    return f"{mu:.2f}±{sd:.2f} (n={n})"

def print_all_res_report(all_res, show_block_breakdown=True):
    by_stage = _collect_by_stage(all_res)
    stages = sorted(by_stage.keys())

    print("\n" + "="*70)
    print("FULL MAD REPORT (Grid vs Outlooker)")
    print("="*70)

    #  STAGE SUMMARY 
    for s in stages:
        rs = by_stage[s]
        blocks = sorted({r["block"] for r in rs})
        seeds  = sorted({r["seed"] for r in rs})
        nimgs  = sorted({r["n_images"] for r in rs})
        gy_set = sorted({r.get("gy", None) for r in rs})
        gx_set = sorted({r.get("gx", None) for r in rs})

        # featuremap info (grid)
        Hfs = _vals(rs, "grid_Hf")
        Wfs = _vals(rs, "grid_Wf")
        den = _vals(rs, "grid_denom")
        Hf_u = sorted(set(Hfs)) if len(Hfs) else []
        Wf_u = sorted(set(Wfs)) if len(Wfs) else []
        den_u = sorted(set([round(d, 6) for d in den])) if len(den) else []

        print(f"\n--- stage {s} ---")
        print(f"blocks: {blocks} | seeds: {seeds} | n_images: {nimgs} | (gy,gx): ({gy_set},{gx_set})")
        if len(Hf_u):
            print(f"grid featuremap: Hf={Hf_u} Wf={Wf_u} | denom={(den_u if len(den_u)<=6 else den_u[:6]+['...'])}")

        # ===== GRID (norm) =====
        mu, sd, n = _mean_std(_vals(rs, "MAD_grid_mean"))
        print("GRID norm      :", _fmt(mu, sd, n, digits=4))

        mu, sd, n = _mean_std(_vals(rs, "MAD_grid_center_mean"))
        mu2, sd2, n2 = _mean_std(_vals(rs, "MAD_grid_max_mean"))
        mu3, sd3, n3 = _mean_std(_vals(rs, "MAD_grid_min_mean"))
        print("GRID norm comps:", f"center={_fmt(mu,sd,n)} | max={_fmt(mu2,sd2,n2)} | min={_fmt(mu3,sd3,n3)}")

        # ===== GRID (abs) =====
        mu, sd, n = _mean_std(_vals(rs, "MAD_grid_abs_mean"))
        print("GRID abs       :", _fmt2(mu, sd, n))

        mu, sd, n = _mean_std(_vals(rs, "MAD_grid_abs_center_mean"))
        mu2, sd2, n2 = _mean_std(_vals(rs, "MAD_grid_abs_max_mean"))
        mu3, sd3, n3 = _mean_std(_vals(rs, "MAD_grid_abs_min_mean"))
        print("GRID abs comps :", f"center={_fmt2(mu,sd,n)} | max={_fmt2(mu2,sd2,n2)} | min={_fmt2(mu3,sd3,n3)}")

        # ===== OUTLOOKER (norm) =====
        mu, sd, n = _mean_std(_vals(rs, "MAD_outlook_mean"))
        print("OUT norm       :", _fmt(mu, sd, n, digits=4))

        mu, sd, n = _mean_std(_vals(rs, "MAD_outlook_center_mean"))
        mu2, sd2, n2 = _mean_std(_vals(rs, "MAD_outlook_max_mean"))
        mu3, sd3, n3 = _mean_std(_vals(rs, "MAD_outlook_min_mean"))
        print("OUT norm comps :", f"center={_fmt(mu,sd,n)} | max={_fmt(mu2,sd2,n2)} | min={_fmt(mu3,sd3,n3)}")

        # ===== OUTLOOKER (abs) =====
        mu, sd, n = _mean_std(_vals(rs, "MAD_outlook_abs_mean"))
        print("OUT abs        :", _fmt2(mu, sd, n), " (max=2)")

        mu, sd, n = _mean_std(_vals(rs, "MAD_outlook_abs_center_mean"))
        mu2, sd2, n2 = _mean_std(_vals(rs, "MAD_outlook_abs_max_mean"))
        mu3, sd3, n3 = _mean_std(_vals(rs, "MAD_outlook_abs_min_mean"))
        print("OUT abs comps  :", f"center={_fmt2(mu,sd,n)} | max={_fmt2(mu2,sd2,n2)} | min={_fmt2(mu3,sd3,n3)}  (max=2)")

    #  OPTIONAL: BLOCK BREAKDOWN 
    if show_block_breakdown:
        print("\n" + "="*70)
        print("BLOCK BREAKDOWN (per stage -> per block; means across seeds/images)")
        print("="*70)

        for s in stages:
            rs = by_stage[s]
            blocks = sorted({r["block"] for r in rs})
            print(f"\n--- stage {s} ---")
            for b in blocks:
                rb = [r for r in rs if r["block"] == b]

                gN = _mean_std(_vals(rb, "MAD_grid_mean"))      # norm
                gA = _mean_std(_vals(rb, "MAD_grid_abs_mean"))  # abs
                oN = _mean_std(_vals(rb, "MAD_outlook_mean"))   # norm
                oA = _mean_std(_vals(rb, "MAD_outlook_abs_mean")) # abs

                print(
                    f"block {b}: "
                    f"GRID norm={_fmt(*gN)} | GRID abs={_fmt2(*gA)} || "
                    f"OUT norm={_fmt(*oN)} | OUT abs={_fmt2(*oA)}")

    print("\n[Done]")

def print_mad_abs_by_stage_simple(all_res):
    by_stage = {}
    for r in all_res:
        by_stage.setdefault(r["stage"], []).append(r)

    print("\n=== MAD (ABS) by stage — simple view ===")
    print("GRID_abs is in featuremap L1 pixels; max = (Hf-1)+(Wf-1).")
    print("OUT_abs  is in 3×3 L1 steps; max = 2.\n")

    for s in sorted(by_stage.keys()):
        rs = by_stage[s]

        # GRID abs
        g_mu, g_sd, g_n = _mean_std(_vals(rs, "MAD_grid_abs_mean"))
        denoms = sorted(set(_vals(rs, "grid_denom")))
        Hfs    = sorted(set(_vals(rs, "grid_Hf")))
        Wfs    = sorted(set(_vals(rs, "grid_Wf")))

        # OUT abs
        o_mu, o_sd, o_n = _mean_std(_vals(rs, "MAD_outlook_abs_mean"))

        # formatting helpers
        def f2(mu, sd, n):
            if mu is None:
                return "None"
            return f"{mu:.2f}±{sd:.2f} (n={n})"

        # scale string
        if len(denoms) == 1 and len(Hfs) == 1 and len(Wfs) == 1:
            scale = f"GRID max={denoms[0]:.0f} (Hf={Hfs[0]}, Wf={Wfs[0]}) | OUT max=2"
        else:
            scale = f"GRID max≈{(denoms[0] if len(denoms) else None)} | OUT max=2"

        print(f"stage {s}:  GRID_abs={f2(g_mu,g_sd,g_n)}   |   OUT_abs={f2(o_mu,o_sd,o_n)}   |   {scale}")


In [206]:
print_all_res_report(all_res, show_block_breakdown=True)


FULL MAD REPORT (Grid vs Outlooker)

--- stage 0 ---
blocks: [0, 1] | seeds: [10] | n_images: [16] | (gy,gx): ([0],[0])
grid featuremap: Hf=[64] Wf=[64] | denom=[126.0]
GRID norm      : 0.2261±0.0093 (n=2)
GRID norm comps: center=0.1972±0.0137 (n=2) | max=0.1735±0.0177 (n=2) | min=0.3077±0.0036 (n=2)
GRID abs       : 28.49±1.17 (n=2)
GRID abs comps : center=24.84±1.73 (n=2) | max=21.86±2.22 (n=2) | min=38.78±0.45 (n=2)
OUT norm       : 0.5950±0.0916 (n=2)
OUT norm comps : center=0.6091±0.1139 (n=2) | max=0.4305±0.1009 (n=2) | min=0.7453±0.0600 (n=2)
OUT abs        : 1.19±0.18 (n=2)  (max=2)
OUT abs comps  : center=1.22±0.23 (n=2) | max=0.86±0.20 (n=2) | min=1.49±0.12 (n=2)  (max=2)

--- stage 1 ---
blocks: [0, 1, 2] | seeds: [10] | n_images: [16] | (gy,gx): ([0],[0])
grid featuremap: Hf=[32] Wf=[32] | denom=[62.0]
GRID norm      : 0.2952±0.0038 (n=3)
GRID norm comps: center=0.2527±0.0076 (n=3) | max=0.3088±0.0121 (n=3) | min=0.3242±0.0156 (n=3)
GRID abs       : 18.30±0.24 (n=3)
GRID a

In [208]:
print_mad_abs_by_stage_simple(all_res)


=== MAD (ABS) by stage — simple view ===
GRID_abs is in featuremap L1 pixels; max = (Hf-1)+(Wf-1).
OUT_abs  is in 3×3 L1 steps; max = 2.

stage 0:  GRID_abs=28.49±1.17 (n=2)   |   OUT_abs=1.19±0.18 (n=2)   |   GRID max=126 (Hf=64, Wf=64) | OUT max=2
stage 1:  GRID_abs=18.30±0.24 (n=3)   |   OUT_abs=1.61±0.13 (n=3)   |   GRID max=62 (Hf=32, Wf=32) | OUT max=2
stage 2:  GRID_abs=9.06±0.21 (n=4)   |   OUT_abs=1.62±0.25 (n=4)   |   GRID max=30 (Hf=16, Wf=16) | OUT max=2
stage 3:  GRID_abs=5.41±0.55 (n=2)   |   OUT_abs=1.69±0.14 (n=2)   |   GRID max=14 (Hf=8, Wf=8) | OUT max=2


---

