In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import RandAugment

def get_cifar100_datasets(
    data_dir: str = "./data",
    val_split: float = 0.0,
    ra_num_ops: int = 2,
    ra_magnitude: int = 7,
    random_erasing_p: float = 0.25,
    erasing_scale=(0.02, 0.20),
    erasing_ratio=(0.3, 3.3),
    img_size: int = 32,):

    """
    CIFAR-100 datasets con augmentations "mix-friendly":
    diseñadas para complementar Mixup/CutMix (en el loop) sin pasarse.

    img_size:
      - 32 (default): CIFAR nativo.
      - >32: upsample (p.ej. 64) para experimentos (más tokens/compute).
    """
    if img_size < 32:
        raise ValueError(f"img_size must be >= 32 for CIFAR-100. Got {img_size}.")

    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    # Si subimos resolución, primero hacemos resize y adaptamos crop/padding.
    # Padding recomendado proporcional: 32->4, 64->8, etc.

    crop_padding = max(4, img_size // 8)

    train_ops = []
    if img_size != 32:
        train_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))

    train_ops += [
        transforms.RandomCrop(img_size, padding=crop_padding),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=ra_num_ops, magnitude=ra_magnitude),
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
        transforms.RandomErasing(
            p=random_erasing_p,
            scale=erasing_scale,
            ratio=erasing_ratio,
            value="random",),]

    train_transform = transforms.Compose(train_ops)

    test_ops = []
    if img_size != 32:
        test_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))

    test_ops += [
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),]

    test_transform = transforms.Compose(test_ops)

    full_train_dataset = datasets.CIFAR100(
        root=data_dir, train=True, download=True, transform=train_transform)

    test_dataset = datasets.CIFAR100(
        root=data_dir, train=False, download=True, transform=test_transform)

    if val_split > 0.0:
        n_total = len(full_train_dataset)
        n_val = int(n_total * val_split)
        n_train = n_total - n_val
        train_dataset, val_dataset = random_split(
            full_train_dataset,
            [n_train, n_val],
            generator=torch.Generator().manual_seed(7),)

    else:
        train_dataset = full_train_dataset
        val_dataset = None

    return train_dataset, val_dataset, test_dataset


def get_cifar100_dataloaders(
    batch_size: int = 128,
    data_dir: str = "./data",
    num_workers: int = 2,
    val_split: float = 0.0,
    pin_memory: bool = True,
    ra_num_ops: int = 2,
    ra_magnitude: int = 7,
    random_erasing_p: float = 0.25,
    img_size: int = 32,):
    """
    Dataloaders CIFAR-100 listos para entrenar con Mixup/CutMix en el loop.
    Augmentations no tan agresivas.

    img_size:
      - 32 (default): CIFAR nativo.
      - 64: experimento de upsample (ojo: más compute).
    """
    train_ds, val_ds, test_ds = get_cifar100_datasets(
        data_dir=data_dir,
        val_split=val_split,
        ra_num_ops=ra_num_ops,
        ra_magnitude=ra_magnitude,
        random_erasing_p=random_erasing_p,
        img_size=img_size,)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=(num_workers > 0),)

    val_loader = None
    if val_ds is not None:
        val_loader = DataLoader(
            val_ds,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=(num_workers > 0),)

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=(num_workers > 0),)

    return train_loader, val_loader, test_loader

In [None]:
train_loader, val_loader, test_loader = get_cifar100_dataloaders(
    batch_size=128,
    data_dir="./data/cifar100",
    num_workers=2,
    val_split=0.1,
    img_size=64)

100%|██████████| 169M/169M [00:04<00:00, 35.3MB/s]


In [None]:

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision import datasets
from collections import Counter, defaultdict
import math

CIFAR100_MEAN = (0.5071, 0.4867, 0.4408)
CIFAR100_STD  = (0.2675, 0.2565, 0.2761)



def describe_loader(loader, name="loader", max_batches_for_stats=50):
    ds = loader.dataset
    n = len(ds)

    print("\n" + "="*90)
    print(f"{name.upper()} SUMMARY")
    print("="*90)

    print(f"Dataset type        : {type(ds).__name__}")
    if hasattr(ds, "dataset") and hasattr(ds, "indices"):
        print(f"  ↳ Wrapped dataset  : {type(ds.dataset).__name__} (Subset-like)")
        print(f"  ↳ Subset size      : {len(ds.indices)}")

    print(f"Num samples         : {n}")
    print(f"Batch size          : {getattr(loader, 'batch_size', None)}")
    print(f"Num workers         : {getattr(loader, 'num_workers', None)}")
    print(f"Pin memory          : {getattr(loader, 'pin_memory', None)}")
    print(f"Drop last           : {getattr(loader, 'drop_last', None)}")

    sampler = getattr(loader, "sampler", None)
    sampler_name = type(sampler).__name__ if sampler is not None else None
    print(f"Sampler             : {sampler_name}")

    num_batches = len(loader)
    bs = loader.batch_size if loader.batch_size is not None else "?"
    approx_batches = math.ceil(n / loader.batch_size) if loader.batch_size else "?"
    print(f"len(loader) (#batches): {num_batches} (≈ ceil({n}/{bs}) = {approx_batches})")

    x, y = next(iter(loader))
    print("\nFirst batch:")
    print(f"  x.shape           : {tuple(x.shape)}")
    print(f"  y.shape           : {tuple(y.shape)}")
    print(f"  x.dtype           : {x.dtype}")
    print(f"  y.dtype           : {y.dtype}")
    print(f"  x.min/max         : {float(x.min()):.4f} / {float(x.max()):.4f}")
    print(f"  y.min/max         : {int(y.min())} / {int(y.max())}")
    print(f"  unique labels (batch): {len(torch.unique(y))}")
    print(f"\nQuick stats over up to {max_batches_for_stats} batches:")

    n_seen = 0
    sum_ = 0.0
    sumsq_ = 0.0
    class_counts = Counter()

    for bi, (xb, yb) in enumerate(loader):
        if bi >= max_batches_for_stats:
            break
        xb = xb.float()
        n_pix = xb.numel()
        sum_ += xb.sum().item()
        sumsq_ += (xb * xb).sum().item()
        n_seen += n_pix

        class_counts.update(yb.tolist())

    mean = sum_ / max(1, n_seen)
    var = (sumsq_ / max(1, n_seen)) - mean**2
    std = math.sqrt(max(0.0, var))

    print(f"  Approx mean        : {mean:.6f}")
    print(f"  Approx std         : {std:.6f}")
    top5 = class_counts.most_common(5)
    print(f"  Seen label counts  : {len(class_counts)} classes (in sampled batches)")
    print(f"  Top-5 labels       : {top5}")

    targets = None
    if hasattr(ds, "targets"):
        targets = ds.targets
    elif hasattr(ds, "labels"):
        targets = ds.labels
    elif hasattr(ds, "dataset") and hasattr(ds.dataset, "targets") and hasattr(ds, "indices"):
        base_targets = ds.dataset.targets
        targets = [base_targets[i] for i in ds.indices]

    if targets is not None:
        full_counts = Counter(list(map(int, targets)))
        k = len(full_counts)
        print(f"\nFull dataset label distribution:")
        print(f"  #classes detected  : {k}")
        if k > 0:
            mn = min(full_counts.values())
            mx = max(full_counts.values())
            print(f"  min/max per class  : {mn} / {mx}")
            first10 = sorted(full_counts.items(), key=lambda t: t[0])[:10]
            print(f"  first 10 classes   : {first10}")
            if mn == mx:
                print("  balance check      : perfectly balanced")
            else:
                print("  balance check      : not perfectly balanced")
    else:
        print("\nFull dataset label distribution: (couldn't find targets/labels attribute)")

    print("="*90)


describe_loader(train_loader, "train_loader", max_batches_for_stats=50)


TRAIN_LOADER SUMMARY
Dataset type        : Subset
  ↳ Wrapped dataset  : CIFAR100 (Subset-like)
  ↳ Subset size      : 45000
Num samples         : 45000
Batch size          : 128
Num workers         : 2
Pin memory          : True
Drop last           : False
Sampler             : RandomSampler
len(loader) (#batches): 352 (≈ ceil(45000/128) = 352)

First batch:
  x.shape           : (128, 3, 64, 64)
  y.shape           : (128,)
  x.dtype           : torch.float32
  y.dtype           : torch.int64
  x.min/max         : -4.3048 / 4.0542
  y.min/max         : 0 / 99
  unique labels (batch): 74

Quick stats over up to 50 batches:
  Approx mean        : -0.287313
  Approx std         : 1.113161
  Seen label counts  : 100 classes (in sampled batches)
  Top-5 labels       : [(2, 82), (18, 81), (58, 79), (22, 79), (99, 79)]

Full dataset label distribution:
  #classes detected  : 100
  min/max per class  : 436 / 463
  first 10 classes   : [(0, 457), (1, 439), (2, 448), (3, 455), (4, 446), (5, 4

In [None]:


def unnormalize(images: torch.Tensor,
                mean=CIFAR100_MEAN,
                std=CIFAR100_STD):
    """
    Des-normaliza un batch de imágenes.
    images: tensor [B, C, H, W] normalizado.
    """
    mean = torch.tensor(mean, device=images.device).view(1, -1, 1, 1)
    std = torch.tensor(std, device=images.device).view(1, -1, 1, 1)
    return images * std + mean


def show_batch(images: torch.Tensor,
               labels: torch.Tensor,
               class_names=None,
               n: int = 8):
    """
    Muestra las primeras n imágenes de un batch con sus labels.

    Args:
        images: tensor [B, C, H, W] (normalizado).
        labels: tensor [B].
        class_names: lista de nombres de clases (len = 100).
        n: cuántas imágenes mostrar (en una fila).
    """
    images = images[:n].cpu()
    labels = labels[:n].cpu()
    images_unnorm = unnormalize(images)

    grid = make_grid(images_unnorm, nrow=n, padding=2)
    npimg = grid.permute(1, 2, 0).numpy()

    plt.figure(figsize=(2 * n, 2.5))
    plt.imshow(npimg)
    plt.axis("off")

    if class_names is not None:
        title = " | ".join(class_names[int(lbl)] for lbl in labels)
        plt.title(title, fontsize=10)
    plt.show()

cifar100_train = datasets.CIFAR100(
    root="./data/cifar100",
    train=True,
    download=False)

class_names = cifar100_train.classes
images, labels = next(iter(train_loader))
show_batch(images, labels, class_names=class_names, n=8)

---


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm2d(nn.Module):
    def __init__(self, C, eps=1e-6):
        super().__init__()
        self.ln = nn.LayerNorm(C, eps=eps)

    def forward(self, x):
        # [B,C,H,W] -> [B,H,W,C] -> LN -> [B,C,H,W]
        return self.ln(x.permute(0,2,3,1)).permute(0,3,1,2).contiguous()





def _make_activation(act: str) -> nn.Module:
    act = act.lower()
    if act == "silu":
        return nn.SiLU(inplace=True)
    if act == "relu":
        return nn.ReLU(inplace=True)
    if act == "gelu":
        return nn.GELU()
    raise ValueError(f"Unknown activation '{act}'. Use one of: silu|gelu|relu")


class MLP2d(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0, act="gelu"):
        super().__init__()
        hidden = max(1, int(dim * mlp_ratio))
        self.fc1 = nn.Conv2d(dim, hidden, 1)
        self.act = _make_activation(act)
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Conv2d(hidden, dim, 1)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class OutlookAttention2d(nn.Module):
    """
    OutlookAttention on [B,C,H,W] (NCHW) with dynamic local aggregation.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int = 6,
        kernel_size: int = 3,
        stride: int = 1,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        qkv_bias: bool = True,
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        if kernel_size <= 0 or kernel_size % 2 == 0:
            raise ValueError("kernel_size must be odd and >0 (e.g., 3,5,7)")
        if stride <= 0:
            raise ValueError("stride must be > 0")

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.kernel_size = kernel_size
        self.stride = stride

        kk = kernel_size * kernel_size
        bias = bool(qkv_bias)

        # logits per spatial position
        self.attn = nn.Conv2d(dim, num_heads * kk, kernel_size=1, bias=bias)
        # values
        self.v = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=True)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        k = self.kernel_size
        s = self.stride
        heads = self.num_heads
        hd = self.head_dim
        kk = k * k

        # attn logits: [B, heads*kk, H, W] -> (optional) pool if stride>1
        a = self.attn(x)
        if s > 1:
            a = F.avg_pool2d(a, kernel_size=s, stride=s)
        _, _, Hs, Ws = a.shape

        # [B, heads, kk, Hs, Ws] -> [B, Hs*Ws, heads, kk]
        a = a.view(B, heads, kk, Hs, Ws).flatten(3).permute(0, 3, 1, 2).contiguous()
        a = F.softmax(a, dim=-1)
        a = self.attn_drop(a)

        # values + unfold neighborhoods
        v = self.v(x)  # [B,C,H,W]
        pad = k // 2
        v_unf = F.unfold(v, kernel_size=k, padding=pad, stride=s)  # [B, C*kk, Hs*Ws]

        # -> [B, Hs*Ws, heads, hd, kk]
        v_unf = v_unf.view(B, heads, hd, kk, Hs * Ws).permute(0, 4, 1, 2, 3).contiguous()

        # weighted sum over kk
        y = (v_unf * a.unsqueeze(3)).sum(dim=-1)  # [B, Hs*Ws, heads, hd]
        y = y.permute(0, 2, 3, 1).contiguous().view(B, C, Hs, Ws)

        y = self.proj(y)
        y = self.proj_drop(y)
        return y

In [None]:
class DropPath(nn.Module):
    """
    DropPath / Stochastic Depth. Works for any tensor shape with batch in dim 0.
    """
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or (not self.training):
            return x
        keep_prob = 1.0 - self.drop_prob
        # shape: [B, 1, 1, 1, ...]
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = torch.empty(shape, device=x.device, dtype=x.dtype).bernoulli_(keep_prob)
        return x * mask / keep_prob


class OutlookerBlock2d(nn.Module):
    """
    x (NCHW) -> LN2d -> OutlookAttention2d -> DropPath + res
             -> LN2d -> MLP2d            -> DropPath + res
    """
    def __init__(
        self,
        dim: int,
        num_heads: int,
        kernel_size: int = 3,
        stride: int = 1,
        mlp_ratio: float = 2.0,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        drop_path: float = 0.0,
        mlp_drop: float = 0.0,
        act: str = "gelu",
        norm_eps: float = 1e-6):

        super().__init__()
        self.norm1 = LayerNorm2d(dim, eps=norm_eps)
        self.attn = OutlookAttention2d(
            dim=dim,
            num_heads=num_heads,
            kernel_size=kernel_size,
            stride=stride,
            attn_drop=attn_drop,
            proj_drop=proj_drop)

        self.dp1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()

        self.norm2 = LayerNorm2d(dim, eps=norm_eps)
        self.mlp = MLP2d(dim=dim, mlp_ratio=mlp_ratio, drop=mlp_drop, act=act)
        self.dp2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.dp1(self.attn(self.norm1(x)))
        x = x + self.dp2(self.mlp(self.norm2(x)))
        return x

In [None]:
x = torch.randn(8, 96, 16, 16)
blk = OutlookerBlock2d(dim=96, num_heads=6, kernel_size=3, stride=1)
y = blk(x)
print(x.shape, y.shape)

torch.Size([8, 96, 16, 16]) torch.Size([8, 96, 16, 16])


In [None]:
from typing import Literal
from dataclasses import dataclass

class SqueezeExcite(nn.Module):
    def __init__(self, channels: int, se_ratio: float = 0.25, act: str = "silu"):
        super().__init__()
        if not (0.0 < se_ratio <= 1.0):
            raise ValueError("se_ratio must be in (0, 1].")

        hidden = max(1, int(channels * se_ratio))
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
        self.act = _make_activation(act)
        self.fc2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
        self.gate = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        s = self.pool(x)
        s = self.fc1(s)
        s = self.act(s)
        s = self.fc2(s)
        return x * self.gate(s)


ActType = Literal["silu", "gelu", "relu"]

@dataclass(frozen=True)
class MBConvConfig:
    expand_ratio: float = 4.0
    se_ratio: float = 0.25
    act: ActType = "silu"
    use_bn: bool = True
    drop_path: float = 0.0

class MBConv(nn.Module):
    """
    MBConv block (NCHW):
      Expand 1x1 -> Depthwise 3x3 -> SE -> Project 1x1
      Residual if stride=1 and in_ch==out_ch
    """
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1, cfg: MBConvConfig = MBConvConfig()):
        super().__init__()
        if in_ch <= 0 or out_ch <= 0:
            raise ValueError("in_ch and out_ch must be > 0")
        if stride not in (1, 2):
            raise ValueError("stride must be 1 or 2")

        self.in_ch = in_ch
        self.out_ch = out_ch
        self.stride = stride

        bn = (lambda c: nn.BatchNorm2d(c)) if cfg.use_bn else (lambda c: nn.Identity())
        act = _make_activation(cfg.act)

        mid_ch = max(1, int(round(in_ch * cfg.expand_ratio)))

        if mid_ch != in_ch:
            self.expand = nn.Sequential(
                nn.Conv2d(in_ch, mid_ch, kernel_size=1, bias=not cfg.use_bn),
                bn(mid_ch),
                act,)

        else:
            self.expand = nn.Identity()

        self.depthwise = nn.Sequential(
            nn.Conv2d(mid_ch, mid_ch, kernel_size=3, stride=stride, padding=1,
                      groups=mid_ch, bias=not cfg.use_bn),
            bn(mid_ch),
            act,)

        self.se = SqueezeExcite(mid_ch, se_ratio=cfg.se_ratio, act=cfg.act) if cfg.se_ratio > 0 else nn.Identity()

        self.project = nn.Sequential(
            nn.Conv2d(mid_ch, out_ch, kernel_size=1, bias=not cfg.use_bn),
            bn(out_ch),)

        self.use_res = (stride == 1 and in_ch == out_ch)
        self.drop_path = DropPath(cfg.drop_path) if (cfg.drop_path and cfg.drop_path > 0) else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.expand(x)
        out = self.depthwise(out)
        out = self.se(out)
        out = self.project(out)

        if self.use_res:
            out = x + self.drop_path(out)
        return out



---

In [None]:
def grid_partition(x: torch.Tensor, grid_size: int):
    if x.ndim != 4:
        raise ValueError(f"Expected x.ndim==4 (BHWC). Got shape {tuple(x.shape)}")
    B, H, W, C = x.shape
    g = grid_size
    if g <= 0:
        raise ValueError("grid_size must be > 0")
    if (H % g) != 0 or (W % g) != 0:
        raise ValueError(f"H and W must be divisible by grid_size. Got H={H}, W={W}, g={g}")

    Hg, Wg = H // g, W // g
    x = x.view(B, Hg, g, Wg, g, C)  # [B, Hg, g, Wg, g, C]
    grids = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * g * g, Hg, Wg, C)
    meta = (B, H, W, C, g)
    return grids, meta


def grid_unpartition(grids: torch.Tensor, meta) -> torch.Tensor:
    if grids.ndim != 4:
        raise ValueError(f"Expected grids.ndim==4. Got shape {tuple(grids.shape)}")
    B, H, W, C, g = meta
    Hg, Wg = H // g, W // g
    if grids.shape[0] != B * g * g:
        raise ValueError(f"grids.shape[0] must be B*g*g = {B*g*g}. Got {grids.shape[0]}")
    if grids.shape[1] != Hg or grids.shape[2] != Wg or grids.shape[3] != C:
        raise ValueError(f"grids shape mismatch. Expected (*,{Hg},{Wg},{C}) got {tuple(grids.shape)}")

    x = grids.view(B, g, g, Hg, Wg, C)
    x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, H, W, C)
    return x

In [None]:
from dataclasses import dataclass
from typing import Literal


AttnMode = Literal["grid"]

@dataclass(frozen=True)
class AttentionConfig:
    dim: int
    num_heads: int
    qkv_bias: bool = True
    attn_drop: float = 0.0
    proj_drop: float = 0.0


@dataclass(frozen=True)
class LocalAttention2DConfig:
    mode: AttnMode
    dim: int
    num_heads: int
    grid_size: int
    window_size: int = 1
    qkv_bias: bool = True
    attn_drop: float = 0.0
    proj_drop: float = 0.0


class MultiHeadSelfAttention(nn.Module):
    """
    Standard MHSA for token sequences.

    Input:  x [B, N, C]
    Output: y [B, N, C]

    Works for both window and grid partitions because both can be flattened to [Bgrp, N, C].
    """

    def __init__(self, cfg: AttentionConfig):
        super().__init__()
        if cfg.dim <= 0:
            raise ValueError("cfg.dim must be > 0")
        if cfg.num_heads <= 0:
            raise ValueError("cfg.num_heads must be > 0")
        if cfg.dim % cfg.num_heads != 0:
            raise ValueError(f"dim ({cfg.dim}) must be divisible by num_heads ({cfg.num_heads})")

        self.dim = cfg.dim
        self.num_heads = cfg.num_heads
        self.head_dim = cfg.dim // cfg.num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=cfg.qkv_bias)
        self.attn_drop = nn.Dropout(cfg.attn_drop)
        self.proj = nn.Linear(cfg.dim, cfg.dim, bias=True)
        self.proj_drop = nn.Dropout(cfg.proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim != 3:
            raise ValueError(f"Expected x.ndim==3 with shape [B, N, C]. Got {tuple(x.shape)}")
        B, N, C = x.shape
        if C != self.dim:
            raise ValueError(f"Expected last dim C={self.dim}. Got C={C}")

        # qkv: [B, N, 3C] -> [B, N, 3, heads, head_dim] -> [3, B, heads, N, head_dim]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # attention: [B, heads, N, N]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # out: [B, heads, N, head_dim] -> [B, N, C]
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out


class LocalAttention2D(nn.Module):
    """
    Grid attention wrapper.

    Input/Output: x BHWC [B,H,W,C] -> [B,H,W,C]
    """
    def __init__(self, cfg: LocalAttention2DConfig):
        super().__init__()
        if cfg.mode != "grid":
            raise ValueError("This minimal version only supports mode='grid'")
        self.cfg = cfg
        self.mhsa = MultiHeadSelfAttention(
            AttentionConfig(
                dim=cfg.dim,
                num_heads=cfg.num_heads,
                qkv_bias=cfg.qkv_bias,
                attn_drop=cfg.attn_drop,
                proj_drop=cfg.proj_drop,))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim != 4:
            raise ValueError(f"Expected x.ndim==4 (BHWC). Got {tuple(x.shape)}")
        B, H, W, C = x.shape
        if C != self.cfg.dim:
            raise ValueError(f"Expected C=={self.cfg.dim}. Got C={C}")

        g = self.cfg.grid_size
        grids, meta = grid_partition(x, g)         # [B*g*g, Hg, Wg, C]
        Bgrp, Hg, Wg, _ = grids.shape
        tokens = grids.view(Bgrp, Hg * Wg, C)      # [Bgrp, N, C]
        tokens = self.mhsa(tokens)
        grids = tokens.view(Bgrp, Hg, Wg, C)
        out = grid_unpartition(grids, meta)
        return out

In [None]:
class MLP(nn.Module):
    """
    MLP para BHWC: aplica sobre el último dim C.
    x: [..., C] -> [..., C]
    """
    def __init__(self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.0, act: str = "gelu"):
        super().__init__()
        hidden = max(1, int(dim * mlp_ratio))
        self.fc1 = nn.Linear(dim, hidden)
        self.act = _make_activation(act)
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden, dim)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[-1] != self.fc1.in_features:
            raise ValueError(f"MLP expected last dim={self.fc1.in_features}, got {x.shape[-1]}")
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class OutGridBlock(nn.Module):
    """
    Híbrido: Outlooker (local dinámico) -> MBConv -> Grid-MHSA -> MLP
    Input/Output: [B, C, H, W]
    """
    def __init__(self, cfg):
        super().__init__()
        C = cfg.dim

        # Outlooker en NCHW
        self.outlook = OutlookerBlock2d(
            dim=C,
            num_heads=cfg.outlook_heads,          # nuevo hyperparam
            kernel_size=cfg.outlook_kernel,       # nuevo hyperparam
            stride=1,
            mlp_ratio=cfg.outlook_mlp_ratio,      # opcional, puedes fijar 0 o 2
            attn_drop=cfg.attn_drop,
            proj_drop=cfg.proj_drop,
            mlp_drop=cfg.ffn_drop,
            drop_path=cfg.drop_path,
            act=cfg.mlp_act,)

        # MBConv NCHW
        self.mbconv = MBConv(
            in_ch=C, out_ch=C, stride=1,
            cfg=MBConvConfig(
                expand_ratio=cfg.mbconv_expand_ratio,
                se_ratio=cfg.mbconv_se_ratio,
                act=cfg.mbconv_act,
                use_bn=cfg.use_bn,
                drop_path=0.0,
            ),)

        # Grid attention BHWC
        self.norm2 = nn.LayerNorm(C)
        self.grid_attn = LocalAttention2D(
            LocalAttention2DConfig(
                mode="grid",
                dim=C,
                num_heads=cfg.num_heads,
                window_size=cfg.window_size,
                grid_size=cfg.grid_size,
                qkv_bias=True,
                attn_drop=cfg.attn_drop,
                proj_drop=cfg.proj_drop,
            ))
        self.dp2 = DropPath(cfg.drop_path) if cfg.drop_path > 0 else nn.Identity()

        # 4) MLP BHWC
        self.norm3 = nn.LayerNorm(C)
        self.mlp = MLP(dim=C, mlp_ratio=cfg.mlp_ratio, drop=cfg.ffn_drop, act=cfg.mlp_act)
        self.dp3 = DropPath(cfg.drop_path) if cfg.drop_path > 0 else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape

        # Outlooker + MBConv (NCHW)
        x = self.outlook(x)
        x = self.mbconv(x)

        # to BHWC for grid + mlp
        x_bhwc = x.permute(0, 2, 3, 1).contiguous()

        y = self.norm2(x_bhwc)
        y = self.grid_attn(y)
        x_bhwc = x_bhwc + self.dp2(y)

        y = self.norm3(x_bhwc)
        y = self.mlp(y)
        x_bhwc = x_bhwc + self.dp3(y)

        # back to NCHW
        return x_bhwc.permute(0, 3, 1, 2).contiguous()

## Test del nuevo modelo

In [None]:
@dataclass
class DummyCfg:
    dim: int = 96

    # Outlooker
    outlook_heads: int = 6
    outlook_kernel: int = 3
    outlook_mlp_ratio: float = 2.0

    # MBConv
    mbconv_expand_ratio: float = 4.0
    mbconv_se_ratio: float = 0.25
    mbconv_act: str = "silu"
    use_bn: bool = True

    # Grid MHSA
    num_heads: int = 6
    grid_size: int = 4
    window_size: int = 8  # no se usa en grid-only, pero tu ctor lo pasa

    # Drops
    attn_drop: float = 0.0
    proj_drop: float = 0.0
    ffn_drop: float = 0.0
    drop_path: float = 0.0

    # MLP (BHWC)
    mlp_ratio: float = 4.0
    mlp_act: str = "gelu"

In [None]:
def _assert_shape(x: torch.Tensor, shape: tuple, name: str = "tensor"):
    assert tuple(x.shape) == tuple(shape), f"{name}: expected shape {shape}, got {tuple(x.shape)}"

def _assert_ndim(x: torch.Tensor, ndim: int, name: str = "tensor"):
    assert x.ndim == ndim, f"{name}: expected ndim={ndim}, got ndim={x.ndim}, shape={tuple(x.shape)}"

def _assert_finite(x: torch.Tensor, name: str = "tensor"):
    assert torch.isfinite(x).all().item(), f"{name}: found non-finite values (nan/inf)"

def _assert_divisible_hw(H: int, W: int, g: int):
    assert (H % g) == 0 and (W % g) == 0, f"H,W must be divisible by grid_size g={g}. Got H={H}, W={W}"


In [None]:
@torch.no_grad()
def test_outlooker_stage(block: OutGridBlock, x: torch.Tensor):
    _assert_ndim(x, 4, "x")
    B, C, H, W = x.shape
    _assert_shape(x, (B, block.outlook.norm1.ln.normalized_shape[0], H, W), "x (pre)")  # C check

    y = block.outlook(x)
    _assert_shape(y, (B, C, H, W), "outlook(x)")
    _assert_finite(y, "outlook(x)")
    return y

@torch.no_grad()
def test_mbconv_stage(block: OutGridBlock, x: torch.Tensor):
    B, C, H, W = x.shape
    y = block.mbconv(x)
    _assert_shape(y, (B, C, H, W), "mbconv(x)")
    _assert_finite(y, "mbconv(x)")
    return y


@torch.no_grad()
def test_grid_stage(block: OutGridBlock, x_nchw: torch.Tensor):
    B, C, H, W = x_nchw.shape
    x_bhwc = x_nchw.permute(0, 2, 3, 1).contiguous()
    _assert_shape(x_bhwc, (B, H, W, C), "x_bhwc")

    # divisibilidad
    g = block.grid_attn.cfg.grid_size
    _assert_divisible_hw(H, W, g)

    y = block.norm2(x_bhwc)
    _assert_shape(y, (B, H, W, C), "norm2(x_bhwc)")
    y = block.grid_attn(y)
    _assert_shape(y, (B, H, W, C), "grid_attn(norm2(x_bhwc))")
    _assert_finite(y, "grid_attn output")

    out = x_bhwc + block.dp2(y)
    _assert_shape(out, (B, H, W, C), "residual after grid")
    _assert_finite(out, "after grid residual")

    return out  # BHWC


@torch.no_grad()
def test_mlp_stage(block: OutGridBlock, x_bhwc: torch.Tensor):
    B, H, W, C = x_bhwc.shape

    y = block.norm3(x_bhwc)
    _assert_shape(y, (B, H, W, C), "norm3(x_bhwc)")
    y = block.mlp(y)
    _assert_shape(y, (B, H, W, C), "mlp(norm3(x_bhwc))")
    _assert_finite(y, "mlp output")

    out = x_bhwc + block.dp3(y)
    _assert_shape(out, (B, H, W, C), "residual after mlp")
    _assert_finite(out, "after mlp residual")
    return out

@torch.no_grad()
def test_full_forward_matches_stages(block: OutGridBlock, x: torch.Tensor, atol=1e-6, rtol=1e-5):
    block.eval()

    # manual pipeline
    a = test_outlooker_stage(block, x)
    b = test_mbconv_stage(block, a)
    c = test_grid_stage(block, b)         # BHWC
    d = test_mlp_stage(block, c)          # BHWC
    manual = d.permute(0, 3, 1, 2).contiguous()

    # direct forward
    direct = block(x)

    _assert_shape(direct, x.shape, "block(x)")
    _assert_finite(direct, "block(x)")
    assert torch.allclose(manual, direct, atol=atol, rtol=rtol), \
        "Manual staged pipeline != block.forward output (check wiring/residuals/norms)."

    return direct


In [None]:
def run_all_tests():
    torch.manual_seed(0)

    cfg = DummyCfg(dim=96, grid_size=4)
    blk = OutGridBlock(cfg).eval()

    x = torch.randn(2, 96, 16, 16)
    assert x.shape[2] % cfg.grid_size == 0 and x.shape[3] % cfg.grid_size == 0

    y = test_full_forward_matches_stages(blk, x)
    print("All tests passed. y:", y.shape)

run_all_tests()

All tests passed. y: torch.Size([2, 96, 16, 16])


---

In [None]:
class MaxOutStage(nn.Module):
    def __init__(self, block_cfg, depth: int):
        super().__init__()
        self.blocks = nn.ModuleList([OutGridBlock(block_cfg) for _ in range(depth)])

    def forward(self, x):
        for b in self.blocks:
            x = b(x)
        return x

class GridOnlyBlock(nn.Module):
    """
    MBConv -> Grid-MHSA -> MLP (sin window attn).
    Input/Output: [B,C,H,W]
    """
    def __init__(self, cfg):
        super().__init__()
        C = cfg.dim

        self.mbconv = MBConv(
            in_ch=C, out_ch=C, stride=1,
            cfg=MBConvConfig(
                expand_ratio=cfg.mbconv_expand_ratio,
                se_ratio=cfg.mbconv_se_ratio,
                act=cfg.mbconv_act,
                use_bn=cfg.use_bn,
                drop_path=0.0,
            ))

        self.norm2 = nn.LayerNorm(C)
        self.grid_attn = LocalAttention2D(
            LocalAttention2DConfig(
                mode="grid",
                dim=C,
                num_heads=cfg.num_heads,
                window_size=getattr(cfg, "window_size", 1),
                grid_size=cfg.grid_size,
                qkv_bias=True,
                attn_drop=cfg.attn_drop,
                proj_drop=cfg.proj_drop,
            ))
        
        self.dp2 = DropPath(cfg.drop_path) if cfg.drop_path > 0 else nn.Identity()

        self.norm3 = nn.LayerNorm(C)
        self.mlp = MLP(dim=C, mlp_ratio=cfg.mlp_ratio, drop=cfg.ffn_drop, act=cfg.mlp_act)
        self.dp3 = DropPath(cfg.drop_path) if cfg.drop_path > 0 else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.mbconv(x)

        x_bhwc = x.permute(0, 2, 3, 1).contiguous()

        y = self.norm2(x_bhwc)
        y = self.grid_attn(y)
        x_bhwc = x_bhwc + self.dp2(y)

        y = self.norm3(x_bhwc)
        y = self.mlp(y)
        x_bhwc = x_bhwc + self.dp3(y)

        return x_bhwc.permute(0, 3, 1, 2).contiguous()


In [None]:
class StageOutThenGrid(nn.Module):
    """
    Outlooker una vez al inicio del stage, luego varios GridOnlyBlock.
    """
    def __init__(self, cfg, depth: int, out_depth: int = 1):
        super().__init__()
        self.outlookers = nn.ModuleList([
            OutlookerBlock2d(
                dim=cfg.dim,
                num_heads=cfg.outlook_heads,
                kernel_size=cfg.outlook_kernel,
                stride=1,
                mlp_ratio=cfg.outlook_mlp_ratio,
                attn_drop=cfg.attn_drop,
                proj_drop=cfg.proj_drop,
                mlp_drop=cfg.ffn_drop,
                drop_path=cfg.drop_path,
                act=cfg.mlp_act,)

            for _ in range(out_depth)])

        self.blocks = nn.ModuleList([GridOnlyBlock(cfg) for _ in range(depth)])

    def forward(self, x):
        for o in self.outlookers:
            x = o(x)
        for b in self.blocks:
            x = b(x)
        return x

# DownSampling

In [None]:
DownsampleType = Literal["conv", "pool"]
ActType = Literal["silu", "gelu", "relu"]

def _make_activation(act) -> nn.Module:
    act = act.lower()
    if act == "silu":
        return nn.SiLU(inplace=True)
    if act == "relu":
        return nn.ReLU(inplace=True)
    if act == "gelu":
        return nn.GELU()
    raise ValueError(f"Unknown activation '{act}'. Use one of: silu|gelu|relu")

@dataclass(frozen=True)
class DownsampleConfig:
    kind: DownsampleType = "conv"  # "conv" or "pool"
    act: ActType = "silu"
    use_bn: bool = True


class Downsample(nn.Module):
    """
    Downsample block:
      - "conv": Conv3x3 stride2 padding1 (in_ch -> out_ch) + BN + Act
      - "pool": AvgPool2x2 + Conv1x1 (in_ch -> out_ch) + BN + Act

    Input:  [B, in_ch, H, W]
    Output: [B, out_ch, H/2, W/2]
    """

    def __init__(self, in_ch: int, out_ch: int, cfg: DownsampleConfig = DownsampleConfig()):
        super().__init__()
        if in_ch <= 0 or out_ch <= 0:
            raise ValueError("in_ch and out_ch must be > 0")

        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kind = cfg.kind

        bn = (lambda c: nn.BatchNorm2d(c)) if cfg.use_bn else (lambda c: nn.Identity())
        act = _make_activation(cfg.act)

        if cfg.kind == "conv":
            self.op = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1, bias=not cfg.use_bn),
                bn(out_ch),
                act,)
        elif cfg.kind == "pool":
            self.op = nn.Sequential(
                nn.AvgPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=not cfg.use_bn),
                bn(out_ch),
                act,)
        else:
            raise ValueError("cfg.kind must be 'conv' or 'pool'")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.op(x)

---

In [None]:
from dataclasses import dataclass
from typing import List

@dataclass
class StageCfg:
    # core dims
    dim: int
    depth: int

    # grid attention
    num_heads: int
    grid_size: int
    window_size: int = 8  # no se usa en grid-only, pero lo mantenemos compatible

    # outlooker
    outlook_heads: int = 6
    outlook_kernel: int = 3
    outlook_mlp_ratio: float = 2.0

    # MBConv
    mbconv_expand_ratio: float = 4.0
    mbconv_se_ratio: float = 0.25
    mbconv_act: str = "silu"
    use_bn: bool = True

    # drops
    attn_drop: float = 0.0
    proj_drop: float = 0.0
    ffn_drop: float = 0.0
    drop_path: float = 0.0

    # MLP (BHWC)
    mlp_ratio: float = 4.0
    mlp_act: str = "gelu"


def make_dpr(total_blocks: int, dpr_max: float) -> List[float]:
    if total_blocks <= 1:
        return [dpr_max]
    return [dpr_max * i / (total_blocks - 1) for i in range(total_blocks)]


class ConvStem(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, act: str = "silu", use_bn: bool = True):
        super().__init__()
        bn = (lambda c: nn.BatchNorm2d(c)) if use_bn else (lambda c: nn.Identity())
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=not use_bn),
            bn(out_ch),
            _make_activation(act),)

    def forward(self, x):
        return self.stem(x)




In [None]:
class OutlookerFrontGridNet(nn.Module):
    """
    Modelo A:
      Stem -> OutlookerFront (L bloques) -> (Stage: GridOnlyBlock x depth + Downsample) -> Head
    """
    def __init__(
        self,
        num_classes: int,
        stages: List[StageCfg],
        in_ch: int = 3,
        stem_dim: int = 64,
        outlooker_front_depth: int = 2,   # <- varios outlookers "tipo VOLO"
        dpr_max: float = 0.1,
        down_cfg: DownsampleConfig = DownsampleConfig(kind="conv", act="silu", use_bn=True),):

        super().__init__()
        assert len(stages) >= 1
        self.stem = ConvStem(in_ch, stem_dim, act="silu", use_bn=True)

        # proyección para entrar a dim del stage1 si stem_dim != stage1.dim
        self.proj_in = nn.Identity()
        if stem_dim != stages[0].dim:
            self.proj_in = nn.Conv2d(stem_dim, stages[0].dim, kernel_size=1, bias=True)

        # schedule global de drop_path por bloque (front + sum(stage.depth))
        total_blocks = outlooker_front_depth + sum(s.depth for s in stages)
        dprs = make_dpr(total_blocks, dpr_max)
        idx = 0

        # Outlooker front (NCHW) con residual + DropPath interno
        front_cfg = stages[0]
        self.front = nn.ModuleList()
        for _ in range(outlooker_front_depth):
            c = front_cfg
            self.front.append(
                OutlookerBlock2d(
                    dim=c.dim,
                    num_heads=c.outlook_heads,
                    kernel_size=c.outlook_kernel,
                    stride=1,
                    mlp_ratio=c.outlook_mlp_ratio,
                    attn_drop=c.attn_drop,
                    proj_drop=c.proj_drop,
                    mlp_drop=c.ffn_drop,
                    drop_path=dprs[idx],
                    act=c.mlp_act,))

            idx += 1

        # stages: GridOnlyBlock stacks + downsample between stages
        self.stages = nn.ModuleList()
        self.downs = nn.ModuleList()

        for si, scfg in enumerate(stages):
            blocks = nn.ModuleList()
            for _ in range(scfg.depth):
                # clonar cfg pero con drop_path asignado por bloque
                bcfg = StageCfg(**{**scfg.__dict__, "drop_path": dprs[idx]})
                blocks.append(GridOnlyBlock(bcfg))
                idx += 1
            self.stages.append(blocks)

            # downsample (except after last stage)
            if si < len(stages) - 1:
                self.downs.append(Downsample(scfg.dim, stages[si+1].dim, cfg=down_cfg))

        # head
        self.head_norm = nn.BatchNorm2d(stages[-1].dim)
        self.classifier = nn.Linear(stages[-1].dim, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.proj_in(x)

        # front outlooker
        for blk in self.front:
            x = blk(x)

        # grid-only stages
        for si, blocks in enumerate(self.stages):
            for blk in blocks:
                x = blk(x)
            if si < len(self.downs):
                x = self.downs[si](x)

        # global pool + cls
        x = self.head_norm(x)
        x = x.mean(dim=(2, 3))

        return self.classifier(x)

In [None]:
class MaxOutNet(nn.Module):
    """
    Modelo B:
      Stem -> (Stage: MaxOutBlock x depth + Downsample) -> Head
    """
    def __init__(
        self,
        num_classes: int,
        stages: List[StageCfg],
        in_ch: int = 3,
        stem_dim: int = 64,
        dpr_max: float = 0.1,
        down_cfg: DownsampleConfig = DownsampleConfig(kind="conv", act="silu", use_bn=True),):

        super().__init__()
        assert len(stages) >= 1
        self.stem = ConvStem(in_ch, stem_dim, act="silu", use_bn=True)

        self.proj_in = nn.Identity()

        if stem_dim != stages[0].dim:
            self.proj_in = nn.Conv2d(stem_dim, stages[0].dim, kernel_size=1, bias=True)

        total_blocks = sum(s.depth for s in stages)
        dprs = make_dpr(total_blocks, dpr_max)
        idx = 0

        self.stages = nn.ModuleList()
        self.downs = nn.ModuleList()

        for si, scfg in enumerate(stages):
            blocks = nn.ModuleList()
            for _ in range(scfg.depth):
                bcfg = StageCfg(**{**scfg.__dict__, "drop_path": dprs[idx]})

                blocks.append(OutGridBlock(bcfg))
                idx += 1
            self.stages.append(blocks)

            if si < len(stages) - 1:
                self.downs.append(Downsample(scfg.dim, stages[si+1].dim, cfg=down_cfg))

        self.head_norm = nn.BatchNorm2d(stages[-1].dim)
        self.classifier = nn.Linear(stages[-1].dim, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.proj_in(x)

        for si, blocks in enumerate(self.stages):
            for blk in blocks:
                x = blk(x)
            if si < len(self.downs):
                x = self.downs[si](x)

        x = self.head_norm(x)
        x = x.mean(dim=(2, 3))
        return self.classifier(x)

---


In [None]:
def cifar64_stages_tiny():
    # resoluciones esperadas: 64 -> 32 -> 16 -> 8 -> 4
    '''
    dim: #canales del feature map en ese stage
    depth: cuántos bloques repites en ese stage
    num_heads: #heads de la Grid-MHSA en ese stage
    grid_size: cómo se parte la imagen para grid attention (debe dividir H y W)
    outlook_heads: #heads del Outlooker (si ese modelo lo usa en ese stage)
    '''

    return [
        StageCfg(dim=96,  depth=2, num_heads=3, grid_size=8, outlook_heads=3),
        StageCfg(dim=192, depth=2, num_heads=6, grid_size=8, outlook_heads=6),
        StageCfg(dim=384, depth=5, num_heads=12, grid_size=4, outlook_heads=12),
        StageCfg(dim=768, depth=2, num_heads=12, grid_size=2, outlook_heads=12),]


stages = cifar64_stages_tiny()

mA = OutlookerFrontGridNet(num_classes=100, stages=stages, stem_dim=96, outlooker_front_depth=2, dpr_max=0.1)
mB = MaxOutNet(num_classes=100, stages=stages, stem_dim=96, dpr_max=0.1)

x = torch.randn(2, 3, 64, 64)
yA = mA(x)
yB = mB(x)
print(yA.shape, yB.shape)

torch.Size([2, 100]) torch.Size([2, 100])


---

# Training

In [None]:
import os
import math
import time
import random
import inspect
from dataclasses import dataclass
from contextlib import contextmanager, nullcontext
from typing import Optional, Dict, Tuple, Any
import torch

def seed_everything(seed: int = 0, deterministic: bool = False):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True


DTYPE_MAP = {
    "bf16": torch.bfloat16, "bfloat16": torch.bfloat16,
    "fp16": torch.float16,  "float16": torch.float16,
    "fp32": torch.float32,  "float32": torch.float32,}

def _cuda_dtype_supported(dtype: torch.dtype) -> bool:
    if not torch.cuda.is_available():
        return False
    return dtype in (torch.float16, torch.bfloat16)

def make_grad_scaler(device: str = "cuda", enabled: bool = True):
    if not enabled:
        return None

    if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
        try:
            sig = inspect.signature(torch.amp.GradScaler)
            if len(sig.parameters) >= 1:
                return torch.amp.GradScaler(device if device in ("cuda", "cpu") else "cuda")
            return torch.amp.GradScaler()
        except Exception:
            pass

    if hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "GradScaler"):
        return torch.cuda.amp.GradScaler()
    return None


@contextmanager
def autocast_ctx(
    device: str = "cuda",
    enabled: bool = True,
    dtype: str = "fp16",
    cache_enabled: bool = True,):
    """
    Context manager de autocast:
      - cuda: fp16 por defecto (ideal en T4)
      - cpu: bfloat16 si está disponible
    """
    if not enabled:
        with nullcontext():
            yield
        return

    if device == "cuda":
        want = DTYPE_MAP.get(dtype.lower(), torch.float16)
        use = want if _cuda_dtype_supported(want) else torch.float16
        with torch.amp.autocast(device_type="cuda", dtype=use, cache_enabled=cache_enabled):
            yield
        return

    if device == "cpu":
        try:
            with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled):
                yield
        except Exception:
            with nullcontext():
                yield
        return

    with nullcontext():
        yield

In [None]:

def save_checkpoint(
    path: str,
    model,
    optimizer,
    scheduler,
    scaler,
    epoch: int,
    best_top1: float,
    extra: dict | None = None,):

    ckpt = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "scheduler": scheduler.state_dict() if scheduler is not None else None,
        "scaler": scaler.state_dict() if scaler is not None else None,
        "epoch": epoch,
        "best_top1": best_top1,
        "extra": extra or {},}
    torch.save(ckpt, path)


def load_checkpoint(
    path: str,
    model,
    optimizer=None,
    scheduler=None,
    scaler=None,
    map_location="cpu",
    strict: bool = True,):
    ckpt = torch.load(path, map_location=map_location)
    model.load_state_dict(ckpt["model"], strict=strict)

    if optimizer is not None and ckpt.get("optimizer") is not None:
        optimizer.load_state_dict(ckpt["optimizer"])
    if scheduler is not None and ckpt.get("scheduler") is not None:
        scheduler.load_state_dict(ckpt["scheduler"])
    if scaler is not None and ckpt.get("scaler") is not None:
        scaler.load_state_dict(ckpt["scaler"])
    return ckpt

In [None]:
@torch.no_grad()
def accuracy_topk(logits: torch.Tensor, targets: torch.Tensor, ks=(1, 3, 5)) -> Dict[int, float]:
    """
    targets can be:
      - int64 class indices [B]
      - soft targets [B, num_classes] (we'll argmax for accuracy reporting)
    """
    if targets.ndim == 2:
        targets = targets.argmax(dim=1)

    max_k = max(ks)
    B = targets.size(0)
    _, pred = torch.topk(logits, k=max_k, dim=1)
    correct = pred.eq(targets.view(-1, 1).expand_as(pred))
    out = {}
    for k in ks:
        out[k] = 100.0 * correct[:, :k].any(dim=1).float().sum().item() / B
    return out

In [None]:
def _one_hot(targets: torch.Tensor, num_classes: int) -> torch.Tensor:
    return F.one_hot(targets, num_classes=num_classes).float()


def soft_target_cross_entropy(logits: torch.Tensor, targets_soft: torch.Tensor) -> torch.Tensor:
    logp = F.log_softmax(logits, dim=1)
    return -(targets_soft * logp).sum(dim=1).mean()


def apply_mixup_cutmix(
    images: torch.Tensor,
    targets: torch.Tensor,
    num_classes: int,
    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    prob: float = 1.0,):
    """
    Returns:
      images_aug: [B,3,H,W]
      targets_soft: [B,K]
    """
    if prob <= 0.0 or (mixup_alpha <= 0.0 and cutmix_alpha <= 0.0):
        return images, _one_hot(targets, num_classes)

    if random.random() > prob:
        return images, _one_hot(targets, num_classes)

    use_cutmix = (cutmix_alpha > 0.0) and (mixup_alpha <= 0.0 or random.random() < 0.5)
    B, _, H, W = images.shape
    perm = torch.randperm(B, device=images.device)

    y1 = _one_hot(targets, num_classes)
    y2 = _one_hot(targets[perm], num_classes)

    if use_cutmix:
        lam = torch.distributions.Beta(cutmix_alpha, cutmix_alpha).sample().item()
        cut_w = int(W * math.sqrt(1.0 - lam))
        cut_h = int(H * math.sqrt(1.0 - lam))
        cx = random.randint(0, W - 1)
        cy = random.randint(0, H - 1)

        x1 = max(cx - cut_w // 2, 0)
        x2 = min(cx + cut_w // 2, W)
        y1b = max(cy - cut_h // 2, 0)
        y2b = min(cy + cut_h // 2, H)

        images_aug = images.clone()
        images_aug[:, :, y1b:y2b, x1:x2] = images[perm, :, y1b:y2b, x1:x2]

        # adjust lambda based on actual area swapped
        area = (x2 - x1) * (y2b - y1b)
        lam = 1.0 - area / float(W * H)
    else:
        lam = torch.distributions.Beta(mixup_alpha, mixup_alpha).sample().item()
        images_aug = images * lam + images[perm] * (1.0 - lam)

    targets_soft = y1 * lam + y2 * (1.0 - lam)
    return images_aug, targets_soft

In [None]:
def build_param_groups_no_wd(model: nn.Module, weight_decay: float):
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue

        name_l = name.lower()
        # no decay for biases + norms + positional/class tokens
        if (
            name.endswith(".bias")
            or ("norm" in name_l)
            or ("bn" in name_l)
            or ("ln" in name_l)
            or ("pos" in name_l)         # pos_embed / pos
            or ("cls_token" in name_l)
        ):
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0}]


class WarmupCosineLR:
    """Warmup linear for warmup_steps, then cosine to min_lr. Step-based."""
    def __init__(self, optimizer, total_steps: int, warmup_steps: int, min_lr: float = 0.0):
        self.optimizer = optimizer
        self.total_steps = int(total_steps)
        self.warmup_steps = int(warmup_steps)
        self.min_lr = float(min_lr)
        self.base_lrs = [g["lr"] for g in optimizer.param_groups]
        self.step_num = 0

    def step(self):
        self.step_num += 1
        t = self.step_num

        for i, group in enumerate(self.optimizer.param_groups):
            base = self.base_lrs[i]
            if t <= self.warmup_steps and self.warmup_steps > 0:
                lr = base * (t / self.warmup_steps)
            else:
                tt = min(t, self.total_steps)
                denom = max(1, self.total_steps - self.warmup_steps)
                progress = (tt - self.warmup_steps) / denom
                cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
                lr = self.min_lr + (base - self.min_lr) * cosine
            group["lr"] = lr

    def state_dict(self):
        return {"step_num": self.step_num}

    def load_state_dict(self, d):
        self.step_num = int(d.get("step_num", 0))

In [None]:
def train_one_epoch(
    model: nn.Module,
    dataloader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    device: str = "cuda", 
    scaler=None,
    autocast_dtype: str = "bf16",
    use_amp: bool = True,
    grad_clip_norm: Optional[float] = 1.0,
    label_smoothing: float = 0.1,
    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    mix_prob: float = 1.0,
    num_classes: int = 100,
    channels_last: bool = False,
    print_every: int = 100,) -> Tuple[float, Dict[str, float]]:
    """
    Single-process train loop (no DDP, no EMA).

    Expects helpers already defined in your file:
      - autocast_ctx
      - apply_mixup_cutmix
      - soft_target_cross_entropy
      - accuracy_topk
    """
    model.train()

    use_scaler = (scaler is not None) and use_amp and autocast_dtype.lower() in ("fp16", "float16")

    running_loss = 0.0
    total = 0
    c1 = c3 = c5 = 0.0

    t0 = time.time()
    for step, (images, targets) in enumerate(dataloader, start=1):
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if channels_last:
            images = images.contiguous(memory_format=torch.channels_last)

        B = targets.size(0)

        # mixup/cutmix => soft targets
        images_aug, targets_soft = apply_mixup_cutmix(
            images, targets,
            num_classes=num_classes,
            mixup_alpha=mixup_alpha,
            cutmix_alpha=cutmix_alpha,
            prob=mix_prob,)

        optimizer.zero_grad(set_to_none=True)

        with autocast_ctx(device=device, enabled=use_amp, dtype=autocast_dtype, cache_enabled=True):
            logits = model(images_aug)  # [B, K]

        # loss in fp32
        if (mixup_alpha > 0.0) or (cutmix_alpha > 0.0):
            # With mixup/cutmix, label smoothing is usually redundant.
            loss = soft_target_cross_entropy(logits.float(), targets_soft)
        else:
            loss = F.cross_entropy(logits.float(), targets, label_smoothing=label_smoothing)

        if use_scaler:
            scaler.scale(loss).backward()
            if grad_clip_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        # metrics
        running_loss += loss.item() * B
        total += B
        accs = accuracy_topk(
            logits.detach(),
            targets_soft if targets_soft.ndim == 2 else targets,
            ks=(1, 3, 5),)

        c1 += accs[1] * B / 100.0
        c3 += accs[3] * B / 100.0
        c5 += accs[5] * B / 100.0

        if print_every and (step % print_every == 0):
            dt = time.time() - t0
            imgs_sec = total / max(dt, 1e-9)
            print(
                f"[train step {step}/{len(dataloader)}] "
                f"loss {running_loss/total:.4f} | "
                f"top1 {100*c1/total:.2f}% | top3 {100*c3/total:.2f}% | top5 {100*c5/total:.2f}% | "
                f"{imgs_sec:.1f} img/s | lr {optimizer.param_groups[0]['lr']:.2e}")

    avg_loss = running_loss / max(1, total)
    metrics = {
        "top1": 100.0 * c1 / max(1, total),
        "top3": 100.0 * c3 / max(1, total),
        "top5": 100.0 * c5 / max(1, total),}

    return avg_loss, metrics

In [None]:

@torch.no_grad()
def evaluate_one_epoch(
    model: nn.Module,
    dataloader,
    device: str = "cuda",
    autocast_dtype: str = "bf16",
    use_amp: bool = True,
    label_smoothing: float = 0.0,
    channels_last: bool = False) -> Tuple[float, Dict[str, float]]:
    """
    Single-process evaluation loop (no DDP, no EMA).

    Expects helpers already defined in your file:
      - autocast_ctx
      - accuracy_topk
    """
    model.eval().to(device)

    running_loss = 0.0
    total = 0
    c1 = c3 = c5 = 0.0

    for images, targets in dataloader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if channels_last:
            images = images.contiguous(memory_format=torch.channels_last)

        B = targets.size(0)

        with autocast_ctx(device=device, enabled=use_amp, dtype=autocast_dtype, cache_enabled=True):
            logits = model(images)

        loss = F.cross_entropy(logits.float(), targets, label_smoothing=label_smoothing)

        running_loss += loss.item() * B
        total += B

        accs = accuracy_topk(logits, targets, ks=(1, 3, 5))
        c1 += accs[1] * B / 100.0
        c3 += accs[3] * B / 100.0
        c5 += accs[5] * B / 100.0

    avg_loss = running_loss / max(1, total)
    metrics = {
        "top1": 100.0 * c1 / max(1, total),
        "top3": 100.0 * c3 / max(1, total),
        "top5": 100.0 * c5 / max(1, total),}

    return avg_loss, metrics

In [None]:
def train_model(
    model: nn.Module,
    train_loader,
    epochs: int,
    val_loader=None,
    device: str = "cuda",
    lr: float = 5e-4,
    weight_decay: float = 0.05,
    autocast_dtype: str = "fp16",
    use_amp: bool = True,
    grad_clip_norm: float | None = 1.0,
    warmup_ratio: float = 0.05,
    min_lr: float = 0.0,
    label_smoothing: float = 0.1,
    print_every: int = 100,
    save_path: str = "best_model.pt",
    last_path: str = "last_model.pt",
    resume_path: str | None = None,
    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    mix_prob: float = 1.0,
    num_classes: int = 100,
    channels_last: bool = False,
    early_stop: bool = True,
    early_stop_metric: str = "top1",          # "top1" or "loss"
    early_stop_patience: int = 10,
    early_stop_min_delta: float = 0.0,
    early_stop_require_monotonic: bool = False,) -> Tuple[Dict[str, list], nn.Module]:

    """
    Single-process trainer (no DDP, no EMA).

    Expects helpers already defined in your file:
      - build_param_groups_no_wd
      - WarmupCosineLR
      - make_grad_scaler
      - save_checkpoint / load_checkpoint
      - train_one_epoch (the one above)
      - evaluate_one_epoch
    """
    model.to(device)

    # Optimizer
    param_groups = build_param_groups_no_wd(model, weight_decay=weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.999), eps=1e-8)

    # Scheduler warmup + cosine (step-based)
    total_steps = epochs * len(train_loader)
    warmup_steps = int(total_steps * warmup_ratio)
    scheduler = WarmupCosineLR(
        optimizer,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        min_lr=min_lr)

    # AMP scaler
    scaler = None
    if use_amp and autocast_dtype.lower() in ("fp16", "float16"):
        scaler = make_grad_scaler(device=device, enabled=True)

    # Resume
    start_epoch = 0
    best_val_top1 = -float("inf")
    best_val_loss = float("inf")
    best_epoch = 0

    if resume_path is not None:
        ckpt = load_checkpoint(
            resume_path, model,
            optimizer=optimizer, scheduler=scheduler, scaler=scaler,
            map_location=device,
            strict=True,)

        start_epoch = int(ckpt.get("epoch", 0))
        best_val_top1 = float(ckpt.get("best_top1", best_val_top1))
        extra = ckpt.get("extra", {}) or {}
        best_val_loss = float(extra.get("best_val_loss", best_val_loss))
        best_epoch = int(extra.get("best_epoch", best_epoch))
        print(f"Resumed from {resume_path} at epoch {start_epoch} | best_top1 {best_val_top1:.2f}% | best_loss {best_val_loss:.4f}")

    history = {
        "train_loss": [], "train_top1": [], "train_top3": [], "train_top5": [],
        "val_loss": [], "val_top1": [], "val_top3": [], "val_top5": [],
        "lr": [],}

    metric = early_stop_metric.lower()
    assert metric in ("top1", "loss")
    mode = "max" if metric == "top1" else "min"
    best_metric = best_val_top1 if metric == "top1" else best_val_loss
    patience = int(early_stop_patience)
    bad_epochs = 0
    last_vals: list[float] = []

    def _is_improvement(curr: float, best: float) -> bool:
        d = float(early_stop_min_delta)
        return (curr > (best + d)) if mode == "max" else (curr < (best - d))

    def _degradation_monotonic(vals: list[float]) -> bool:
        if not early_stop_require_monotonic or len(vals) < 2:
            return True
        if mode == "max":
            return all(vals[i] >= vals[i + 1] for i in range(len(vals) - 1))
        else:
            return all(vals[i] <= vals[i + 1] for i in range(len(vals) - 1))

    for epoch in range(start_epoch + 1, epochs + 1):
        print(f"\n=== Epoch {epoch}/{epochs} ===")
        t_epoch = time.time()

        # If a sampler supports set_epoch, reshuffle deterministically per epoch (works even without DDP)
        if hasattr(train_loader, "sampler") and hasattr(train_loader.sampler, "set_epoch"):
            train_loader.sampler.set_epoch(epoch)
        if val_loader is not None and hasattr(val_loader, "sampler") and hasattr(val_loader.sampler, "set_epoch"):
            val_loader.sampler.set_epoch(epoch)

        # --- Train ---
        tr_loss, tr_m = train_one_epoch(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            scaler=scaler,
            autocast_dtype=autocast_dtype,
            use_amp=use_amp,
            grad_clip_norm=grad_clip_norm,
            label_smoothing=label_smoothing,
            mixup_alpha=mixup_alpha,
            cutmix_alpha=cutmix_alpha,
            mix_prob=mix_prob,
            num_classes=num_classes,
            channels_last=channels_last,
            print_every=print_every,)

        history["train_loss"].append(tr_loss)
        history["train_top1"].append(tr_m["top1"])
        history["train_top3"].append(tr_m["top3"])
        history["train_top5"].append(tr_m["top5"])
        history["lr"].append(float(optimizer.param_groups[0]["lr"]))

        print(
            f"[Train] loss {tr_loss:.4f} | top1 {tr_m['top1']:.2f}% | top3 {tr_m['top3']:.2f}% | "
            f"top5 {tr_m['top5']:.2f}% | lr {optimizer.param_groups[0]['lr']:.2e}")

        # Save "last" checkpoint every epoch
        save_checkpoint(
            last_path, model, optimizer, scheduler, scaler,
            epoch=epoch, best_top1=best_val_top1,
            extra={
                "autocast_dtype": autocast_dtype,
                "use_amp": use_amp,
                "best_val_loss": best_val_loss,
                "best_epoch": best_epoch,
                "early_stop_metric": metric,
                "early_stop_patience": patience,
                "early_stop_min_delta": float(early_stop_min_delta),},)

        stop_now = False

        # --- Val ---
        if val_loader is not None:
            va_loss, va_m = evaluate_one_epoch(
                model=model,
                dataloader=val_loader,
                device=device,
                autocast_dtype=autocast_dtype,
                use_amp=use_amp,
                label_smoothing=0.0,
                channels_last=channels_last,)

            history["val_loss"].append(va_loss)
            history["val_top1"].append(va_m["top1"])
            history["val_top3"].append(va_m["top3"])
            history["val_top5"].append(va_m["top5"])

            print(
                f"[Val]   loss {va_loss:.4f} | top1 {va_m['top1']:.2f}% | top3 {va_m['top3']:.2f}% | top5 {va_m['top5']:.2f}%")

            # Best checkpoint by val_top1
            if va_m["top1"] > best_val_top1:
                best_val_top1 = float(va_m["top1"])
                if va_loss < best_val_loss:
                    best_val_loss = float(va_loss)
                    best_epoch = int(epoch)

                save_checkpoint(
                    save_path, model, optimizer, scheduler, scaler,
                    epoch=epoch, best_top1=best_val_top1,
                    extra={
                        "autocast_dtype": autocast_dtype,
                        "use_amp": use_amp,
                        "best_val_loss": best_val_loss,
                        "best_epoch": best_epoch,},)

                print(f"Best saved to {save_path} (val top1 {best_val_top1:.2f}%)")

            # Early stop on chosen metric
            if early_stop:
                curr_metric = float(va_m["top1"]) if metric == "top1" else float(va_loss)

                last_vals.append(curr_metric)
                if len(last_vals) > patience:
                    last_vals = last_vals[-patience:]

                if _is_improvement(curr_metric, best_metric):
                    best_metric = curr_metric
                    bad_epochs = 0
                else:
                    bad_epochs += 1

                if bad_epochs >= patience and _degradation_monotonic(last_vals):
                    print(f"Early-stop: no improvement on val_{metric} for {patience} epochs.")
                    stop_now = True

        if stop_now:
            break

        dt = time.time() - t_epoch
        print(f"Epoch time: {dt/60:.2f} min")

    return history, model

In [None]:
import gc

def free_all_cuda(*names, verbose=True, globals_dict=None, locals_dict=None):
    """
    Borra variables por nombre (strings) de globals/locals para evitar referencias colgadas en notebooks.
    """
    if globals_dict is None: globals_dict = globals()
    if locals_dict is None:  locals_dict  = locals()

    for n in names:
        if n in locals_dict:
            del locals_dict[n]
        if n in globals_dict:
            del globals_dict[n]

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

    if verbose and torch.cuda.is_available():
        alloc = torch.cuda.memory_allocated() / 1024**2
        res   = torch.cuda.memory_reserved() / 1024**2
        print(f"[CUDA] allocated={alloc:.1f} MB | reserved(cache)={res:.1f} MB")

free_all_cuda("model", "optimizer", "scaler", "scheduler", "batch", "loss", "outputs", "logits")

[CUDA] allocated=0.0 MB | reserved(cache)=0.0 MB


In [None]:
def cifar64_stages_t4_tinyplus(drop_path=0.08):
    # resoluciones: 64 -> 32 -> 16 -> 8
    return [
        StageCfg(dim=64,  depth=2, num_heads=2,  grid_size=8, outlook_heads=2,  drop_path=drop_path),
        StageCfg(dim=128, depth=2, num_heads=4,  grid_size=8, outlook_heads=4,  drop_path=drop_path),
        StageCfg(dim=256, depth=3, num_heads=8,  grid_size=4, outlook_heads=8,  drop_path=drop_path),
        StageCfg(dim=384, depth=1, num_heads=6,  grid_size=2, outlook_heads=6,  drop_path=drop_path),
    ]

stages = cifar64_stages_t4_tinyplus()
model = OutlookerFrontGridNet(num_classes=100, stages=stages, stem_dim=64, outlooker_front_depth=1, dpr_max=0.1)


device = "cuda" if torch.cuda.is_available() else "cpu"

history, model = train_model(
    model=model,
    train_loader=train_loader,
    epochs=50,
    val_loader=val_loader,
    device=device,

    lr=5e-4,
    weight_decay=0.05,

    autocast_dtype="fp16" if device == "cuda" else "fp32",
    use_amp=(device == "cuda"),
    grad_clip_norm=1.0,

    warmup_ratio=0.05,
    min_lr=1e-6,

    label_smoothing=0.0,

    print_every=100,
    save_path="best_maxout_medium.pt",
    last_path="last_maxout_medium.pt",
    resume_path=None,

    # Augmentations
    mix_prob=0.5,
    mixup_alpha=0.0,
    cutmix_alpha=1.0,

    num_classes=100,
    channels_last=True)




=== Epoch 1/50 ===
[train step 100/352] loss 4.4824 | top1 3.52% | top3 8.54% | top5 12.80% | 116.3 img/s | lr 5.68e-05
[train step 200/352] loss 4.3496 | top1 5.12% | top3 12.41% | top5 18.10% | 120.4 img/s | lr 1.14e-04
[train step 300/352] loss 4.2459 | top1 6.55% | top3 15.18% | top5 21.68% | 121.8 img/s | lr 1.70e-04
[Train] loss 4.2034 | top1 7.19% | top3 16.40% | top5 23.26% | lr 2.00e-04
[Val]   loss 3.7152 | top1 13.10% | top3 27.08% | top5 36.98%
Best saved to best_maxout_medium.pt (val top1 13.10%)
Epoch time: 6.52 min

=== Epoch 2/50 ===
[train step 100/352] loss 3.7666 | top1 13.28% | top3 28.25% | top5 37.82% | 124.1 img/s | lr 2.57e-04
[train step 200/352] loss 3.6872 | top1 14.67% | top3 30.08% | top5 39.87% | 124.3 img/s | lr 3.14e-04
[train step 300/352] loss 3.6362 | top1 15.90% | top3 32.04% | top5 41.82% | 124.5 img/s | lr 3.70e-04
[Train] loss 3.5984 | top1 16.66% | top3 33.18% | top5 43.03% | lr 4.00e-04
[Val]   loss 3.1387 | top1 22.26% | top3 41.72% | top5 51.