In [1]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import RandAugment

def get_cifar100_datasets(
    data_dir: str = "./data",
    val_split: float = 0.0,
    ra_num_ops: int = 2,
    ra_magnitude: int = 7,
    random_erasing_p: float = 0.25,
    erasing_scale=(0.02, 0.20),
    erasing_ratio=(0.3, 3.3),
    img_size: int = 32,):

    """
    CIFAR-100 datasets con augmentations "mix-friendly":
    diseñadas para complementar Mixup/CutMix (en el loop) sin pasarse.

    img_size:
      - 32 (default): CIFAR nativo.
      - >32: upsample (p.ej. 64) para experimentos (más tokens/compute).
    """
    if img_size < 32:
        raise ValueError(f"img_size must be >= 32 for CIFAR-100. Got {img_size}.")

    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    # Si subimos resolución, primero hacemos resize y adaptamos crop/padding.
    # Padding recomendado proporcional: 32->4, 64->8, etc.

    crop_padding = max(4, img_size // 8)

    train_ops = []
    if img_size != 32:
        train_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))

    train_ops += [
        transforms.RandomCrop(img_size, padding=crop_padding),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=ra_num_ops, magnitude=ra_magnitude),
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
        transforms.RandomErasing(
            p=random_erasing_p,
            scale=erasing_scale,
            ratio=erasing_ratio,
            value="random",),]

    train_transform = transforms.Compose(train_ops)

    test_ops = []
    if img_size != 32:
        test_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))

    test_ops += [
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),]

    test_transform = transforms.Compose(test_ops)

    full_train_dataset = datasets.CIFAR100(
        root=data_dir, train=True, download=True, transform=train_transform)

    test_dataset = datasets.CIFAR100(
        root=data_dir, train=False, download=True, transform=test_transform)

    if val_split > 0.0:
        n_total = len(full_train_dataset)
        n_val = int(n_total * val_split)
        n_train = n_total - n_val
        train_dataset, val_dataset = random_split(
            full_train_dataset,
            [n_train, n_val],
            generator=torch.Generator().manual_seed(7),)

    else:
        train_dataset = full_train_dataset
        val_dataset = None

    return train_dataset, val_dataset, test_dataset


def get_cifar100_dataloaders(
    batch_size: int = 128,
    data_dir: str = "./data",
    num_workers: int = 2,
    val_split: float = 0.0,
    pin_memory: bool = True,
    ra_num_ops: int = 2,
    ra_magnitude: int = 7,
    random_erasing_p: float = 0.25,
    img_size: int = 32,):
    """
    Dataloaders CIFAR-100 listos para entrenar con Mixup/CutMix en el loop.
    Augmentations no tan agresivas.

    img_size:
      - 32 (default): CIFAR nativo.
      - 64: experimento de upsample (ojo: más compute).
    """
    train_ds, val_ds, test_ds = get_cifar100_datasets(
        data_dir=data_dir,
        val_split=val_split,
        ra_num_ops=ra_num_ops,
        ra_magnitude=ra_magnitude,
        random_erasing_p=random_erasing_p,
        img_size=img_size,)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=(num_workers > 0),)

    val_loader = None
    if val_ds is not None:
        val_loader = DataLoader(
            val_ds,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=(num_workers > 0),)

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=(num_workers > 0),)

    return train_loader, val_loader, test_loader

In [55]:
def _ddp_is_on():
    return dist.is_available() and dist.is_initialized()

def _ddp_rank():
    return dist.get_rank() if _ddp_is_on() else 0

def _ddp_barrier():
    if _ddp_is_on():
        dist.barrier()

def get_cifar100_datasets(
    data_dir: str = "./data",
    val_split: float = 0.0,
    ra_num_ops: int = 2,
    ra_magnitude: int = 7,
    random_erasing_p: float = 0.25,
    erasing_scale=(0.02, 0.20),
    erasing_ratio=(0.3, 3.3),
    img_size: int = 32,
    seed: int = 7,
    ddp_safe_download: bool = True):
    """
    CIFAR-100 datasets con aug 'mix-friendly' y soporte DDP:
      - Descarga segura: solo rank0 descarga, luego barrier.
      - Split determinista: train/val indices iguales en todos los ranks.
      - Val usa test_transform (SIN aug estocásticos).
    """
    if img_size < 32:
        raise ValueError(f"img_size must be >= 32 for CIFAR-100. Got {img_size}.")

    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    crop_padding = max(4, img_size // 8)

    train_ops = []
    if img_size != 32:
        train_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))
    train_ops += [
        transforms.RandomCrop(img_size, padding=crop_padding),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=ra_num_ops, magnitude=ra_magnitude),
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
        transforms.RandomErasing(
            p=random_erasing_p,
            scale=erasing_scale,
            ratio=erasing_ratio,
            value="random",
        ),
    ]
    train_transform = transforms.Compose(train_ops)

    test_ops = []
    if img_size != 32:
        test_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))
    test_ops += [
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),]
    
    test_transform = transforms.Compose(test_ops)

    #  DDP-safe download 
    if ddp_safe_download and _ddp_is_on():
        if _ddp_rank() == 0:
            datasets.CIFAR100(root=data_dir, train=True, download=True)
            datasets.CIFAR100(root=data_dir, train=False, download=True)
        _ddp_barrier()
        download_flag = False
    else:
        download_flag = True

    # Base datasets (dos versiones: train aug y eval clean)
    full_train_aug = datasets.CIFAR100(root=data_dir, train=True, download=download_flag, transform=train_transform)
    full_train_eval = datasets.CIFAR100(root=data_dir, train=True, download=False, transform=test_transform)
    test_dataset = datasets.CIFAR100(root=data_dir, train=False, download=download_flag, transform=test_transform)

    if val_split > 0.0:
        n_total = len(full_train_aug)
        n_val = int(n_total * val_split)
        n_train = n_total - n_val

        g = torch.Generator().manual_seed(seed)
        perm = torch.randperm(n_total, generator=g).tolist()
        train_idx = perm[:n_train]
        val_idx = perm[n_train:]

        train_dataset = Subset(full_train_aug, train_idx)
        val_dataset = Subset(full_train_eval, val_idx)   
    else:
        train_dataset = full_train_aug
        val_dataset = None

    return train_dataset, val_dataset, test_dataset

In [46]:
train_loader, val_loader, test_loader = get_cifar100_dataloaders(
    batch_size=256,
    data_dir="./data/cifar100",
    num_workers=2,
    val_split=0.1,
    img_size=32)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbeddingConv(nn.Module):
    """
    Patch embedding estilo Swin.

    - Conv2d con kernel=stride=patch_size para convertir imagen -> grilla de patches.
    - Devuelve el mapa 2D en formato canal-al-final: [B, Hp, Wp, D],
      (más cómodo para window partition).
    - Opcionalmente devuelve tokens [B, N, D].
    - Opcional padding automático si H/W no son divisibles por patch_size.
    """

    def __init__(
        self,
        patch_size: int | tuple[int, int] = 4,
        in_chans: int = 3,
        embed_dim: int = 192,
        norm_layer: type[nn.Module] | None = nn.LayerNorm,
        pad_if_needed: bool = True,
        return_tokens: bool = True):

        super().__init__()

        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)

        self.patch_size = patch_size  # (Ph, Pw)
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.pad_if_needed = pad_if_needed
        self.return_tokens = return_tokens

        # [B, C, H, W] -> [B, D, Hp, Wp]
        self.proj = nn.Conv2d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
            bias=True,)

        # En Swin normalmente LayerNorm sobre la última dimensión
        self.norm = norm_layer(embed_dim) if norm_layer is not None else None

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [B, C, H, W]

        Returns:
            x_map:    [B, Hp, Wp, D]
            (Hp, Wp): tamaño espacial en patches
            x_tokens (opcional): [B, N, D]
            pad_hw (opcional): (pad_h, pad_w) aplicados a la imagen
        """
        B, C, H, W = x.shape
        Ph, Pw = self.patch_size

        pad_h = (Ph - (H % Ph)) % Ph
        pad_w = (Pw - (W % Pw)) % Pw

        if (pad_h != 0 or pad_w != 0):
            if not self.pad_if_needed:
                raise AssertionError(
                    f"Image size ({H}x{W}) no es divisible por patch_size {self.patch_size} "
                    f"y pad_if_needed=False.")

            x = F.pad(x, (0, pad_w, 0, pad_h))

        # [B, D, Hp, Wp]
        x = self.proj(x)
        Hp, Wp = x.shape[2], x.shape[3]

        # canal al final -> [B, Hp, Wp, D]
        x_map = x.permute(0, 2, 3, 1).contiguous()

        if self.norm is not None:
            x_map = self.norm(x_map)

        if self.return_tokens:
            x_tokens = x_map.view(B, Hp * Wp, self.embed_dim)
            return x_map, (Hp, Wp), x_tokens, (pad_h, pad_w)

        return x_map, (Hp, Wp), (pad_h, pad_w)



In [7]:
def test_patch_embedding_conv():
    torch.manual_seed(0)

    #  tamaño divisible (64 con patch=4)
    B, C, H, W = 2, 3, 64, 64
    x = torch.randn(B, C, H, W)

    pe = PatchEmbeddingConv(
        patch_size=4,
        in_chans=3,
        embed_dim=192,
        norm_layer=torch.nn.LayerNorm,
        pad_if_needed=True,
        return_tokens=True,)

    x_map, (Hp, Wp), x_tok, (pad_h, pad_w) = pe(x)

    assert x_map.shape == (B, Hp, Wp, 192)
    assert x_tok.shape == (B, Hp * Wp, 192)
    assert (pad_h, pad_w) == (0, 0)
    assert (Hp, Wp) == (H // 4, W // 4)

    print("[OK] PatchEmbeddingConv divisible:",
          "x_map", tuple(x_map.shape),
          "| x_tok", tuple(x_tok.shape),
          "| pad", (pad_h, pad_w))

    # tamaño NO divisible (65x63 con patch=4) -> debería paddear
    H2, W2 = 65, 63
    x2 = torch.randn(B, C, H2, W2)

    x_map2, (Hp2, Wp2), x_tok2, (pad_h2, pad_w2) = pe(x2)

    assert (H2 + pad_h2) % 4 == 0
    assert (W2 + pad_w2) % 4 == 0
    assert x_map2.shape == (B, Hp2, Wp2, 192)
    assert x_tok2.shape == (B, Hp2 * Wp2, 192)

    print("[OK] PatchEmbeddingConv non-divisible:",
          "input", (H2, W2),
          "| padded by", (pad_h2, pad_w2),
          "| patches", (Hp2, Wp2),
          "| x_map", tuple(x_map2.shape))

test_patch_embedding_conv()

[OK] PatchEmbeddingConv divisible: x_map (2, 16, 16, 192) | x_tok (2, 256, 192) | pad (0, 0)
[OK] PatchEmbeddingConv non-divisible: input (65, 63) | padded by (3, 1) | patches (17, 16) | x_map (2, 17, 16, 192)


In [8]:
class OutlookAttention(nn.Module):
    """
    Outlook Attention (VOLO): agregación local dinámica sobre ventanas.

    Entrada:  x_map [B, H, W, C]  (channel-last)
    Salida:   y_map [B, H, W, C]

    Parámetros:
      - dim: canales C
      - kernel_size: k (vecindario k×k)
      - stride: s (si s>1 hace downsample tipo "outlook pooling"; para CIFAR típicamente s=1)
      - num_heads: h (partimos canales en cabezas, como MHSA)
    """

    def __init__(
        self,
        dim: int,
        num_heads: int = 6,
        kernel_size: int = 3,
        stride: int = 1,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,):

        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.kernel_size = kernel_size
        self.stride = stride

        # Genera pesos de atención por posición: [B, H, W, heads * k*k]
        self.attn = nn.Linear(dim, num_heads * kernel_size * kernel_size, bias=True)

        # Proyección para values (antes de unfold)
        self.v = nn.Linear(dim, dim, bias=True)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim, bias=True)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x_map: torch.Tensor) -> torch.Tensor:
        """
        x_map: [B, H, W, C]
        """
        B, H, W, C = x_map.shape
        k = self.kernel_size
        s = self.stride
        heads = self.num_heads
        hd = self.head_dim

        # attention weights
        a = self.attn(x_map)
        # si stride>1, la atención se evalúa en posiciones downsampled
        if s > 1:
            # downsample espacialmente (simple avg pool sobre channel-last)
            a = a.permute(0, 3, 1, 2)                       # [B, heads*k*k, H, W]
            a = F.avg_pool2d(a, kernel_size=s, stride=s)    # [B, heads*k*k, Hs, Ws]
            a = a.permute(0, 2, 3, 1).contiguous()          # [B, Hs, Ws, heads*k*k]

        Hs, Ws = a.shape[1], a.shape[2]
        a = a.view(B, Hs * Ws, heads, k * k)
        a = F.softmax(a, dim=-1)
        a = self.attn_drop(a)

        # values map
        v = self.v(x_map)
        v = v.permute(0, 3, 1, 2).contiguous()

        # unfold extrae vecindarios k×k para cada posición (con padding para "same")
        pad = k // 2
        v_unf = F.unfold(v, kernel_size=k, padding=pad, stride=s)
        v_unf = v_unf.view(B, C, k * k, Hs * Ws).permute(0, 3, 1, 2).contiguous()
        v_unf = v_unf.view(B, Hs * Ws, heads, hd, k * k)

        # apply attention: weighted sum over neighborhood
        # a:     [B, Hs*Ws, heads, k*k]
        # v_unf: [B, Hs*Ws, heads, hd, k*k]
        y = (v_unf * a.unsqueeze(3)).sum(dim=-1)
        y = y.reshape(B, Hs * Ws, C)              # concat heads

        # fold back to spatial map
        y_map = y.view(B, Hs, Ws, C)

        y_map = self.proj(y_map)
        y_map = self.proj_drop(y_map)
        return y_map

In [9]:
def test_outlook_attention_stride1():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    oa = OutlookAttention(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=1,
        attn_drop=0.0,
        proj_drop=0.0)

    y = oa(x_map)
    assert y.shape == x_map.shape, f"Expected {x_map.shape}, got {y.shape}"

    loss = y.mean()
    loss.backward()

    assert x_map.grad is not None, "No gradient flowed to input!"
    assert torch.isfinite(x_map.grad).all(), "Non-finite grads!"

    print("[OK] OutlookAttention stride=1:",
          "in", tuple(x_map.shape),
          "| out", tuple(y.shape),
          "| grad mean", float(x_map.grad.abs().mean()))

test_outlook_attention_stride1()

[OK] OutlookAttention stride=1: in (2, 16, 16, 192) | out (2, 16, 16, 192) | grad mean 2.7283142571832286e-06


In [10]:
def test_outlook_attention_stride2():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    oa = OutlookAttention(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=2,
        attn_drop=0.0,
        proj_drop=0.0)

    y = oa(x_map)

    assert y.shape[0] == B and y.shape[-1] == C
    assert y.shape[1] == H // 2 and y.shape[2] == W // 2, f"Got {y.shape[1:3]}"

    loss = y.mean()
    loss.backward()
    assert x_map.grad is not None
    assert torch.isfinite(x_map.grad).all()

    print("[OK] OutlookAttention stride=2:",
          "in", (B, H, W, C),
          "| out", tuple(y.shape))

test_outlook_attention_stride2()

[OK] OutlookAttention stride=2: in (2, 16, 16, 192) | out (2, 8, 8, 192)


In [11]:
class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x

        keep_prob = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class MLP(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, drop: float = 0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class OutlookerBlock(nn.Module):
    """
    Bloque VOLO Outlooker:
      x -> LN -> OutlookAttention -> DropPath + residual
        -> LN -> MLP -> DropPath + residual

    Input/Output: [B, H, W, C]
    """
    def __init__(
        self,
        dim: int,
        num_heads: int,
        kernel_size: int = 3,
        stride: int = 1,
        mlp_ratio: float = 4.0,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        drop_path: float = 0.0,
        mlp_drop: float = 0.0):

        super().__init__()
        self.norm1 = nn.LayerNorm(dim)

        self.attn = OutlookAttention(
            dim=dim,
            num_heads=num_heads,
            kernel_size=kernel_size,
            stride=stride,
            attn_drop=attn_drop,
            proj_drop=proj_drop,)

        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)

        self.mlp = MLP(dim=dim, hidden_dim=hidden_dim, drop=mlp_drop)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x_map: torch.Tensor) -> torch.Tensor:
        """
        x_map: tensor de forma (B, C, H, W) o (B, N, C), según el bloque.
        """

        # Primer sub-bloque: Norm -> Atención -> DropPath -> Residual

        # Normalización del input
        x_norm_1 = self.norm1(x_map)

        # Atención
        attn_out = self.attn(x_norm_1)
        attn_out = self.drop_path1(attn_out)

        # Suma residual
        x_map = x_map + attn_out

        # Segundo sub-bloque: Norm -> MLP -> DropPath -> Residual ---

        x_norm_2 = self.norm2(x_map)

        # MLP
        mlp_out = self.mlp(x_norm_2)
        mlp_out = self.drop_path2(mlp_out)

        # Segunda suma residual
        x_out = x_map + mlp_out

        return x_out

In [12]:
def test_outlooker_block():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    blk = OutlookerBlock(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=1,
        mlp_ratio=4.0,
        attn_drop=0.0,
        proj_drop=0.0,
        drop_path=0.0,
        mlp_drop=0.0,)

    y = blk(x_map)
    assert y.shape == x_map.shape

    y.mean().backward()
    assert x_map.grad is not None
    assert torch.isfinite(x_map.grad).all()

    print("[OK] OutlookerBlock:",
          "in/out", tuple(y.shape),
          "| grad mean", float(x_map.grad.abs().mean()))

test_outlooker_block()

[OK] OutlookerBlock: in/out (2, 16, 16, 192) | grad mean 1.0187762200075667e-05


In [13]:
def test_embed_then_outlook(img_size=64, patch_size=4, dim=192, heads=6):
    torch.manual_seed(0)

    B = 2
    x = torch.randn(B, 3, img_size, img_size, requires_grad=True)

    pe = PatchEmbeddingConv(
        patch_size=patch_size,
        in_chans=3,
        embed_dim=dim,
        norm_layer=torch.nn.LayerNorm,
        pad_if_needed=True,
        return_tokens=True,)

    blk = OutlookerBlock(
        dim=dim,
        num_heads=heads,
        kernel_size=3,
        stride=1,
        mlp_ratio=4.0,
        drop_path=0.0,)

    x_map, (Hp, Wp), x_tok, pad_hw = pe(x)
    y_map = blk(x_map)

    assert y_map.shape == x_map.shape == (B, Hp, Wp, dim)

    # grad
    y_map.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()

    print("[OK] Embed->Outlook:",
          "img", (img_size, img_size),
          "| patches", (Hp, Wp),
          "| map", tuple(y_map.shape),
          "| pad", pad_hw)

test_embed_then_outlook(img_size=32)
test_embed_then_outlook(img_size=64)

[OK] Embed->Outlook: img (32, 32) | patches (8, 8) | map (2, 8, 8, 192) | pad (0, 0)
[OK] Embed->Outlook: img (64, 64) | patches (16, 16) | map (2, 16, 16, 192) | pad (0, 0)


In [14]:
class VOLOStage(nn.Module):
    """
    Un stage VOLO basado en OutlookerBlocks.

    Mantiene el formato channel-last:
      Input:  [B, H, W, C]
      Output: [B, H, W, C]  (si stride=1)
    Si quisieras un stage que haga downsample, usa stride>1 en los bloques
    (pero en CIFAR te recomiendo stride=1 en el stage inicial).
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        kernel_size: int = 3,
        stride: int = 1,
        mlp_ratio: float = 4.0,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        drop_path: float | list[float] = 0.0,
        mlp_drop: float = 0.0,):

        super().__init__()

        if isinstance(drop_path, float):
            dpr = [drop_path] * depth
        else:
            assert len(drop_path) == depth, "drop_path list must have length=depth"
            dpr = drop_path

        self.blocks = nn.ModuleList([
            OutlookerBlock(
                dim=dim,
                num_heads=num_heads,
                kernel_size=kernel_size,
                stride=stride,
                mlp_ratio=mlp_ratio,
                attn_drop=attn_drop,
                proj_drop=proj_drop,
                drop_path=dpr[i],
                mlp_drop=mlp_drop,) for i in range(depth)])

    def forward(self, x_map: torch.Tensor) -> torch.Tensor:
        for blk in self.blocks:
            x_map = blk(x_map)
        return x_map

In [15]:
def test_volo_stage():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x = torch.randn(B, H, W, C, requires_grad=True)

    stage = VOLOStage(
        dim=C,
        depth=3,
        num_heads=6,
        kernel_size=3,
        stride=1,
        drop_path=[0.0, 0.05, 0.1])

    y = stage(x)
    assert y.shape == x.shape
    y.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()

    print("[OK] VOLOStage:", tuple(y.shape), "| grad mean", float(x.grad.abs().mean()))

test_volo_stage()

[OK] VOLOStage: (2, 16, 16, 192) | grad mean 1.0873188330151606e-05


## Attention

In [16]:
def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout_p: float = 0.0, training: bool = True):
    """
    q: (B, H, Lq, d)
    k: (B, H, Lk, d)
    v: (B, H, Lk, d)
    mask: broadcastable a (B, H, Lq, Lk)
          - bool: True = BLOQUEAR (poner -inf)
          - float: 1.0 = permitir, 0.0 = bloquear
    """
    scores = torch.matmul(q, k.transpose(-2, -1))
    dk = q.size(-1)
    scores = scores / (dk ** 0.5)

    if mask is not None:
        if mask.dtype == torch.bool:
            scores = scores.masked_fill(mask, float("-inf"))
        else:
            scores = scores.masked_fill(mask <= 0, float("-inf"))

    attn = F.softmax(scores, dim=-1)
    if attn_dropout_p > 0.0:
        attn = F.dropout(attn, p=attn_dropout_p, training=training)

    output = torch.matmul(attn, v)
    return output, attn

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model debe ser múltiplo de num_heads"

        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.d_model = d_model

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

        # "dropout" lo usaremos como dropout de atención (sobre attn)
        self.attn_dropout_p = dropout
        # y también dejamos dropout de salida si quieres (común en ViT)
        self.out_dropout = nn.Dropout(dropout)

    def _split_heads(self, x):
        B, L, _ = x.shape
        return x.view(B, L, self.num_heads, self.d_head).transpose(1, 2)

    def _combine_heads(self, x):
        B, H, L, D = x.shape
        return x.transpose(1, 2).contiguous().view(B, L, H * D)

    def forward(self, x_q, x_kv, mask=None):
        q = self._split_heads(self.w_q(x_q))
        k = self._split_heads(self.w_k(x_kv))
        v = self._split_heads(self.w_v(x_kv))

        if mask is not None:
            if mask.dim() == 2:
                mask = mask[:, None, None, :]
            elif mask.dim() == 3:
                mask = mask[:, None, :, :]
            elif mask.dim() == 4:
                pass
            else:
                raise ValueError(f"Máscara con dims no soportadas: {mask.shape}")

            if mask.dtype != torch.bool:
                mask = (mask <= 0)

        attn_out, _ = scaled_dot_product_attention(
            q, k, v,
            mask=mask,
            attn_dropout_p=self.attn_dropout_p,
            training=self.training)

        attn_out = self._combine_heads(attn_out)

        attn_out = self.w_o(attn_out)
        attn_out = self.out_dropout(attn_out)
        return attn_out

In [18]:
class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """
    Bloque encoder para ViT (pre-norm):
    x -> LN -> MHA -> DropPath -> +residual
       -> LN -> MLP -> DropPath -> +residual
    """
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_dropout: float = 0.0,
        dropout: float = 0.1,
        drop_path: float = 0.0):

        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(d_model=dim, num_heads=num_heads, dropout=attn_dropout)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = FeedForward(dim, hidden_dim, dropout=dropout)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path1(self.attn(self.norm1(x), self.norm1(x), mask=None))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x


In [19]:
class TransformerStack(nn.Module):
    """Stack simple de TransformerBlock sobre tokens [B, N, C]."""
    def __init__(self, dim: int, depth: int, num_heads: int, mlp_ratio=4.0,
                 attn_dropout=0.0, dropout=0.1, drop_path: float | list[float] = 0.0):
        super().__init__()
        if isinstance(drop_path, float):
            dpr = [drop_path] * depth
        else:
            assert len(drop_path) == depth
            dpr = drop_path

        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                attn_dropout=attn_dropout,
                dropout=dropout,
                drop_path=dpr[i] if "drop_path" in TransformerBlock.__init__.__code__.co_varnames else 0.0) for i in range(depth)])

    def forward(self, x_tok: torch.Tensor) -> torch.Tensor:
        for blk in self.blocks:
            x_tok = blk(x_tok)
        return x_tok

In [20]:
def test_transformer_block():
    torch.manual_seed(0)
    B, N, C = 2, 256, 192
    x = torch.randn(B, N, C, requires_grad=True)

    blk = TransformerBlock(dim=C, num_heads=6, mlp_ratio=4.0, attn_dropout=0.0, dropout=0.1, drop_path=0.0)
    y = blk(x)
    assert y.shape == x.shape
    y.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()
    print("[OK] TransformerBlock:", tuple(y.shape), "grad", float(x.grad.abs().mean()))

test_transformer_block()

[OK] TransformerBlock: (2, 256, 192) grad 1.018048442347208e-05


# Hiratical

In [21]:
class MapDownsample(nn.Module):
    """
    Downsample para mapas channel-last: [B, H, W, C_in] -> [B, H/2, W/2, C_out]
    usando conv2d stride=2 en formato channel-first internamente.
    """
    def __init__(self, dim_in: int, dim_out: int, kernel_size: int = 3, norm_layer=nn.LayerNorm):
        super().__init__()
        pad = kernel_size // 2
        self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=2, padding=pad, bias=True)
        self.norm = norm_layer(dim_out) if norm_layer is not None else None

    def forward(self, x_map: torch.Tensor):
        # x_map: [B, H, W, C_in]
        B, H, W, C = x_map.shape
        x = x_map.permute(0, 3, 1, 2).contiguous()     # [B, C, H, W]
        x = self.conv(x)                               # [B, C_out, H2, W2]
        x_map = x.permute(0, 2, 3, 1).contiguous()     # [B, H2, W2, C_out]
        if self.norm is not None:
            x_map = self.norm(x_map)
        return x_map

In [22]:
class PoolingLayer(nn.Module):
    """
    Pooling jerárquico para ViT:

    - Toma tokens [B, N, D_in] + grid_size (H, W)
    - Los reinterpreta como feature map [B, D_in, H, W]
    - Aplica:
        depthwise conv (3x3, stride=2, padding=1)
        pointwise conv (1x1) para cambiar D_in -> D_out
    - Devuelve:
        tokens [B, N_out, D_out] y nuevo grid_size (H_out, W_out)
    """

    def __init__(self,
        dim_in: int,
        dim_out: int,
        kernel_size: int = 3,
        stride: int = 2,
        norm_layer: type[nn.Module] | None = nn.LayerNorm):

        super().__init__()
        padding = kernel_size // 2

        # Depthwise conv: cada canal se filtra por separado
        self.depthwise_conv = nn.Conv2d(
            in_channels=dim_in,
            out_channels=dim_in,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=dim_in)

        # Pointwise conv: mezcla canales y cambia dim
        self.pointwise_conv = nn.Conv2d(
            in_channels=dim_in,
            out_channels=dim_out,
            kernel_size=1,
            stride=1,
            padding=0)

        self.norm = norm_layer(dim_out) if norm_layer is not None else None

        self.dim_in = dim_in
        self.dim_out = dim_out
        self.stride = stride

    def forward(self, x: torch.Tensor, grid_size: tuple[int, int]):
        """
        Args:
            x: tokens [B, N, D_in]
            grid_size: (H, W) tal que H*W = N

        Returns:
            x_out: tokens [B, N_out, D_out]
            new_grid: (H_out, W_out)
        """
        B, N, D_in = x.shape
        H, W = grid_size

        assert D_in == self.dim_in, f"dim_in {D_in} != {self.dim_in}"
        assert H * W == N, f"H*W={H*W} no coincide con N={N}"

        # [B, N, D_in] -> [B, D_in, H, W]
        x = x.view(B, H, W, D_in).permute(0, 3, 1, 2)

        # Depthwise + pointwise
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)

        B, D_out, H_out, W_out = x.shape
        N_out = H_out * W_out

        # Volver a tokens: [B, D_out, H_out, W_out] -> [B, N_out, D_out]
        x = x.flatten(2).transpose(1, 2)

        if self.norm is not None:
            x = self.norm(x)

        new_grid = (H_out, W_out)
        return x, new_grid

# VOLO BackBone

In [24]:
def map_to_tokens(x_map: torch.Tensor) -> torch.Tensor:
    B, H, W, C = x_map.shape
    return x_map.view(B, H * W, C)

def tokens_to_map(x_tok: torch.Tensor, H: int, W: int) -> torch.Tensor:
    B, N, C = x_tok.shape
    assert N == H * W
    return x_tok.view(B, H, W, C)

class VOLOPyramid(nn.Module):
    """
    Backbone jerárquico para VOLO (sin classifier head aún).
    - Local: VOLOStage (Outlooker)
    - Global: TransformerStack (opcional)
    - Downsample: map-space (recomendado) o token-space (PoolingLayer tuyo)
    """
    def __init__(
        self,
        dims: tuple[int, ...],                 # ej (192, 256, 384)
        outlooker_depths: tuple[int, ...],     # ej (4, 2, 0)  (0 si no hay outlooker en ese nivel)
        outlooker_heads: tuple[int, ...],      # ej (6, 8, 12)
        transformer_depths: tuple[int, ...],   # ej (0, 4, 6)
        transformer_heads: tuple[int, ...],    # ej (6, 8, 12)
        kernel_size: int = 3,
        mlp_ratio: float = 4.0,
        downsample_kind: str = "map",          # "map" o "token"
        drop_path_rate: float = 0.0):

        super().__init__()
        L = len(dims)

        assert len(outlooker_depths) == L
        assert len(outlooker_heads) == L
        assert len(transformer_depths) == L
        assert len(transformer_heads) == L

        # schedule lineal de droppath a través de todos los bloques (local+global)
        total_blocks = sum(outlooker_depths) + sum(transformer_depths)
        dpr = torch.linspace(0, drop_path_rate, total_blocks).tolist() if total_blocks > 0 else []
        dp_i = 0

        self.levels = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        self.downsample_kind = downsample_kind

        for i in range(L):
            dim = dims[i]

            # Local stage (Outlooker)
            local = None
            if outlooker_depths[i] > 0:
                local_dpr = dpr[dp_i: dp_i + outlooker_depths[i]]
                dp_i += outlooker_depths[i]
                local = VOLOStage(
                    dim=dim,
                    depth=outlooker_depths[i],
                    num_heads=outlooker_heads[i],
                    kernel_size=kernel_size,
                    stride=1,
                    mlp_ratio=mlp_ratio,
                    drop_path=local_dpr)

            # Global stage (Transformer)
            global_ = None
            if transformer_depths[i] > 0:
                glob_dpr = dpr[dp_i: dp_i + transformer_depths[i]]
                dp_i += transformer_depths[i]

                global_ = TransformerStack(
                    dim=dim,
                    depth=transformer_depths[i],
                    num_heads=transformer_heads[i],
                    mlp_ratio=mlp_ratio,
                    attn_dropout=0.0,
                    dropout=0.1,
                    drop_path=glob_dpr,)

            self.levels.append(nn.ModuleDict({"local": local, "global": global_}))

            # Downsample para pasar dim_i -> dim_{i+1} (si no es el último nivel)
            if i < L - 1:
                if downsample_kind == "map":
                    self.downsamples.append(MapDownsample(dim_in=dim, dim_out=dims[i + 1], kernel_size=3))
                elif downsample_kind == "token":
                    # reusar PoolingLayer
                    self.downsamples.append(PoolingLayer(dim_in=dim, dim_out=dims[i + 1], kernel_size=3, stride=2))
                else:
                    raise ValueError(f"downsample_kind must be 'map' or 'token'. Got {downsample_kind}")

        assert dp_i == total_blocks

    def forward(self, x_map: torch.Tensor):
        """
        x_map: [B, H, W, C0]
        returns:
          x_final_tokens: [B, N_last, C_last]
          last_grid: (H_last, W_last)
        """
        B, H, W, C = x_map.shape

        for i, lvl in enumerate(self.levels):
            # local stage en map
            if lvl["local"] is not None:
                x_map = lvl["local"](x_map)

            # global stage en tokens (si existe)
            if lvl["global"] is not None:
                x_tok = map_to_tokens(x_map)
                x_tok = lvl["global"](x_tok)
                x_map = tokens_to_map(x_tok, H, W)

            # downsample (si aplica)
            if i < len(self.downsamples):
                ds = self.downsamples[i]
                if self.downsample_kind == "map":
                    x_map = ds(x_map)
                    H, W = x_map.shape[1], x_map.shape[2]
                else:
                    # token downsample: necesita grid
                    x_tok = map_to_tokens(x_map)
                    x_tok, (H, W) = ds(x_tok, (H, W))
                    x_map = tokens_to_map(x_tok, H, W)

        x_final = map_to_tokens(x_map)
        return x_final, (H, W)

In [25]:
def test_volo_pyramid_map():
    torch.manual_seed(0)
    B = 2
    H = W = 16
    x_map = torch.randn(B, H, W, 192)

    pyr = VOLOPyramid(
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads=(6, 8, 12),
        downsample_kind="map",
        drop_path_rate=0.1,)

    x_tok, (Hf, Wf) = pyr(x_map)
    print("[OK] Pyramid-map:", x_tok.shape, "grid", (Hf, Wf))
    assert x_tok.shape[0] == B
    assert x_tok.shape[2] == 384
    assert Hf * Wf == x_tok.shape[1]

test_volo_pyramid_map()

[OK] Pyramid-map: torch.Size([2, 16, 384]) grid (4, 4)


In [26]:
def test_volo_pyramid_token():
    torch.manual_seed(0)
    B = 2
    H = W = 16
    x_map = torch.randn(B, H, W, 192)

    pyr = VOLOPyramid(
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads=(6, 8, 12),
        downsample_kind="token",
        drop_path_rate=0.1)

    x_tok, (Hf, Wf) = pyr(x_map)
    print("[OK] Pyramid-token:", x_tok.shape, "grid", (Hf, Wf))
    assert x_tok.shape[2] == 384
    assert Hf * Wf == x_tok.shape[1]

test_volo_pyramid_token()

[OK] Pyramid-token: torch.Size([2, 16, 384]) grid (4, 4)


In [44]:
class ClassAttention(nn.Module):
    """
    Class Attention: sólo el CLS atiende al conjunto [CLS | tokens].
    Inputs:
      cls:    [B, 1, C]
      tokens: [B, N, C]
    Output:
      cls_out: [B, 1, C]
    """
    def __init__(self, dim: int, num_heads: int, attn_dropout: float = 0.0, proj_dropout: float = 0.0):
        super().__init__()
        self.attn = MultiHeadAttention(d_model=dim, num_heads=num_heads, dropout=attn_dropout)
        self.proj_drop = nn.Dropout(proj_dropout)

    def forward(self, cls: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
        kv = torch.cat([cls, tokens], dim=1)        # [B, 1+N, C]
        cls_out = self.attn(cls, kv, mask=None) # [B, 1, C] (solo CLS sale actualizado)
        return self.proj_drop(cls_out)


class ClassAttentionBlock(nn.Module):
    """
    Pre-norm (CaiT-style):
      cls -> LN -> ClassAttn(cls, [cls|tokens]) -> +res
          -> LN -> MLP -> +res
    Nota: tokens NO se actualizan.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_dropout: float = 0.0,
        dropout: float = 0.0):

        super().__init__()
        self.norm_cls = nn.LayerNorm(dim)
        self.norm_tok = nn.LayerNorm(dim)
        self.ca = ClassAttention(dim, num_heads, attn_dropout=attn_dropout, proj_dropout=dropout)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, int(dim * mlp_ratio), dropout=dropout)

    def forward(self, cls: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
        # Class attention update (solo CLS)
        cls_norm = self.norm_cls(cls)
        tok_norm = self.norm_tok(tokens)
        cls = cls + self.ca(cls_norm, tok_norm)

        # MLP update (solo CLS)
        cls = cls + self.mlp(self.norm2(cls))
        return cls

In [66]:
class CLIPool(nn.Module):
    """
    "CLI" style pooling: mezcla aprendible entre CLS y mean(tokens).
      z = alpha * cls + (1-alpha) * mean
    """
    def __init__(self, init_alpha: float = 0.5):
        super().__init__()
        # parametriza alpha en logits para mantenerlo en (0,1)
        init_alpha = float(init_alpha)
        init_alpha = min(max(init_alpha, 1e-4), 1 - 1e-4)
        logit = math.log(init_alpha / (1 - init_alpha))
        self.alpha_logit = nn.Parameter(torch.tensor([logit], dtype=torch.float32))

    def forward(self, cls_vec: torch.Tensor, tok_mean: torch.Tensor) -> torch.Tensor:
        """
        cls_vec:  [B, C]
        tok_mean: [B, C]
        """
        alpha = torch.sigmoid(self.alpha_logit)  # scalar in (0,1)
        return alpha * cls_vec + (1.0 - alpha) * tok_mean

# VOLO

In [28]:
import math

def trunc_normal_(tensor, mean=0., std=1.):
    with torch.no_grad():
        return tensor.normal_(mean=mean, std=std)

class PosEmbed2D(nn.Module):
    """
    Positional embedding aprendible para grilla (H, W) en tokens.

    Guarda [1, H*W, C]. Si en forward llega otro (H,W), interpola.
    """
    def __init__(self, H: int, W: int, dim: int):
        super().__init__()
        self.H0 = H
        self.W0 = W
        self.dim = dim
        self.pos = nn.Parameter(torch.zeros(1, H * W, dim))
        trunc_normal_(self.pos, std=0.02)

    def forward(self, x_tok: torch.Tensor, grid: tuple[int, int]):
        """
        x_tok: [B, N, C]
        grid: (H, W)
        """
        B, N, C = x_tok.shape
        H, W = grid
        if (H == self.H0) and (W == self.W0):
            return x_tok + self.pos

        # Interpola pos emb como mapa [1, C, H, W] -> nuevo tamaño
        pos = self.pos.reshape(1, self.H0, self.W0, self.dim).permute(0, 3, 1, 2)  # [1,C,H0,W0]
        pos = nn.functional.interpolate(pos, size=(H, W), mode="bicubic", align_corners=False)
        pos = pos.permute(0, 2, 3, 1).reshape(1, H * W, self.dim)
        return x_tok + pos

class VOLOClassifier(nn.Module):
    """
    VOLO para CIFAR-100 (y similares), con dos modos:
      - flat: OutlookerStage -> TransformerStack (sin downsample)
              pooling: mean | cls | cli (cls via class-attn final)
      - hierarchical: pirámide con downsample (map o token)
              pooling: SOLO mean (por ahora)

    Flujo base:
      x [B,3,H,W]
        -> PatchEmbeddingConv -> x_tok [B, N, C0]
        -> pos emb (opcional)
        -> backbone (flat o pyramid)
        -> pooling
        -> head
    """

    def __init__(
        self,
        num_classes: int = 100,
        img_size: int = 32,
        in_chans: int = 3,
        patch_size: int = 4,

        # mode
        hierarchical: bool = False,
        downsample_kind: str = "map",   # si hierarchical=True: "map" o "token"

        # dims / depths (flat)
        embed_dim: int = 192,
        outlooker_depth: int = 4,
        outlooker_heads: int = 6,
        transformer_depth: int = 6,
        transformer_heads: int = 6,

        # hierarchical configs (si hierarchical=True)
        dims: tuple[int, ...] = (192, 256, 384),
        outlooker_depths: tuple[int, ...] = (2, 2, 0),
        outlooker_heads_list: tuple[int, ...] = (6, 8, 12),
        transformer_depths: tuple[int, ...] = (0, 2, 2),
        transformer_heads_list: tuple[int, ...] = (6, 8, 12),

        # block hyperparams
        kernel_size: int = 3,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
        attn_dropout: float = 0.0,
        drop_path_rate: float = 0.0,

        # head / pooling
        pooling: str = "mean",          # flat: "mean"|"cls"|"cli" ; hierarchical: "mean"
        use_pos_embed: bool = True,

        # cls refinamiento (flat)
        cls_attn_depth: int = 2,        # # capas ClassAttentionBlock
        cli_init_alpha: float = 0.5,    # init alpha para pooling="cli"
        use_cls_pos: bool = True):

        super().__init__()

        self.hierarchical = hierarchical
        self.use_pos_embed = use_pos_embed

        if self.hierarchical:
            assert pooling == "mean", "Por ahora hierarchical solo soporta pooling='mean'."
        else:
            assert pooling in ["mean", "cls", "cli"], "pooling en flat debe ser 'mean', 'cls' o 'cli'."
        self.pooling = pooling

        # ---- Patch Embedding ----
        C0 = (dims[0] if hierarchical else embed_dim)

        self.patch_embed = PatchEmbeddingConv(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=C0,
            norm_layer=nn.LayerNorm,
            pad_if_needed=True,
            return_tokens=True,)

        Hp0 = math.ceil(img_size / patch_size)
        Wp0 = math.ceil(img_size / patch_size)

        self.pos_embed = PosEmbed2D(Hp0, Wp0, C0) if use_pos_embed else None
        self.pos_drop = nn.Dropout(dropout)

        # ---- Backbone ----
        if not hierarchical:
            total = outlooker_depth + transformer_depth
            dpr = torch.linspace(0, drop_path_rate, total).tolist() if total > 0 else []
            dpr_local = dpr[:outlooker_depth]
            dpr_glob = dpr[outlooker_depth:]

            self.local_stage = VOLOStage(
                dim=embed_dim,
                depth=outlooker_depth,
                num_heads=outlooker_heads,
                kernel_size=kernel_size,
                stride=1,
                mlp_ratio=mlp_ratio,
                attn_drop=attn_dropout,
                proj_drop=dropout,
                drop_path=dpr_local if len(dpr_local) else 0.0,
                mlp_drop=dropout)

            self.global_blocks = nn.ModuleList([
                TransformerBlock(
                    dim=embed_dim,
                    num_heads=transformer_heads,
                    mlp_ratio=mlp_ratio,
                    attn_dropout=attn_dropout,
                    dropout=dropout,
                    drop_path=(dpr_glob[i] if len(dpr_glob) else 0.0),
                ) for i in range(transformer_depth)])

            # --- CLS  (solo si pooling usa cls/cli) ---
            self.use_cls = (pooling in ["cls", "cli"])
            if self.use_cls:
                self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
                trunc_normal_(self.cls_token, std=0.02)

                self.cls_pos = None
                if use_cls_pos:
                    self.cls_pos = nn.Parameter(torch.zeros(1, 1, embed_dim))
                    trunc_normal_(self.cls_pos, std=0.02)

                self.cls_attn_blocks = nn.ModuleList([
                    ClassAttentionBlock(
                        dim=embed_dim,
                        num_heads=transformer_heads,
                        mlp_ratio=mlp_ratio,
                        attn_dropout=attn_dropout,
                        dropout=dropout,) for _ in range(int(cls_attn_depth))])

                self.cli_pool = CLIPool(init_alpha=cli_init_alpha) if pooling == "cli" else None
            else:
                self.cls_token = None
                self.cls_pos = None
                self.cls_attn_blocks = None
                self.cli_pool = None


            self.norm = nn.LayerNorm(embed_dim)
            self.norm_feat = nn.LayerNorm(embed_dim)

            self.head = nn.Linear(embed_dim, num_classes)

        else:
            self.pyramid = VOLOPyramid(
                dims=dims,
                outlooker_depths=outlooker_depths,
                outlooker_heads=outlooker_heads_list,
                transformer_depths=transformer_depths,
                transformer_heads=transformer_heads_list,
                kernel_size=kernel_size,
                mlp_ratio=mlp_ratio,
                downsample_kind=downsample_kind,
                drop_path_rate=drop_path_rate,)


            self.norm = nn.LayerNorm(dims[-1])
            self.norm_feat = nn.LayerNorm(dims[-1])
            self.head = nn.Linear(dims[-1], num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Patch embedding
        x_map, (Hp, Wp), x_tok, _pad = self.patch_embed(x)   # x_tok [B,N,C0]
        B, N, C0 = x_tok.shape

        # Pos emb sobre tokens del grid
        if self.use_pos_embed and (self.pos_embed is not None):
            x_tok = self.pos_embed(x_tok, (Hp, Wp))
        x_tok = self.pos_drop(x_tok)

        if not self.hierarchical:
            # ---- Flat backbone ----
            # Outlooker trabaja en map (sin CLS)
            x_map = x_tok.view(B, Hp, Wp, C0)
            x_map = self.local_stage(x_map)
            x_tok = x_map.view(B, Hp * Wp, C0)  # [B,N,C]

            # Transformer global (tokens sin CLS)
            for blk in self.global_blocks:
                x_tok = blk(x_tok)

            #  Pooling
            if self.pooling == "mean":
                # Normaliza tokens y promedia
                x_tok_n = self.norm(x_tok)           # [B,N,C]
                feat = x_tok_n.mean(dim=1)           # [B,C]
                feat = self.norm_feat(feat)          # [B,C]
                return self.head(feat)

            # CLS refinado con class-attn final (CaiT-style)
            cls = self.cls_token.expand(B, -1, -1)   # [B,1,C]
            if self.cls_pos is not None:
                cls = cls + self.cls_pos

            for cab in self.cls_attn_blocks:
                cls = cab(cls, x_tok)               # [B,1,C]

            cls_vec = cls.squeeze(1)                # [B,C]
            cls_vec = self.norm_feat(cls_vec)

            if self.pooling == "cls":
                feat = cls_vec
                return self.head(feat)

            # pooling == "cli": mezcla CLS con mean(tokens) normalizado
            tok_mean = self.norm(x_tok).mean(dim=1)  # [B,C]
            feat = self.cli_pool(cls_vec, tok_mean)
            feat = self.norm_feat(feat)
            return self.head(feat)

        else:
            # ---- Hierarchical backbone (solo mean) ----
            x_map = x_tok.view(B, Hp, Wp, C0)
            x_last, (Hf, Wf) = self.pyramid(x_map)      # x_last: [B, Nf, C_last]

            x_last = self.norm(x_last)
            feat = x_last.mean(dim=1)
            feat = self.norm_feat(feat)
            return self.head(feat)


In [29]:
def test_volo_classifier_flat():
    torch.manual_seed(0)
    model = VOLOClassifier(
        num_classes=100,
        img_size=64,
        patch_size=4,
        hierarchical=False,
        embed_dim=192,
        outlooker_depth=2,
        transformer_depth=2,
        outlooker_heads=6,
        transformer_heads=6,
        pooling="mean")

    x = torch.randn(2, 3, 64, 64)
    y = model(x)
    print("[OK] flat logits:", y.shape)
    assert y.shape == (2, 100)

def test_volo_classifier_hier():
    torch.manual_seed(0)
    model = VOLOClassifier(
        num_classes=100,
        img_size=64,
        patch_size=4,
        hierarchical=True,
        downsample_kind="map",
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads_list=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads_list=(6, 8, 12),
        pooling="mean",)

    x = torch.randn(2, 3, 64, 64)
    y = model(x)
    print("[OK] hier logits:", y.shape)
    assert y.shape == (2, 100)

test_volo_classifier_flat()
test_volo_classifier_hier()

[OK] flat logits: torch.Size([2, 100])
[OK] hier logits: torch.Size([2, 100])


In [31]:
def _fmt_out(output):
    if isinstance(output, (tuple, list)):
        shapes = []
        for o in output:
            if hasattr(o, "shape"):
                shapes.append(tuple(o.shape))
            else:
                shapes.append(type(o).__name__)
        return shapes
    if hasattr(output, "shape"):
        return tuple(output.shape)
    return type(output).__name__


def attach_shape_hooks_volo(model: nn.Module, verbose: bool = True):
    hooks = []

    def add_hook(mod: nn.Module, name: str):
        if mod is None:
            return
        def hook(_m, _inp, out):
            print(f"{name:35s} -> {_fmt_out(out)}")
        hooks.append(mod.register_forward_hook(hook))

    # Top-level components
    add_hook(getattr(model, "patch_embed", None), "patch_embed")
    add_hook(getattr(model, "local_stage", None), "local_stage (outlooker)")
    add_hook(getattr(model, "pyramid", None), "pyramid (top)")
    add_hook(getattr(model, "norm", None), "norm")
    add_hook(getattr(model, "head", None), "head")

    # Global blocks (flat)
    if hasattr(model, "global_blocks"):
        for i, blk in enumerate(model.global_blocks):
            add_hook(blk, f"global_block[{i}]")

    # Pyramid internals (hierarchical)
    pyr = getattr(model, "pyramid", None)
    if pyr is not None:
        if hasattr(pyr, "levels"):
            for i, lvl in enumerate(pyr.levels):
                # lvl es nn.ModuleDict: NO tiene .get
                loc = lvl["local"] if "local" in lvl else None
                glob = lvl["global"] if "global" in lvl else None
                add_hook(loc,  f"pyr.level[{i}].local")
                add_hook(glob, f"pyr.level[{i}].global")

        if hasattr(pyr, "downsamples"):
            for i, ds in enumerate(pyr.downsamples):
                add_hook(ds, f"pyr.down[{i}]")

    return hooks

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def debug_forward_shapes(model: nn.Module, img_size: int, device: str = "cpu", batch_size: int = 2):
    model = model.to(device).eval()
    hooks = attach_shape_hooks_volo(model)

    x = torch.randn(batch_size, 3, img_size, img_size, device=device)
    print(f"\n=== Forward debug | img_size={img_size} | model={model.__class__.__name__} ===")
    y = model(x)
    print(f"{'OUTPUT logits':35s} -> {tuple(y.shape)}")

    remove_hooks(hooks)


In [32]:
model_flat64 = VOLOClassifier(
    num_classes=100,
    img_size=64,
    patch_size=4,
    hierarchical=False,
    embed_dim=192,
    outlooker_depth=2,
    outlooker_heads=6,
    transformer_depth=2,
    transformer_heads=6,
    pooling="mean",
    use_pos_embed=True,)

debug_forward_shapes(model_flat64, img_size=64, device="cpu")



=== Forward debug | img_size=64 | model=VOLOClassifier ===
patch_embed                         -> [(2, 16, 16, 192), 'tuple', (2, 256, 192), 'tuple']
local_stage (outlooker)             -> (2, 16, 16, 192)
global_block[0]                     -> (2, 256, 192)
global_block[1]                     -> (2, 256, 192)
norm                                -> (2, 256, 192)
head                                -> (2, 100)
OUTPUT logits                       -> (2, 100)


In [33]:
model_hier64 = VOLOClassifier(
    num_classes=100,
    img_size=64,
    patch_size=4,
    hierarchical=True,
    downsample_kind="map",
    dims=(192, 256, 384),
    outlooker_depths=(2, 2, 0),
    outlooker_heads_list=(6, 8, 12),
    transformer_depths=(0, 2, 2),
    transformer_heads_list=(6, 8, 12),
    pooling="mean",
    use_pos_embed=True,)

debug_forward_shapes(model_hier64, img_size=64, device="cpu")


=== Forward debug | img_size=64 | model=VOLOClassifier ===
patch_embed                         -> [(2, 16, 16, 192), 'tuple', (2, 256, 192), 'tuple']
pyr.level[0].local                  -> (2, 16, 16, 192)
pyr.down[0]                         -> (2, 8, 8, 256)
pyr.level[1].local                  -> (2, 8, 8, 256)
pyr.level[1].global                 -> (2, 64, 256)
pyr.down[1]                         -> (2, 4, 4, 384)
pyr.level[2].global                 -> (2, 16, 384)
pyramid (top)                       -> [(2, 16, 384), 'tuple']
norm                                -> (2, 16, 384)
head                                -> (2, 100)
OUTPUT logits                       -> (2, 100)


In [47]:
model_hier64_tok = VOLOClassifier(
    num_classes=100,
    img_size=64,
    patch_size=4,
    hierarchical=True,
    downsample_kind="token",
    dims=(192, 256, 384),
    outlooker_depths=(2, 2, 0),
    outlooker_heads_list=(6, 8, 12),
    transformer_depths=(0, 2, 2),
    transformer_heads_list=(6, 8, 12),
    pooling="mean",
    use_pos_embed=True,)

debug_forward_shapes(model_hier64_tok, img_size=64, device="cpu")


=== Forward debug | img_size=64 | model=VOLOClassifier ===
patch_embed                         -> [(2, 16, 16, 192), 'tuple', (2, 256, 192), 'tuple']
pyr.level[0].local                  -> (2, 16, 16, 192)
pyr.down[0]                         -> [(2, 64, 256), 'tuple']
pyr.level[1].local                  -> (2, 8, 8, 256)
pyr.level[1].global                 -> (2, 64, 256)
pyr.down[1]                         -> [(2, 16, 384), 'tuple']
pyr.level[2].global                 -> (2, 16, 384)
pyramid (top)                       -> [(2, 16, 384), 'tuple']
norm                                -> (2, 16, 384)
head                                -> (2, 100)
OUTPUT logits                       -> (2, 100)


---

In [34]:
import os, math, random, inspect
from contextlib import contextmanager, nullcontext
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

def seed_everything(seed: int = 0, deterministic: bool = False):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True

_DTYPE_MAP = {
    "bf16": torch.bfloat16, "bfloat16": torch.bfloat16,
    "fp16": torch.float16,  "float16": torch.float16,
    "fp32": torch.float32,  "float32": torch.float32,}

def _cuda_dtype_supported(dtype: torch.dtype) -> bool:
    if not torch.cuda.is_available():
        return False
    return dtype in (torch.float16, torch.bfloat16)

def make_grad_scaler(device: str = "cuda", enabled: bool = True):
    if not enabled:
        return None

    if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
        try:
            sig = inspect.signature(torch.amp.GradScaler)
            if len(sig.parameters) >= 1:
                return torch.amp.GradScaler(device if device in ("cuda", "cpu") else "cuda")
            return torch.amp.GradScaler()
        except Exception:
            pass

    if hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "GradScaler"):
        return torch.cuda.amp.GradScaler()
    return None


@contextmanager
def autocast_ctx(
    device: str = "cuda",
    enabled: bool = True,
    dtype: str = "fp16",
    cache_enabled: bool = True,):
    """
    Context manager de autocast:
      - cuda: fp16 por defecto (ideal en T4)
      - cpu: bfloat16 si está disponible
    """
    if not enabled:
        with nullcontext():
            yield
        return

    if device == "cuda":
        want = _DTYPE_MAP.get(dtype.lower(), torch.float16)
        use = want if _cuda_dtype_supported(want) else torch.float16
        with torch.amp.autocast(device_type="cuda", dtype=use, cache_enabled=cache_enabled):
            yield
        return

    if device == "cpu":
        try:
            with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled):
                yield
        except Exception:
            with nullcontext():
                yield
        return

    with nullcontext():
        yield

In [36]:
def build_param_groups_no_wd(model: nn.Module, weight_decay: float):
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue

        name_l = name.lower()
        # no decay for biases + norms + positional/class tokens
        if (
            name.endswith(".bias")
            or ("norm" in name_l)
            or ("bn" in name_l)
            or ("ln" in name_l)
            or ("pos" in name_l)         # pos_embed / pos
            or ("cls_token" in name_l)
        ):
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0}]


class WarmupCosineLR:
    """Warmup linear for warmup_steps, then cosine to min_lr. Step-based."""
    def __init__(self, optimizer, total_steps: int, warmup_steps: int, min_lr: float = 0.0):
        self.optimizer = optimizer
        self.total_steps = int(total_steps)
        self.warmup_steps = int(warmup_steps)
        self.min_lr = float(min_lr)
        self.base_lrs = [g["lr"] for g in optimizer.param_groups]
        self.step_num = 0

    def step(self):
        self.step_num += 1
        t = self.step_num

        for i, group in enumerate(self.optimizer.param_groups):
            base = self.base_lrs[i]
            if t <= self.warmup_steps and self.warmup_steps > 0:
                lr = base * (t / self.warmup_steps)
            else:
                tt = min(t, self.total_steps)
                denom = max(1, self.total_steps - self.warmup_steps)
                progress = (tt - self.warmup_steps) / denom
                cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
                lr = self.min_lr + (base - self.min_lr) * cosine
            group["lr"] = lr

    def state_dict(self):
        return {"step_num": self.step_num}

    def load_state_dict(self, d):
        self.step_num = int(d.get("step_num", 0))

In [37]:
def save_checkpoint(
    path: str,
    model,
    optimizer,
    scheduler,
    scaler,
    epoch: int,
    best_top1: float,
    extra: dict | None = None,):

    ckpt = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "scheduler": scheduler.state_dict() if scheduler is not None else None,
        "scaler": scaler.state_dict() if scaler is not None else None,
        "epoch": epoch,
        "best_top1": best_top1,
        "extra": extra or {},}
    torch.save(ckpt, path)


def load_checkpoint(
    path: str,
    model,
    optimizer=None,
    scheduler=None,
    scaler=None,
    map_location="cpu",
    strict: bool = True,):
    ckpt = torch.load(path, map_location=map_location)
    model.load_state_dict(ckpt["model"], strict=strict)

    if optimizer is not None and ckpt.get("optimizer") is not None:
        optimizer.load_state_dict(ckpt["optimizer"])
    if scheduler is not None and ckpt.get("scheduler") is not None:
        scheduler.load_state_dict(ckpt["scheduler"])
    if scaler is not None and ckpt.get("scaler") is not None:
        scaler.load_state_dict(ckpt["scaler"])
    return ckpt

In [38]:
# -------------------------
# Mixup / CutMix + Loss
# -------------------------
def _one_hot(targets: torch.Tensor, num_classes: int) -> torch.Tensor:
    return F.one_hot(targets, num_classes=num_classes).float()


def soft_target_cross_entropy(logits: torch.Tensor, targets_soft: torch.Tensor) -> torch.Tensor:
    logp = F.log_softmax(logits, dim=1)
    return -(targets_soft * logp).sum(dim=1).mean()


def apply_mixup_cutmix(
    images: torch.Tensor,
    targets: torch.Tensor,
    num_classes: int,
    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    prob: float = 1.0,):
    """
    Returns:
      images_aug: [B,3,H,W]
      targets_soft: [B,K]
    """
    if prob <= 0.0 or (mixup_alpha <= 0.0 and cutmix_alpha <= 0.0):
        return images, _one_hot(targets, num_classes)

    if random.random() > prob:
        return images, _one_hot(targets, num_classes)

    use_cutmix = (cutmix_alpha > 0.0) and (mixup_alpha <= 0.0 or random.random() < 0.5)
    B, _, H, W = images.shape
    perm = torch.randperm(B, device=images.device)

    y1 = _one_hot(targets, num_classes)
    y2 = _one_hot(targets[perm], num_classes)

    if use_cutmix:
        lam = torch.distributions.Beta(cutmix_alpha, cutmix_alpha).sample().item()
        cut_w = int(W * math.sqrt(1.0 - lam))
        cut_h = int(H * math.sqrt(1.0 - lam))
        cx = random.randint(0, W - 1)
        cy = random.randint(0, H - 1)

        x1 = max(cx - cut_w // 2, 0)
        x2 = min(cx + cut_w // 2, W)
        y1b = max(cy - cut_h // 2, 0)
        y2b = min(cy + cut_h // 2, H)

        images_aug = images.clone()
        images_aug[:, :, y1b:y2b, x1:x2] = images[perm, :, y1b:y2b, x1:x2]

        # adjust lambda based on actual area swapped
        area = (x2 - x1) * (y2b - y1b)
        lam = 1.0 - area / float(W * H)
    else:
        lam = torch.distributions.Beta(mixup_alpha, mixup_alpha).sample().item()
        images_aug = images * lam + images[perm] * (1.0 - lam)

    targets_soft = y1 * lam + y2 * (1.0 - lam)
    return images_aug, targets_soft

In [39]:
# -------------------------
# Metrics
# -------------------------
@torch.no_grad()
def accuracy_topk(logits: torch.Tensor, targets: torch.Tensor, ks=(1, 3, 5)) -> Dict[int, float]:
    """
    targets can be:
      - int64 class indices [B]
      - soft targets [B, num_classes] (we'll argmax for accuracy reporting)
    """
    if targets.ndim == 2:
        targets = targets.argmax(dim=1)

    max_k = max(ks)
    B = targets.size(0)
    _, pred = torch.topk(logits, k=max_k, dim=1)
    correct = pred.eq(targets.view(-1, 1).expand_as(pred))
    out = {}
    for k in ks:
        out[k] = 100.0 * correct[:, :k].any(dim=1).float().sum().item() / B
    return out

In [47]:
import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

def ddp_is_on() -> bool:
    return dist.is_available() and dist.is_initialized()

def ddp_rank() -> int:
    return dist.get_rank() if ddp_is_on() else 0

def is_main_process() -> bool:
    return (not ddp_is_on()) or ddp_rank() == 0

def ddp_sum_(tensor: torch.Tensor) -> torch.Tensor:
    """All-reduce SUM in-place and return tensor."""
    if ddp_is_on():
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    return tensor

def ddp_broadcast_bool(flag: bool, device: torch.device | str) -> bool:
    """Broadcast a stop flag from rank0 to all ranks."""
    t = torch.tensor([1 if flag else 0], device=device)
    if ddp_is_on():
        dist.broadcast(t, src=0)
    return bool(t.item())

In [48]:
from typing import Optional
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

def train_one_epoch(
    model: nn.Module,
    dataloader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    device: str = "cuda",
    scaler=None,
    autocast_dtype: str = "fp16",
    use_amp: bool = True,
    grad_clip_norm: Optional[float] = 1.0,
    label_smoothing: float = 0.1,
    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    mix_prob: float = 1.0,
    num_classes: int = 100,
    channels_last: bool = False,
    print_every: int = 100,
):
    model.train()

    use_scaler = (scaler is not None) and use_amp and autocast_dtype.lower() in ("fp16", "float16")

    running_loss = 0.0
    total = 0
    c1 = c3 = c5 = 0.0  # “correct counts” acumulados (no %)

    t0 = time.time()
    for step, (images, targets) in enumerate(dataloader, start=1):
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if channels_last:
            images = images.contiguous(memory_format=torch.channels_last)

        B = targets.size(0)

        images_aug, targets_soft = apply_mixup_cutmix(
            images, targets,
            num_classes=num_classes,
            mixup_alpha=mixup_alpha,
            cutmix_alpha=cutmix_alpha,
            prob=mix_prob
        )

        use_mix = (mixup_alpha > 0.0) or (cutmix_alpha > 0.0)
        targets_for_acc = targets_soft if use_mix else targets

        optimizer.zero_grad(set_to_none=True)

        with autocast_ctx(device=device, enabled=use_amp, dtype=autocast_dtype, cache_enabled=True):
            logits = model(images_aug)

        if use_mix:
            loss = soft_target_cross_entropy(logits.float(), targets_soft)
        else:
            loss = F.cross_entropy(logits.float(), targets, label_smoothing=label_smoothing)

        if use_scaler:
            scaler.scale(loss).backward()
            if grad_clip_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        # métricas “locales”
        running_loss += loss.item() * B
        total += B
        accs = accuracy_topk(logits.detach(), targets_for_acc, ks=(1, 3, 5))
        c1 += accs[1] * B / 100.0
        c3 += accs[3] * B / 100.0
        c5 += accs[5] * B / 100.0

        # log solo en rank0 (si DDP)
        if print_every and (step % print_every == 0) and is_main_process():
            dt = time.time() - t0
            imgs_sec = total / max(dt, 1e-9)
            print(
                f"[train step {step}/{len(dataloader)}] "
                f"loss {running_loss/total:.4f} | "
                f"top1 {100*c1/total:.2f}% | top3 {100*c3/total:.2f}% | top5 {100*c5/total:.2f}% | "
                f"{imgs_sec:.1f} img/s | lr {optimizer.param_groups[0]['lr']:.2e}"
            )

    # ---- REDUCCIÓN GLOBAL (DDP) ----
    stats = torch.tensor([running_loss, total, c1, c3, c5], device=device, dtype=torch.float64)
    ddp_sum_(stats)
    running_loss_g, total_g, c1_g, c3_g, c5_g = stats.tolist()

    avg_loss = running_loss_g / max(total_g, 1e-12)
    metrics = {
        "top1": 100.0 * c1_g / max(total_g, 1e-12),
        "top3": 100.0 * c3_g / max(total_g, 1e-12),
        "top5": 100.0 * c5_g / max(total_g, 1e-12),}
    
    return avg_loss, metrics

In [49]:

@torch.no_grad()
def evaluate_one_epoch(
    model: nn.Module,
    dataloader,
    device: str = "cuda",
    autocast_dtype: str = "fp16",
    use_amp: bool = True,
    label_smoothing: float = 0.0,
    channels_last: bool = False,
):
    model.eval()

    running_loss = 0.0
    total = 0
    c1 = c3 = c5 = 0.0

    for images, targets in dataloader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if channels_last:
            images = images.contiguous(memory_format=torch.channels_last)

        B = targets.size(0)

        with autocast_ctx(device=device, enabled=use_amp, dtype=autocast_dtype, cache_enabled=True):
            logits = model(images)

        loss = F.cross_entropy(logits.float(), targets, label_smoothing=label_smoothing)

        running_loss += loss.item() * B
        total += B

        accs = accuracy_topk(logits, targets, ks=(1, 3, 5))
        c1 += accs[1] * B / 100.0
        c3 += accs[3] * B / 100.0
        c5 += accs[5] * B / 100.0

    # ---- REDUCCIÓN GLOBAL (DDP) ----
    stats = torch.tensor([running_loss, total, c1, c3, c5], device=device, dtype=torch.float64)
    ddp_sum_(stats)
    running_loss_g, total_g, c1_g, c3_g, c5_g = stats.tolist()

    avg_loss = running_loss_g / max(total_g, 1e-12)
    metrics = {
        "top1": 100.0 * c1_g / max(total_g, 1e-12),
        "top3": 100.0 * c3_g / max(total_g, 1e-12),
        "top5": 100.0 * c5_g / max(total_g, 1e-12),}
    
    return avg_loss, metrics

In [63]:
import time
import torch
import torch.nn as nn

def train_model(
    model: nn.Module,
    train_loader,
    epochs: int,
    val_loader=None,
    device: str = "cuda",
    lr: float = 5e-4,
    weight_decay: float = 0.05,
    autocast_dtype: str = "fp16",
    use_amp: bool = True,
    grad_clip_norm: float | None = 1.0,
    warmup_ratio: float = 0.05,
    min_lr: float = 0.0,
    label_smoothing: float = 0.1,
    print_every: int = 100,
    save_path: str = "best_model.pt",
    last_path: str = "last_model.pt",
    resume_path: str | None = None,

    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    mix_prob: float = 1.0,
    num_classes: int = 100,
    channels_last: bool = False,

    early_stop: bool = True,
    early_stop_metric: str = "top1",
    early_stop_patience: int = 10,
    early_stop_min_delta: float = 0.0,
    early_stop_require_monotonic: bool = False):

    model.to(device)

    # Optimizer
    param_groups = build_param_groups_no_wd(model, weight_decay=weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.999), eps=1e-8)

    # Scheduler warmup + cosine (step-based)
    total_steps = epochs * len(train_loader)
    warmup_steps = int(total_steps * warmup_ratio)
    scheduler = WarmupCosineLR(
        optimizer,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        min_lr=min_lr,
    )

    scaler = None
    if use_amp and autocast_dtype.lower() in ("fp16", "float16"):
        scaler = make_grad_scaler(device=device, enabled=True)

    # Resume
    start_epoch = 0
    best_val_top1 = -float("inf")
    best_val_loss = float("inf")
    best_epoch = 0

    if resume_path is not None:
        ckpt = load_checkpoint(
            resume_path, model,
            optimizer=optimizer, scheduler=scheduler, scaler=scaler,
            map_location=device,
            strict=True,
        )

        start_epoch = int(ckpt.get("epoch", 0))
        best_val_top1 = float(ckpt.get("best_top1", best_val_top1))
        extra = ckpt.get("extra", {}) or {}
        best_val_loss = float(extra.get("best_val_loss", best_val_loss))
        best_epoch = int(extra.get("best_epoch", best_epoch))

        if is_main_process():
            print(f"Resumed from {resume_path} at epoch {start_epoch} | best_top1 {best_val_top1:.2f}% | best_loss {best_val_loss:.4f}")

    history = {
        "train_loss": [], "train_top1": [], "train_top3": [], "train_top5": [],
        "val_loss": [], "val_top1": [], "val_top3": [], "val_top5": [],
        "lr": [],
    } if is_main_process() else None  # <- solo rank0 guarda history

    # Early stop state (solo rank0 lleva el estado)
    metric = early_stop_metric.lower()
    assert metric in ("top1", "loss")
    patience = int(early_stop_patience)
    mode = "max" if metric == "top1" else "min"
    best_metric = best_val_top1 if metric == "top1" else best_val_loss
    bad_epochs = 0
    last_vals = []

    def _is_improvement(curr: float, best: float) -> bool:
        d = float(early_stop_min_delta)
        return (curr > (best + d)) if mode == "max" else (curr < (best - d))

    def _degradation_monotonic(vals: list[float]) -> bool:
        if not early_stop_require_monotonic or len(vals) < 2:
            return True
        if mode == "max":
            return all(vals[i] >= vals[i + 1] for i in range(len(vals) - 1))
        else:
            return all(vals[i] <= vals[i + 1] for i in range(len(vals) - 1))

    for epoch in range(start_epoch + 1, epochs + 1):
        if is_main_process():
            print(f"\n=== Epoch {epoch}/{epochs} ===")
        t_epoch = time.time()

        # ✅ DDP: reshuffle correcto por epoch
        if hasattr(train_loader, "sampler") and isinstance(train_loader.sampler, DistributedSampler):
            train_loader.sampler.set_epoch(epoch)
        if val_loader is not None and hasattr(val_loader, "sampler") and isinstance(val_loader.sampler, DistributedSampler):
            val_loader.sampler.set_epoch(epoch)

        # --- Train ---
        tr_loss, tr_m = train_one_epoch(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            scaler=scaler,
            autocast_dtype=autocast_dtype,
            use_amp=use_amp,
            grad_clip_norm=grad_clip_norm,
            label_smoothing=label_smoothing,
            mixup_alpha=mixup_alpha,
            cutmix_alpha=cutmix_alpha,
            mix_prob=mix_prob,
            num_classes=num_classes,
            channels_last=channels_last,
            print_every=print_every,
        )

        if is_main_process():
            history["train_loss"].append(tr_loss)
            history["train_top1"].append(tr_m["top1"])
            history["train_top3"].append(tr_m["top3"])
            history["train_top5"].append(tr_m["top5"])
            history["lr"].append(optimizer.param_groups[0]["lr"])

            print(f"[Train] loss {tr_loss:.4f} | top1 {tr_m['top1']:.2f}% | top3 {tr_m['top3']:.2f}% | top5 {tr_m['top5']:.2f}% | lr {optimizer.param_groups[0]['lr']:.2e}")

            # ✅ guardar "last" SOLO en rank0
            save_checkpoint(
                last_path, model, optimizer, scheduler, scaler,
                epoch=epoch, best_top1=best_val_top1,
                extra={
                    "autocast_dtype": autocast_dtype,
                    "use_amp": use_amp,
                    "best_val_loss": best_val_loss,
                    "best_epoch": best_epoch,
                    "early_stop_metric": metric,
                    "early_stop_patience": patience,
                    "early_stop_min_delta": float(early_stop_min_delta),
                },
            )

        stop_now = False

        # --- Val ---
        if val_loader is not None:
            va_loss, va_m = evaluate_one_epoch(
                model=model,
                dataloader=val_loader,
                device=device,
                autocast_dtype=autocast_dtype,
                use_amp=use_amp,
                label_smoothing=0.0,
                channels_last=channels_last,
            )

            if is_main_process():
                history["val_loss"].append(va_loss)
                history["val_top1"].append(va_m["top1"])
                history["val_top3"].append(va_m["top3"])
                history["val_top5"].append(va_m["top5"])

                print(f"[Val]   loss {va_loss:.4f} | top1 {va_m['top1']:.2f}% | top3 {va_m['top3']:.2f}% | top5 {va_m['top5']:.2f}%")

                # Best saved por top1
                if va_m["top1"] > best_val_top1:
                    best_val_top1 = va_m["top1"]
                    if va_loss < best_val_loss:
                        best_val_loss = va_loss
                        best_epoch = epoch

                    save_checkpoint(
                        save_path, model, optimizer, scheduler, scaler,
                        epoch=epoch, best_top1=best_val_top1,
                        extra={
                            "autocast_dtype": autocast_dtype,
                            "use_amp": use_amp,
                            "best_val_loss": best_val_loss,
                            "best_epoch": best_epoch,
                        },
                    )
                    print(f"Best saved to {save_path} (val top1 {best_val_top1:.2f}%)")

                # Early stop (solo rank0 decide)
                if early_stop:
                    curr_metric = va_m["top1"] if metric == "top1" else va_loss

                    last_vals.append(float(curr_metric))
                    if len(last_vals) > patience:
                        last_vals = last_vals[-patience:]

                    if _is_improvement(curr_metric, best_metric):
                        best_metric = float(curr_metric)
                        bad_epochs = 0
                    else:
                        bad_epochs += 1

                    if bad_epochs >= patience and _degradation_monotonic(last_vals):
                        print(f"Early-stop: no improvement on val_{metric} for {patience} epochs.")
                        stop_now = True

        # ✅ DDP: sincroniza el “stop” a todos los ranks
        stop_now = ddp_broadcast_bool(stop_now, device=device)
        if stop_now:
            break

        if is_main_process():
            dt = time.time() - t_epoch
            print(f"Epoch time: {dt/60:.2f} min")

    # return: history solo en rank0; en otros ranks devuelve None
    return history, (model.module if hasattr(model, "module") else model)

In [88]:


%%writefile volo.py

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import RandAugment
import torch.distributed as dist
from torch.utils.data import Subset
from torch.utils.data import Subset, DataLoader

def get_cifar100_datasets(
    data_dir: str = "./data",
    val_split: float = 0.0,
    ra_num_ops: int = 2,
    ra_magnitude: int = 7,
    random_erasing_p: float = 0.25,
    erasing_scale=(0.02, 0.20),
    erasing_ratio=(0.3, 3.3),
    img_size: int = 32,):

    """
    CIFAR-100 datasets con augmentations "mix-friendly":
    diseñadas para complementar Mixup/CutMix (en el loop) sin pasarse.

    img_size:
      - 32 (default): CIFAR nativo.
      - >32: upsample (p.ej. 64) para experimentos (más tokens/compute).
    """
    if img_size < 32:
        raise ValueError(f"img_size must be >= 32 for CIFAR-100. Got {img_size}.")

    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    # Si subimos resolución, primero hacemos resize y adaptamos crop/padding.
    # Padding recomendado proporcional: 32->4, 64->8, etc.

    crop_padding = max(4, img_size // 8)

    train_ops = []
    if img_size != 32:
        train_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))

    train_ops += [
        transforms.RandomCrop(img_size, padding=crop_padding),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=ra_num_ops, magnitude=ra_magnitude),
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
        transforms.RandomErasing(
            p=random_erasing_p,
            scale=erasing_scale,
            ratio=erasing_ratio,
            value="random",),]

    train_transform = transforms.Compose(train_ops)

    test_ops = []
    if img_size != 32:
        test_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))

    test_ops += [
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),]

    test_transform = transforms.Compose(test_ops)

    full_train_dataset = datasets.CIFAR100(
        root=data_dir, train=True, download=True, transform=train_transform)

    test_dataset = datasets.CIFAR100(
        root=data_dir, train=False, download=True, transform=test_transform)

    if val_split > 0.0:
        n_total = len(full_train_dataset)
        n_val = int(n_total * val_split)
        n_train = n_total - n_val
        train_dataset, val_dataset = random_split(
            full_train_dataset,
            [n_train, n_val],
            generator=torch.Generator().manual_seed(7),)

    else:
        train_dataset = full_train_dataset
        val_dataset = None

    return train_dataset, val_dataset, test_dataset


def get_cifar100_dataloaders(
    batch_size: int = 128,
    data_dir: str = "./data",
    num_workers: int = 2,
    val_split: float = 0.0,
    pin_memory: bool = True,
    ra_num_ops: int = 2,
    ra_magnitude: int = 7,
    random_erasing_p: float = 0.25,
    img_size: int = 32,):
    """
    Dataloaders CIFAR-100 listos para entrenar con Mixup/CutMix en el loop.
    Augmentations no tan agresivas.

    img_size:
      - 32 (default): CIFAR nativo.
      - 64: experimento de upsample (ojo: más compute).
    """
    train_ds, val_ds, test_ds = get_cifar100_datasets(
        data_dir=data_dir,
        val_split=val_split,
        ra_num_ops=ra_num_ops,
        ra_magnitude=ra_magnitude,
        random_erasing_p=random_erasing_p,
        img_size=img_size,)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=(num_workers > 0),)

    val_loader = None
    if val_ds is not None:
        val_loader = DataLoader(
            val_ds,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=(num_workers > 0),)

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=(num_workers > 0),)

    return train_loader, val_loader, test_loader

def _ddp_is_on():
    return dist.is_available() and dist.is_initialized()

def _ddp_rank():
    return dist.get_rank() if _ddp_is_on() else 0

def _ddp_barrier():
    if _ddp_is_on():
        dist.barrier()

def get_cifar100_datasets(
    data_dir: str = "./data",
    val_split: float = 0.0,
    ra_num_ops: int = 1,
    ra_magnitude: int = 5,
    random_erasing_p: float = 0.1,
    erasing_scale=(0.02, 0.20),
    erasing_ratio=(0.3, 3.3),
    img_size: int = 32,
    seed: int = 7,
    ddp_safe_download: bool = True):
    """
    CIFAR-100 datasets con aug 'mix-friendly' y soporte DDP:
      - Descarga segura: solo rank0 descarga, luego barrier.
      - Split determinista: train/val indices iguales en todos los ranks.
      - Val usa test_transform (SIN aug estocásticos).
    """
    if img_size < 32:
        raise ValueError(f"img_size must be >= 32 for CIFAR-100. Got {img_size}.")

    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    crop_padding = max(4, img_size // 8)

    train_ops = []
    if img_size != 32:
        train_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))
    train_ops += [
        transforms.RandomCrop(img_size, padding=crop_padding),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=ra_num_ops, magnitude=ra_magnitude),
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
        transforms.RandomErasing(
            p=random_erasing_p,
            scale=erasing_scale,
            ratio=erasing_ratio,
            value="random",
        ),
    ]
    train_transform = transforms.Compose(train_ops)

    test_ops = []
    if img_size != 32:
        test_ops.append(transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC))
    test_ops += [
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),]

    test_transform = transforms.Compose(test_ops)

    #  DDP-safe download
    if ddp_safe_download and _ddp_is_on():
        if _ddp_rank() == 0:
            datasets.CIFAR100(root=data_dir, train=True, download=True)
            datasets.CIFAR100(root=data_dir, train=False, download=True)
        _ddp_barrier()
        download_flag = False
    else:
        download_flag = True

    # Base datasets (dos versiones: train aug y eval clean)
    full_train_aug = datasets.CIFAR100(root=data_dir, train=True, download=download_flag, transform=train_transform)
    full_train_eval = datasets.CIFAR100(root=data_dir, train=True, download=False, transform=test_transform)
    test_dataset = datasets.CIFAR100(root=data_dir, train=False, download=download_flag, transform=test_transform)

    if val_split > 0.0:
        n_total = len(full_train_aug)
        n_val = int(n_total * val_split)
        n_train = n_total - n_val

        g = torch.Generator().manual_seed(seed)
        perm = torch.randperm(n_total, generator=g).tolist()
        train_idx = perm[:n_train]
        val_idx = perm[n_train:]

        train_dataset = Subset(full_train_aug, train_idx)
        val_dataset = Subset(full_train_eval, val_idx)
    else:
        train_dataset = full_train_aug
        val_dataset = None

    return train_dataset, val_dataset, test_dataset



import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbeddingConv(nn.Module):
    """
    Patch embedding estilo Swin.

    - Conv2d con kernel=stride=patch_size para convertir imagen -> grilla de patches.
    - Devuelve el mapa 2D en formato canal-al-final: [B, Hp, Wp, D],
      (más cómodo para window partition).
    - Opcionalmente devuelve tokens [B, N, D].
    - Opcional padding automático si H/W no son divisibles por patch_size.
    """

    def __init__(
        self,
        patch_size: int | tuple[int, int] = 4,
        in_chans: int = 3,
        embed_dim: int = 192,
        norm_layer: type[nn.Module] | None = nn.LayerNorm,
        pad_if_needed: bool = True,
        return_tokens: bool = True):

        super().__init__()

        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)

        self.patch_size = patch_size  # (Ph, Pw)
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.pad_if_needed = pad_if_needed
        self.return_tokens = return_tokens

        # [B, C, H, W] -> [B, D, Hp, Wp]
        self.proj = nn.Conv2d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
            bias=True,)

        # En Swin normalmente LayerNorm sobre la última dimensión
        self.norm = norm_layer(embed_dim) if norm_layer is not None else None

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [B, C, H, W]

        Returns:
            x_map:    [B, Hp, Wp, D]
            (Hp, Wp): tamaño espacial en patches
            x_tokens (opcional): [B, N, D]
            pad_hw (opcional): (pad_h, pad_w) aplicados a la imagen
        """
        B, C, H, W = x.shape
        Ph, Pw = self.patch_size

        pad_h = (Ph - (H % Ph)) % Ph
        pad_w = (Pw - (W % Pw)) % Pw

        if (pad_h != 0 or pad_w != 0):
            if not self.pad_if_needed:
                raise AssertionError(
                    f"Image size ({H}x{W}) no es divisible por patch_size {self.patch_size} "
                    f"y pad_if_needed=False.")

            x = F.pad(x, (0, pad_w, 0, pad_h))

        # [B, D, Hp, Wp]
        x = self.proj(x)
        Hp, Wp = x.shape[2], x.shape[3]

        # canal al final -> [B, Hp, Wp, D]
        x_map = x.permute(0, 2, 3, 1).contiguous()

        if self.norm is not None:
            x_map = self.norm(x_map)

        if self.return_tokens:
            x_tokens = x_map.view(B, Hp * Wp, self.embed_dim)
            return x_map, (Hp, Wp), x_tokens, (pad_h, pad_w)

        return x_map, (Hp, Wp), (pad_h, pad_w)

def test_patch_embedding_conv():
    torch.manual_seed(0)

    #  tamaño divisible (64 con patch=4)
    B, C, H, W = 2, 3, 64, 64
    x = torch.randn(B, C, H, W)

    pe = PatchEmbeddingConv(
        patch_size=4,
        in_chans=3,
        embed_dim=192,
        norm_layer=torch.nn.LayerNorm,
        pad_if_needed=True,
        return_tokens=True,)

    x_map, (Hp, Wp), x_tok, (pad_h, pad_w) = pe(x)

    assert x_map.shape == (B, Hp, Wp, 192)
    assert x_tok.shape == (B, Hp * Wp, 192)
    assert (pad_h, pad_w) == (0, 0)
    assert (Hp, Wp) == (H // 4, W // 4)

    print("[OK] PatchEmbeddingConv divisible:",
          "x_map", tuple(x_map.shape),
          "| x_tok", tuple(x_tok.shape),
          "| pad", (pad_h, pad_w))

    # tamaño NO divisible (65x63 con patch=4) -> debería paddear
    H2, W2 = 65, 63
    x2 = torch.randn(B, C, H2, W2)

    x_map2, (Hp2, Wp2), x_tok2, (pad_h2, pad_w2) = pe(x2)

    assert (H2 + pad_h2) % 4 == 0
    assert (W2 + pad_w2) % 4 == 0
    assert x_map2.shape == (B, Hp2, Wp2, 192)
    assert x_tok2.shape == (B, Hp2 * Wp2, 192)

    print("[OK] PatchEmbeddingConv non-divisible:",
          "input", (H2, W2),
          "| padded by", (pad_h2, pad_w2),
          "| patches", (Hp2, Wp2),
          "| x_map", tuple(x_map2.shape))



class OutlookAttention(nn.Module):
    """
    Outlook Attention (VOLO): agregación local dinámica sobre ventanas.

    Entrada:  x_map [B, H, W, C]  (channel-last)
    Salida:   y_map [B, H, W, C]

    Parámetros:
      - dim: canales C
      - kernel_size: k (vecindario k×k)
      - stride: s (si s>1 hace downsample tipo "outlook pooling"; para CIFAR típicamente s=1)
      - num_heads: h (partimos canales en cabezas, como MHSA)
    """

    def __init__(
        self,
        dim: int,
        num_heads: int = 6,
        kernel_size: int = 3,
        stride: int = 1,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,):

        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.kernel_size = kernel_size
        self.stride = stride

        # Genera pesos de atención por posición: [B, H, W, heads * k*k]
        self.attn = nn.Linear(dim, num_heads * kernel_size * kernel_size, bias=True)

        # Proyección para values (antes de unfold)
        self.v = nn.Linear(dim, dim, bias=True)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim, bias=True)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x_map: torch.Tensor) -> torch.Tensor:
        """
        x_map: [B, H, W, C]
        """
        B, H, W, C = x_map.shape
        k = self.kernel_size
        s = self.stride
        heads = self.num_heads
        hd = self.head_dim

        # attention weights
        a = self.attn(x_map)
        # si stride>1, la atención se evalúa en posiciones downsampled
        if s > 1:
            # downsample espacialmente (simple avg pool sobre channel-last)
            a = a.permute(0, 3, 1, 2)                       # [B, heads*k*k, H, W]
            a = F.avg_pool2d(a, kernel_size=s, stride=s)    # [B, heads*k*k, Hs, Ws]
            a = a.permute(0, 2, 3, 1).contiguous()          # [B, Hs, Ws, heads*k*k]

        Hs, Ws = a.shape[1], a.shape[2]
        a = a.view(B, Hs * Ws, heads, k * k)
        a = F.softmax(a, dim=-1)
        a = self.attn_drop(a)

        # values map
        v = self.v(x_map)
        v = v.permute(0, 3, 1, 2).contiguous()

        # unfold extrae vecindarios k×k para cada posición (con padding para "same")
        pad = k // 2
        v_unf = F.unfold(v, kernel_size=k, padding=pad, stride=s)
        v_unf = v_unf.view(B, C, k * k, Hs * Ws).permute(0, 3, 1, 2).contiguous()
        v_unf = v_unf.view(B, Hs * Ws, heads, hd, k * k)

        # apply attention: weighted sum over neighborhood
        # a:     [B, Hs*Ws, heads, k*k]
        # v_unf: [B, Hs*Ws, heads, hd, k*k]
        y = (v_unf * a.unsqueeze(3)).sum(dim=-1)
        y = y.reshape(B, Hs * Ws, C)              # concat heads

        # fold back to spatial map
        y_map = y.view(B, Hs, Ws, C)

        y_map = self.proj(y_map)
        y_map = self.proj_drop(y_map)
        return y_map

def test_outlook_attention_stride1():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    oa = OutlookAttention(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=1,
        attn_drop=0.0,
        proj_drop=0.0)

    y = oa(x_map)
    assert y.shape == x_map.shape, f"Expected {x_map.shape}, got {y.shape}"

    loss = y.mean()
    loss.backward()

    assert x_map.grad is not None, "No gradient flowed to input!"
    assert torch.isfinite(x_map.grad).all(), "Non-finite grads!"

    print("[OK] OutlookAttention stride=1:",
          "in", tuple(x_map.shape),
          "| out", tuple(y.shape),
          "| grad mean", float(x_map.grad.abs().mean()))



def test_outlook_attention_stride2():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    oa = OutlookAttention(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=2,
        attn_drop=0.0,
        proj_drop=0.0)

    y = oa(x_map)

    assert y.shape[0] == B and y.shape[-1] == C
    assert y.shape[1] == H // 2 and y.shape[2] == W // 2, f"Got {y.shape[1:3]}"

    loss = y.mean()
    loss.backward()
    assert x_map.grad is not None
    assert torch.isfinite(x_map.grad).all()

    print("[OK] OutlookAttention stride=2:",
          "in", (B, H, W, C),
          "| out", tuple(y.shape))



class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x

        keep_prob = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class MLP(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, drop: float = 0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class OutlookerBlock(nn.Module):
    """
    Bloque VOLO Outlooker:
      x -> LN -> OutlookAttention -> DropPath + residual
        -> LN -> MLP -> DropPath + residual

    Input/Output: [B, H, W, C]
    """
    def __init__(
        self,
        dim: int,
        num_heads: int,
        kernel_size: int = 3,
        stride: int = 1,
        mlp_ratio: float = 4.0,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        drop_path: float = 0.0,
        mlp_drop: float = 0.0):

        super().__init__()
        self.norm1 = nn.LayerNorm(dim)

        self.attn = OutlookAttention(
            dim=dim,
            num_heads=num_heads,
            kernel_size=kernel_size,
            stride=stride,
            attn_drop=attn_drop,
            proj_drop=proj_drop,)

        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)

        self.mlp = MLP(dim=dim, hidden_dim=hidden_dim, drop=mlp_drop)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x_map: torch.Tensor) -> torch.Tensor:
        """
        x_map: tensor de forma (B, C, H, W) o (B, N, C), según el bloque.
        """

        # Primer sub-bloque: Norm -> Atención -> DropPath -> Residual

        # Normalización del input
        x_norm_1 = self.norm1(x_map)

        # Atención
        attn_out = self.attn(x_norm_1)
        attn_out = self.drop_path1(attn_out)

        # Suma residual
        x_map = x_map + attn_out

        # Segundo sub-bloque: Norm -> MLP -> DropPath -> Residual ---

        x_norm_2 = self.norm2(x_map)

        # MLP
        mlp_out = self.mlp(x_norm_2)
        mlp_out = self.drop_path2(mlp_out)

        # Segunda suma residual
        x_out = x_map + mlp_out

        return x_out

def test_outlooker_block():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    blk = OutlookerBlock(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=1,
        mlp_ratio=4.0,
        attn_drop=0.0,
        proj_drop=0.0,
        drop_path=0.0,
        mlp_drop=0.0,)

    y = blk(x_map)
    assert y.shape == x_map.shape

    y.mean().backward()
    assert x_map.grad is not None
    assert torch.isfinite(x_map.grad).all()

    print("[OK] OutlookerBlock:",
          "in/out", tuple(y.shape),
          "| grad mean", float(x_map.grad.abs().mean()))


def test_embed_then_outlook(img_size=64, patch_size=4, dim=192, heads=6):
    torch.manual_seed(0)

    B = 2
    x = torch.randn(B, 3, img_size, img_size, requires_grad=True)

    pe = PatchEmbeddingConv(
        patch_size=patch_size,
        in_chans=3,
        embed_dim=dim,
        norm_layer=torch.nn.LayerNorm,
        pad_if_needed=True,
        return_tokens=True,)

    blk = OutlookerBlock(
        dim=dim,
        num_heads=heads,
        kernel_size=3,
        stride=1,
        mlp_ratio=4.0,
        drop_path=0.0,)

    x_map, (Hp, Wp), x_tok, pad_hw = pe(x)
    y_map = blk(x_map)

    assert y_map.shape == x_map.shape == (B, Hp, Wp, dim)

    # grad
    y_map.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()

    print("[OK] Embed->Outlook:",
          "img", (img_size, img_size),
          "| patches", (Hp, Wp),
          "| map", tuple(y_map.shape),
          "| pad", pad_hw)



class VOLOStage(nn.Module):
    """
    Un stage VOLO basado en OutlookerBlocks.

    Mantiene el formato channel-last:
      Input:  [B, H, W, C]
      Output: [B, H, W, C]  (si stride=1)
    Si quisieras un stage que haga downsample, usa stride>1 en los bloques
    (pero en CIFAR te recomiendo stride=1 en el stage inicial).
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        kernel_size: int = 3,
        stride: int = 1,
        mlp_ratio: float = 4.0,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        drop_path: float | list[float] = 0.0,
        mlp_drop: float = 0.0,):

        super().__init__()

        if isinstance(drop_path, float):
            dpr = [drop_path] * depth
        else:
            assert len(drop_path) == depth, "drop_path list must have length=depth"
            dpr = drop_path

        self.blocks = nn.ModuleList([
            OutlookerBlock(
                dim=dim,
                num_heads=num_heads,
                kernel_size=kernel_size,
                stride=stride,
                mlp_ratio=mlp_ratio,
                attn_drop=attn_drop,
                proj_drop=proj_drop,
                drop_path=dpr[i],
                mlp_drop=mlp_drop,) for i in range(depth)])

    def forward(self, x_map: torch.Tensor) -> torch.Tensor:
        for blk in self.blocks:
            x_map = blk(x_map)
        return x_map

def test_volo_stage():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x = torch.randn(B, H, W, C, requires_grad=True)

    stage = VOLOStage(
        dim=C,
        depth=3,
        num_heads=6,
        kernel_size=3,
        stride=1,
        drop_path=[0.0, 0.05, 0.1])

    y = stage(x)
    assert y.shape == x.shape
    y.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()

    print("[OK] VOLOStage:", tuple(y.shape), "| grad mean", float(x.grad.abs().mean()))



"""## Attention"""

def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout_p: float = 0.0, training: bool = True):
    """
    q: (B, H, Lq, d)
    k: (B, H, Lk, d)
    v: (B, H, Lk, d)
    mask: broadcastable a (B, H, Lq, Lk)
          - bool: True = BLOQUEAR (poner -inf)
          - float: 1.0 = permitir, 0.0 = bloquear
    """
    scores = torch.matmul(q, k.transpose(-2, -1))
    dk = q.size(-1)
    scores = scores / (dk ** 0.5)

    if mask is not None:
        if mask.dtype == torch.bool:
            scores = scores.masked_fill(mask, float("-inf"))
        else:
            scores = scores.masked_fill(mask <= 0, float("-inf"))

    attn = F.softmax(scores, dim=-1)
    if attn_dropout_p > 0.0:
        attn = F.dropout(attn, p=attn_dropout_p, training=training)

    output = torch.matmul(attn, v)
    return output, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model debe ser múltiplo de num_heads"

        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.d_model = d_model

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

        # "dropout" lo usaremos como dropout de atención (sobre attn)
        self.attn_dropout_p = dropout
        # y también dejamos dropout de salida si quieres (común en ViT)
        self.out_dropout = nn.Dropout(dropout)

    def _split_heads(self, x):
        B, L, _ = x.shape
        return x.view(B, L, self.num_heads, self.d_head).transpose(1, 2)

    def _combine_heads(self, x):
        B, H, L, D = x.shape
        return x.transpose(1, 2).contiguous().view(B, L, H * D)

    def forward(self, x_q, x_kv, mask=None):
        q = self._split_heads(self.w_q(x_q))
        k = self._split_heads(self.w_k(x_kv))
        v = self._split_heads(self.w_v(x_kv))

        if mask is not None:
            if mask.dim() == 2:
                mask = mask[:, None, None, :]
            elif mask.dim() == 3:
                mask = mask[:, None, :, :]
            elif mask.dim() == 4:
                pass
            else:
                raise ValueError(f"Máscara con dims no soportadas: {mask.shape}")

            if mask.dtype != torch.bool:
                mask = (mask <= 0)

        attn_out, _ = scaled_dot_product_attention(
            q, k, v,
            mask=mask,
            attn_dropout_p=self.attn_dropout_p,
            training=self.training)

        attn_out = self._combine_heads(attn_out)

        attn_out = self.w_o(attn_out)
        attn_out = self.out_dropout(attn_out)
        return attn_out

class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """
    Bloque encoder para ViT (pre-norm):
    x -> LN -> MHA -> DropPath -> +residual
       -> LN -> MLP -> DropPath -> +residual
    """
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_dropout: float = 0.0,
        dropout: float = 0.1,
        drop_path: float = 0.0):

        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(d_model=dim, num_heads=num_heads, dropout=attn_dropout)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = FeedForward(dim, hidden_dim, dropout=dropout)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path1(self.attn(self.norm1(x), self.norm1(x), mask=None))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x

class TransformerStack(nn.Module):
    """Stack simple de TransformerBlock sobre tokens [B, N, C]."""
    def __init__(self, dim: int, depth: int, num_heads: int, mlp_ratio=4.0,
                 attn_dropout=0.0, dropout=0.1, drop_path: float | list[float] = 0.0):
        super().__init__()
        if isinstance(drop_path, float):
            dpr = [drop_path] * depth
        else:
            assert len(drop_path) == depth
            dpr = drop_path

        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                attn_dropout=attn_dropout,
                dropout=dropout,
                drop_path=dpr[i] if "drop_path" in TransformerBlock.__init__.__code__.co_varnames else 0.0) for i in range(depth)])

    def forward(self, x_tok: torch.Tensor) -> torch.Tensor:
        for blk in self.blocks:
            x_tok = blk(x_tok)
        return x_tok

def test_transformer_block():
    torch.manual_seed(0)
    B, N, C = 2, 256, 192
    x = torch.randn(B, N, C, requires_grad=True)

    blk = TransformerBlock(dim=C, num_heads=6, mlp_ratio=4.0, attn_dropout=0.0, dropout=0.1, drop_path=0.0)
    y = blk(x)
    assert y.shape == x.shape
    y.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()
    print("[OK] TransformerBlock:", tuple(y.shape), "grad", float(x.grad.abs().mean()))


"""# Hiratical"""

class MapDownsample(nn.Module):
    """
    Downsample para mapas channel-last: [B, H, W, C_in] -> [B, H/2, W/2, C_out]
    usando conv2d stride=2 en formato channel-first internamente.
    """
    def __init__(self, dim_in: int, dim_out: int, kernel_size: int = 3, norm_layer=nn.LayerNorm):
        super().__init__()
        pad = kernel_size // 2
        self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=2, padding=pad, bias=True)
        self.norm = norm_layer(dim_out) if norm_layer is not None else None

    def forward(self, x_map: torch.Tensor):
        # x_map: [B, H, W, C_in]
        B, H, W, C = x_map.shape
        x = x_map.permute(0, 3, 1, 2).contiguous()     # [B, C, H, W]
        x = self.conv(x)                               # [B, C_out, H2, W2]
        x_map = x.permute(0, 2, 3, 1).contiguous()     # [B, H2, W2, C_out]
        if self.norm is not None:
            x_map = self.norm(x_map)
        return x_map

class PoolingLayer(nn.Module):
    """
    Pooling jerárquico para ViT:

    - Toma tokens [B, N, D_in] + grid_size (H, W)
    - Los reinterpreta como feature map [B, D_in, H, W]
    - Aplica:
        depthwise conv (3x3, stride=2, padding=1)
        pointwise conv (1x1) para cambiar D_in -> D_out
    - Devuelve:
        tokens [B, N_out, D_out] y nuevo grid_size (H_out, W_out)
    """

    def __init__(self,
        dim_in: int,
        dim_out: int,
        kernel_size: int = 3,
        stride: int = 2,
        norm_layer: type[nn.Module] | None = nn.LayerNorm):

        super().__init__()
        padding = kernel_size // 2

        # Depthwise conv: cada canal se filtra por separado
        self.depthwise_conv = nn.Conv2d(
            in_channels=dim_in,
            out_channels=dim_in,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=dim_in)

        # Pointwise conv: mezcla canales y cambia dim
        self.pointwise_conv = nn.Conv2d(
            in_channels=dim_in,
            out_channels=dim_out,
            kernel_size=1,
            stride=1,
            padding=0)

        self.norm = norm_layer(dim_out) if norm_layer is not None else None

        self.dim_in = dim_in
        self.dim_out = dim_out
        self.stride = stride

    def forward(self, x: torch.Tensor, grid_size: tuple[int, int]):
        """
        Args:
            x: tokens [B, N, D_in]
            grid_size: (H, W) tal que H*W = N

        Returns:
            x_out: tokens [B, N_out, D_out]
            new_grid: (H_out, W_out)
        """
        B, N, D_in = x.shape
        H, W = grid_size

        assert D_in == self.dim_in, f"dim_in {D_in} != {self.dim_in}"
        assert H * W == N, f"H*W={H*W} no coincide con N={N}"

        # [B, N, D_in] -> [B, D_in, H, W]
        x = x.view(B, H, W, D_in).permute(0, 3, 1, 2)

        # Depthwise + pointwise
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)

        B, D_out, H_out, W_out = x.shape
        N_out = H_out * W_out

        # Volver a tokens: [B, D_out, H_out, W_out] -> [B, N_out, D_out]
        x = x.flatten(2).transpose(1, 2)

        if self.norm is not None:
            x = self.norm(x)

        new_grid = (H_out, W_out)
        return x, new_grid

"""# VOLO BackBone"""

def map_to_tokens(x_map: torch.Tensor) -> torch.Tensor:
    B, H, W, C = x_map.shape
    return x_map.view(B, H * W, C)

def tokens_to_map(x_tok: torch.Tensor, H: int, W: int) -> torch.Tensor:
    B, N, C = x_tok.shape
    assert N == H * W
    return x_tok.view(B, H, W, C)

class VOLOPyramid(nn.Module):
    """
    Backbone jerárquico para VOLO (sin classifier head aún).
    - Local: VOLOStage (Outlooker)
    - Global: TransformerStack (opcional)
    - Downsample: map-space (recomendado) o token-space (PoolingLayer tuyo)
    """
    def __init__(
        self,
        dims: tuple[int, ...],                 # ej (192, 256, 384)
        outlooker_depths: tuple[int, ...],     # ej (4, 2, 0)  (0 si no hay outlooker en ese nivel)
        outlooker_heads: tuple[int, ...],      # ej (6, 8, 12)
        transformer_depths: tuple[int, ...],   # ej (0, 4, 6)
        transformer_heads: tuple[int, ...],    # ej (6, 8, 12)
        kernel_size: int = 3,
        mlp_ratio: float = 4.0,
        downsample_kind: str = "map",          # "map" o "token"
        drop_path_rate: float = 0.0):

        super().__init__()
        L = len(dims)

        assert len(outlooker_depths) == L
        assert len(outlooker_heads) == L
        assert len(transformer_depths) == L
        assert len(transformer_heads) == L

        # schedule lineal de droppath a través de todos los bloques (local+global)
        total_blocks = sum(outlooker_depths) + sum(transformer_depths)
        dpr = torch.linspace(0, drop_path_rate, total_blocks).tolist() if total_blocks > 0 else []
        dp_i = 0

        self.levels = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        self.downsample_kind = downsample_kind

        for i in range(L):
            dim = dims[i]

            # Local stage (Outlooker)
            local = None
            if outlooker_depths[i] > 0:
                local_dpr = dpr[dp_i: dp_i + outlooker_depths[i]]
                dp_i += outlooker_depths[i]
                local = VOLOStage(
                    dim=dim,
                    depth=outlooker_depths[i],
                    num_heads=outlooker_heads[i],
                    kernel_size=kernel_size,
                    stride=1,
                    mlp_ratio=mlp_ratio,
                    drop_path=local_dpr)

            # Global stage (Transformer)
            global_ = None
            if transformer_depths[i] > 0:
                glob_dpr = dpr[dp_i: dp_i + transformer_depths[i]]
                dp_i += transformer_depths[i]

                global_ = TransformerStack(
                    dim=dim,
                    depth=transformer_depths[i],
                    num_heads=transformer_heads[i],
                    mlp_ratio=mlp_ratio,
                    attn_dropout=0.0,
                    dropout=0.1,
                    drop_path=glob_dpr,)

            self.levels.append(nn.ModuleDict({"local": local, "global": global_}))

            # Downsample para pasar dim_i -> dim_{i+1} (si no es el último nivel)
            if i < L - 1:
                if downsample_kind == "map":
                    self.downsamples.append(MapDownsample(dim_in=dim, dim_out=dims[i + 1], kernel_size=3))
                elif downsample_kind == "token":
                    # reusar PoolingLayer
                    self.downsamples.append(PoolingLayer(dim_in=dim, dim_out=dims[i + 1], kernel_size=3, stride=2))
                else:
                    raise ValueError(f"downsample_kind must be 'map' or 'token'. Got {downsample_kind}")

        assert dp_i == total_blocks

    def forward(self, x_map: torch.Tensor):
        """
        x_map: [B, H, W, C0]
        returns:
          x_final_tokens: [B, N_last, C_last]
          last_grid: (H_last, W_last)
        """
        B, H, W, C = x_map.shape

        for i, lvl in enumerate(self.levels):
            # local stage en map
            if lvl["local"] is not None:
                x_map = lvl["local"](x_map)

            # global stage en tokens (si existe)
            if lvl["global"] is not None:
                x_tok = map_to_tokens(x_map)
                x_tok = lvl["global"](x_tok)
                x_map = tokens_to_map(x_tok, H, W)

            # downsample (si aplica)
            if i < len(self.downsamples):
                ds = self.downsamples[i]
                if self.downsample_kind == "map":
                    x_map = ds(x_map)
                    H, W = x_map.shape[1], x_map.shape[2]
                else:
                    # token downsample: necesita grid
                    x_tok = map_to_tokens(x_map)
                    x_tok, (H, W) = ds(x_tok, (H, W))
                    x_map = tokens_to_map(x_tok, H, W)

        x_final = map_to_tokens(x_map)
        return x_final, (H, W)


def test_volo_pyramid_map():
    torch.manual_seed(0)
    B = 2
    H = W = 16
    x_map = torch.randn(B, H, W, 192)

    pyr = VOLOPyramid(
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads=(6, 8, 12),
        downsample_kind="map",
        drop_path_rate=0.1,)

    x_tok, (Hf, Wf) = pyr(x_map)
    print("[OK] Pyramid-map:", x_tok.shape, "grid", (Hf, Wf))
    assert x_tok.shape[0] == B
    assert x_tok.shape[2] == 384
    assert Hf * Wf == x_tok.shape[1]



def test_volo_pyramid_token():
    torch.manual_seed(0)
    B = 2
    H = W = 16
    x_map = torch.randn(B, H, W, 192)

    pyr = VOLOPyramid(
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads=(6, 8, 12),
        downsample_kind="token",
        drop_path_rate=0.1)

    x_tok, (Hf, Wf) = pyr(x_map)
    print("[OK] Pyramid-token:", x_tok.shape, "grid", (Hf, Wf))
    assert x_tok.shape[2] == 384
    assert Hf * Wf == x_tok.shape[1]



class ClassAttention(nn.Module):
    """
    Class Attention: sólo el CLS atiende al conjunto [CLS | tokens].
    Inputs:
      cls:    [B, 1, C]
      tokens: [B, N, C]
    Output:
      cls_out: [B, 1, C]
    """
    def __init__(self, dim: int, num_heads: int, attn_dropout: float = 0.0, proj_dropout: float = 0.0):
        super().__init__()
        self.attn = MultiHeadAttention(d_model=dim, num_heads=num_heads, dropout=attn_dropout)
        self.proj_drop = nn.Dropout(proj_dropout)

    def forward(self, cls: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
        kv = torch.cat([cls, tokens], dim=1)        # [B, 1+N, C]
        cls_out = self.attn(cls, kv, mask=None) # [B, 1, C] (solo CLS sale actualizado)
        return self.proj_drop(cls_out)


class ClassAttentionBlock(nn.Module):
    """
    Pre-norm (CaiT-style):
      cls -> LN -> ClassAttn(cls, [cls|tokens]) -> +res
          -> LN -> MLP -> +res
    Nota: tokens NO se actualizan.
    """
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_dropout: float = 0.0,
        dropout: float = 0.0):

        super().__init__()
        self.norm_cls = nn.LayerNorm(dim)
        self.norm_tok = nn.LayerNorm(dim)
        self.ca = ClassAttention(dim, num_heads, attn_dropout=attn_dropout, proj_dropout=dropout)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = FeedForward(dim, int(dim * mlp_ratio), dropout=dropout)

    def forward(self, cls: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
        # Class attention update (solo CLS)
        cls_norm = self.norm_cls(cls)
        tok_norm = self.norm_tok(tokens)
        cls = cls + self.ca(cls_norm, tok_norm)

        # MLP update (solo CLS)
        cls = cls + self.mlp(self.norm2(cls))
        return cls

class CLIPool(nn.Module):
    """
    "CLI" style pooling: mezcla aprendible entre CLS y mean(tokens).
      z = alpha * cls + (1-alpha) * mean
    """
    def __init__(self, init_alpha: float = 0.5):
        super().__init__()
        # parametriza alpha en logits para mantenerlo en (0,1)
        init_alpha = float(init_alpha)
        init_alpha = min(max(init_alpha, 1e-4), 1 - 1e-4)
        logit = math.log(init_alpha / (1 - init_alpha))
        self.alpha_logit = nn.Parameter(torch.tensor([logit], dtype=torch.float32))

    def forward(self, cls_vec: torch.Tensor, tok_mean: torch.Tensor) -> torch.Tensor:
        """
        cls_vec:  [B, C]
        tok_mean: [B, C]
        """
        alpha = torch.sigmoid(self.alpha_logit)  # scalar in (0,1)
        return alpha * cls_vec + (1.0 - alpha) * tok_mean

"""# VOLO"""

import math

def trunc_normal_(tensor, mean=0., std=1.):
    with torch.no_grad():
        return tensor.normal_(mean=mean, std=std)

class PosEmbed2D(nn.Module):
    """
    Positional embedding aprendible para grilla (H, W) en tokens.

    Guarda [1, H*W, C]. Si en forward llega otro (H,W), interpola.
    """
    def __init__(self, H: int, W: int, dim: int):
        super().__init__()
        self.H0 = H
        self.W0 = W
        self.dim = dim
        self.pos = nn.Parameter(torch.zeros(1, H * W, dim))
        trunc_normal_(self.pos, std=0.02)

    def forward(self, x_tok: torch.Tensor, grid: tuple[int, int]):
        """
        x_tok: [B, N, C]
        grid: (H, W)
        """
        B, N, C = x_tok.shape
        H, W = grid
        if (H == self.H0) and (W == self.W0):
            return x_tok + self.pos

        # Interpola pos emb como mapa [1, C, H, W] -> nuevo tamaño
        pos = self.pos.reshape(1, self.H0, self.W0, self.dim).permute(0, 3, 1, 2)  # [1,C,H0,W0]
        pos = nn.functional.interpolate(pos, size=(H, W), mode="bicubic", align_corners=False)
        pos = pos.permute(0, 2, 3, 1).reshape(1, H * W, self.dim)
        return x_tok + pos

class VOLOClassifier(nn.Module):
    """
    VOLO para CIFAR-100 (y similares), con dos modos:
      - flat: OutlookerStage -> TransformerStack (sin downsample)
              pooling: mean | cls | cli (cls via class-attn final)
      - hierarchical: pirámide con downsample (map o token)
              pooling: SOLO mean (por ahora)

    Flujo base:
      x [B,3,H,W]
        -> PatchEmbeddingConv -> x_tok [B, N, C0]
        -> pos emb (opcional)
        -> backbone (flat o pyramid)
        -> pooling
        -> head
    """

    def __init__(
        self,
        num_classes: int = 100,
        img_size: int = 32,
        in_chans: int = 3,
        patch_size: int = 4,

        # mode
        hierarchical: bool = False,
        downsample_kind: str = "map",   # si hierarchical=True: "map" o "token"

        # dims / depths (flat)
        embed_dim: int = 192,
        outlooker_depth: int = 4,
        outlooker_heads: int = 6,
        transformer_depth: int = 6,
        transformer_heads: int = 6,

        # hierarchical configs (si hierarchical=True)
        dims: tuple[int, ...] = (192, 256, 384),
        outlooker_depths: tuple[int, ...] = (2, 2, 0),
        outlooker_heads_list: tuple[int, ...] = (6, 8, 12),
        transformer_depths: tuple[int, ...] = (0, 2, 2),
        transformer_heads_list: tuple[int, ...] = (6, 8, 12),

        # block hyperparams
        kernel_size: int = 3,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
        attn_dropout: float = 0.0,
        drop_path_rate: float = 0.0,

        # head / pooling
        pooling: str = "mean",          # flat: "mean"|"cls"|"cli" ; hierarchical: "mean"
        use_pos_embed: bool = True,

        # cls refinamiento (flat)
        cls_attn_depth: int = 2,        # # capas ClassAttentionBlock
        cli_init_alpha: float = 0.5,    # init alpha para pooling="cli"
        use_cls_pos: bool = True):

        super().__init__()

        self.hierarchical = hierarchical
        self.use_pos_embed = use_pos_embed

        if self.hierarchical:
            assert pooling == "mean", "Por ahora hierarchical solo soporta pooling='mean'."
        else:
            assert pooling in ["mean", "cls", "cli"], "pooling en flat debe ser 'mean', 'cls' o 'cli'."
        self.pooling = pooling

        # ---- Patch Embedding ----
        C0 = (dims[0] if hierarchical else embed_dim)

        self.patch_embed = PatchEmbeddingConv(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=C0,
            norm_layer=nn.LayerNorm,
            pad_if_needed=True,
            return_tokens=True,)

        Hp0 = math.ceil(img_size / patch_size)
        Wp0 = math.ceil(img_size / patch_size)

        self.pos_embed = PosEmbed2D(Hp0, Wp0, C0) if use_pos_embed else None
        self.pos_drop = nn.Dropout(dropout)

        # ---- Backbone ----
        if not hierarchical:
            total = outlooker_depth + transformer_depth
            dpr = torch.linspace(0, drop_path_rate, total).tolist() if total > 0 else []
            dpr_local = dpr[:outlooker_depth]
            dpr_glob = dpr[outlooker_depth:]

            self.local_stage = VOLOStage(
                dim=embed_dim,
                depth=outlooker_depth,
                num_heads=outlooker_heads,
                kernel_size=kernel_size,
                stride=1,
                mlp_ratio=mlp_ratio,
                attn_drop=attn_dropout,
                proj_drop=dropout,
                drop_path=dpr_local if len(dpr_local) else 0.0,
                mlp_drop=dropout)

            self.global_blocks = nn.ModuleList([
                TransformerBlock(
                    dim=embed_dim,
                    num_heads=transformer_heads,
                    mlp_ratio=mlp_ratio,
                    attn_dropout=attn_dropout,
                    dropout=dropout,
                    drop_path=(dpr_glob[i] if len(dpr_glob) else 0.0),
                ) for i in range(transformer_depth)])

            # --- CLS  (solo si pooling usa cls/cli) ---
            self.use_cls = (pooling in ["cls", "cli"])
            if self.use_cls:
                self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
                trunc_normal_(self.cls_token, std=0.02)

                self.cls_pos = None
                if use_cls_pos:
                    self.cls_pos = nn.Parameter(torch.zeros(1, 1, embed_dim))
                    trunc_normal_(self.cls_pos, std=0.02)

                self.cls_attn_blocks = nn.ModuleList([
                    ClassAttentionBlock(
                        dim=embed_dim,
                        num_heads=transformer_heads,
                        mlp_ratio=mlp_ratio,
                        attn_dropout=attn_dropout,
                        dropout=dropout,) for _ in range(int(cls_attn_depth))])

                self.cli_pool = CLIPool(init_alpha=cli_init_alpha) if pooling == "cli" else None
            else:
                self.cls_token = None
                self.cls_pos = None
                self.cls_attn_blocks = None
                self.cli_pool = None


            self.norm = nn.LayerNorm(embed_dim)
            self.norm_feat = nn.LayerNorm(embed_dim)

            self.head = nn.Linear(embed_dim, num_classes)

        else:
            self.pyramid = VOLOPyramid(
                dims=dims,
                outlooker_depths=outlooker_depths,
                outlooker_heads=outlooker_heads_list,
                transformer_depths=transformer_depths,
                transformer_heads=transformer_heads_list,
                kernel_size=kernel_size,
                mlp_ratio=mlp_ratio,
                downsample_kind=downsample_kind,
                drop_path_rate=drop_path_rate,)


            self.norm = nn.LayerNorm(dims[-1])
            self.norm_feat = nn.LayerNorm(dims[-1])
            self.head = nn.Linear(dims[-1], num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Patch embedding
        x_map, (Hp, Wp), x_tok, _pad = self.patch_embed(x)   # x_tok [B,N,C0]
        B, N, C0 = x_tok.shape

        # Pos emb sobre tokens del grid
        if self.use_pos_embed and (self.pos_embed is not None):
            x_tok = self.pos_embed(x_tok, (Hp, Wp))
        x_tok = self.pos_drop(x_tok)

        if not self.hierarchical:
            # ---- Flat backbone ----
            # Outlooker trabaja en map (sin CLS)
            x_map = x_tok.view(B, Hp, Wp, C0)
            x_map = self.local_stage(x_map)
            x_tok = x_map.view(B, Hp * Wp, C0)  # [B,N,C]

            # Transformer global (tokens sin CLS)
            for blk in self.global_blocks:
                x_tok = blk(x_tok)

            #  Pooling
            if self.pooling == "mean":
                # Normaliza tokens y promedia
                x_tok_n = self.norm(x_tok)           # [B,N,C]
                feat = x_tok_n.mean(dim=1)           # [B,C]
                feat = self.norm_feat(feat)          # [B,C]
                return self.head(feat)

            # CLS refinado con class-attn final (CaiT-style)
            cls = self.cls_token.expand(B, -1, -1)   # [B,1,C]
            if self.cls_pos is not None:
                cls = cls + self.cls_pos

            for cab in self.cls_attn_blocks:
                cls = cab(cls, x_tok)               # [B,1,C]

            cls_vec = cls.squeeze(1)                # [B,C]
            cls_vec = self.norm_feat(cls_vec)

            if self.pooling == "cls":
                feat = cls_vec
                return self.head(feat)

            # pooling == "cli": mezcla CLS con mean(tokens) normalizado
            tok_mean = self.norm(x_tok).mean(dim=1)  # [B,C]
            feat = self.cli_pool(cls_vec, tok_mean)
            feat = self.norm_feat(feat)
            return self.head(feat)

        else:
            # ---- Hierarchical backbone (solo mean) ----
            x_map = x_tok.view(B, Hp, Wp, C0)
            x_last, (Hf, Wf) = self.pyramid(x_map)      # x_last: [B, Nf, C_last]

            x_last = self.norm(x_last)
            feat = x_last.mean(dim=1)
            feat = self.norm_feat(feat)
            return self.head(feat)

def test_volo_classifier_flat():
    torch.manual_seed(0)
    model = VOLOClassifier(
        num_classes=100,
        img_size=64,
        patch_size=4,
        hierarchical=False,
        embed_dim=192,
        outlooker_depth=2,
        transformer_depth=2,
        outlooker_heads=6,
        transformer_heads=6,
        pooling="mean")

    x = torch.randn(2, 3, 64, 64)
    y = model(x)
    print("[OK] flat logits:", y.shape)
    assert y.shape == (2, 100)

def test_volo_classifier_hier():
    torch.manual_seed(0)
    model = VOLOClassifier(
        num_classes=100,
        img_size=64,
        patch_size=4,
        hierarchical=True,
        downsample_kind="map",
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads_list=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads_list=(6, 8, 12),
        pooling="mean",)

    x = torch.randn(2, 3, 64, 64)
    y = model(x)
    print("[OK] hier logits:", y.shape)
    assert y.shape == (2, 100)



def _fmt_out(output):
    if isinstance(output, (tuple, list)):
        shapes = []
        for o in output:
            if hasattr(o, "shape"):
                shapes.append(tuple(o.shape))
            else:
                shapes.append(type(o).__name__)
        return shapes
    if hasattr(output, "shape"):
        return tuple(output.shape)
    return type(output).__name__


def attach_shape_hooks_volo(model: nn.Module, verbose: bool = True):
    hooks = []

    def add_hook(mod: nn.Module, name: str):
        if mod is None:
            return
        def hook(_m, _inp, out):
            print(f"{name:35s} -> {_fmt_out(out)}")
        hooks.append(mod.register_forward_hook(hook))

    # Top-level components
    add_hook(getattr(model, "patch_embed", None), "patch_embed")
    add_hook(getattr(model, "local_stage", None), "local_stage (outlooker)")
    add_hook(getattr(model, "pyramid", None), "pyramid (top)")
    add_hook(getattr(model, "norm", None), "norm")
    add_hook(getattr(model, "head", None), "head")

    # Global blocks (flat)
    if hasattr(model, "global_blocks"):
        for i, blk in enumerate(model.global_blocks):
            add_hook(blk, f"global_block[{i}]")

    # Pyramid internals (hierarchical)
    pyr = getattr(model, "pyramid", None)
    if pyr is not None:
        if hasattr(pyr, "levels"):
            for i, lvl in enumerate(pyr.levels):
                # lvl es nn.ModuleDict: NO tiene .get
                loc = lvl["local"] if "local" in lvl else None
                glob = lvl["global"] if "global" in lvl else None
                add_hook(loc,  f"pyr.level[{i}].local")
                add_hook(glob, f"pyr.level[{i}].global")

        if hasattr(pyr, "downsamples"):
            for i, ds in enumerate(pyr.downsamples):
                add_hook(ds, f"pyr.down[{i}]")

    return hooks

def remove_hooks(hooks):
    for h in hooks:
        h.remove()

@torch.no_grad()
def debug_forward_shapes(model: nn.Module, img_size: int, device: str = "cpu", batch_size: int = 2):
    model = model.to(device).eval()
    hooks = attach_shape_hooks_volo(model)

    x = torch.randn(batch_size, 3, img_size, img_size, device=device)
    print(f"\n=== Forward debug | img_size={img_size} | model={model.__class__.__name__} ===")
    y = model(x)
    print(f"{'OUTPUT logits':35s} -> {tuple(y.shape)}")

    remove_hooks(hooks)



"""---"""

import os, math, random, inspect
from contextlib import contextmanager, nullcontext
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

def seed_everything(seed: int = 0, deterministic: bool = False):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True

_DTYPE_MAP = {
    "bf16": torch.bfloat16, "bfloat16": torch.bfloat16,
    "fp16": torch.float16,  "float16": torch.float16,
    "fp32": torch.float32,  "float32": torch.float32,}

def _cuda_dtype_supported(dtype: torch.dtype) -> bool:
    if not torch.cuda.is_available():
        return False
    return dtype in (torch.float16, torch.bfloat16)

def make_grad_scaler(device: str = "cuda", enabled: bool = True):
    if not enabled:
        return None

    if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
        try:
            sig = inspect.signature(torch.amp.GradScaler)
            if len(sig.parameters) >= 1:
                return torch.amp.GradScaler(device if device in ("cuda", "cpu") else "cuda")
            return torch.amp.GradScaler()
        except Exception:
            pass

    if hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "GradScaler"):
        return torch.cuda.amp.GradScaler()
    return None


@contextmanager
def autocast_ctx(
    device: str = "cuda",
    enabled: bool = True,
    dtype: str = "fp16",
    cache_enabled: bool = True,):
    """
    Context manager de autocast:
      - cuda: fp16 por defecto (ideal en T4)
      - cpu: bfloat16 si está disponible
    """
    if not enabled:
        with nullcontext():
            yield
        return

    if device == "cuda":
        want = _DTYPE_MAP.get(dtype.lower(), torch.float16)
        use = want if _cuda_dtype_supported(want) else torch.float16
        with torch.amp.autocast(device_type="cuda", dtype=use, cache_enabled=cache_enabled):
            yield
        return

    if device == "cpu":
        try:
            with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled):
                yield
        except Exception:
            with nullcontext():
                yield
        return

    with nullcontext():
        yield

def build_param_groups_no_wd(model: nn.Module, weight_decay: float):
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue

        name_l = name.lower()
        # no decay for biases + norms + positional/class tokens
        if (
            name.endswith(".bias")
            or ("norm" in name_l)
            or ("bn" in name_l)
            or ("ln" in name_l)
            or ("pos" in name_l)         # pos_embed / pos
            or ("cls_token" in name_l)
        ):
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0}]


class WarmupCosineLR:
    """Warmup linear for warmup_steps, then cosine to min_lr. Step-based."""
    def __init__(self, optimizer, total_steps: int, warmup_steps: int, min_lr: float = 0.0):
        self.optimizer = optimizer
        self.total_steps = int(total_steps)
        self.warmup_steps = int(warmup_steps)
        self.min_lr = float(min_lr)
        self.base_lrs = [g["lr"] for g in optimizer.param_groups]
        self.step_num = 0

    def step(self):
        self.step_num += 1
        t = self.step_num

        for i, group in enumerate(self.optimizer.param_groups):
            base = self.base_lrs[i]
            if t <= self.warmup_steps and self.warmup_steps > 0:
                lr = base * (t / self.warmup_steps)
            else:
                tt = min(t, self.total_steps)
                denom = max(1, self.total_steps - self.warmup_steps)
                progress = (tt - self.warmup_steps) / denom
                cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
                lr = self.min_lr + (base - self.min_lr) * cosine
            group["lr"] = lr

    def state_dict(self):
        return {"step_num": self.step_num}

    def load_state_dict(self, d):
        self.step_num = int(d.get("step_num", 0))

def save_checkpoint(
    path: str,
    model,
    optimizer,
    scheduler,
    scaler,
    epoch: int,
    best_top1: float,
    extra: dict | None = None,):

    ckpt = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "scheduler": scheduler.state_dict() if scheduler is not None else None,
        "scaler": scaler.state_dict() if scaler is not None else None,
        "epoch": epoch,
        "best_top1": best_top1,
        "extra": extra or {},}
    torch.save(ckpt, path)


def load_checkpoint(
    path: str,
    model,
    optimizer=None,
    scheduler=None,
    scaler=None,
    map_location="cpu",
    strict: bool = True,):
    ckpt = torch.load(path, map_location=map_location)
    model.load_state_dict(ckpt["model"], strict=strict)

    if optimizer is not None and ckpt.get("optimizer") is not None:
        optimizer.load_state_dict(ckpt["optimizer"])
    if scheduler is not None and ckpt.get("scheduler") is not None:
        scheduler.load_state_dict(ckpt["scheduler"])
    if scaler is not None and ckpt.get("scaler") is not None:
        scaler.load_state_dict(ckpt["scaler"])
    return ckpt

# -------------------------
# Mixup / CutMix + Loss
# -------------------------
def _one_hot(targets: torch.Tensor, num_classes: int) -> torch.Tensor:
    return F.one_hot(targets, num_classes=num_classes).float()


def soft_target_cross_entropy(logits: torch.Tensor, targets_soft: torch.Tensor) -> torch.Tensor:
    logp = F.log_softmax(logits, dim=1)
    return -(targets_soft * logp).sum(dim=1).mean()


def apply_mixup_cutmix(
    images: torch.Tensor,
    targets: torch.Tensor,
    num_classes: int,
    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    prob: float = 1.0,):
    """
    Returns:
      images_aug: [B,3,H,W]
      targets_soft: [B,K]
    """
    if prob <= 0.0 or (mixup_alpha <= 0.0 and cutmix_alpha <= 0.0):
        return images, _one_hot(targets, num_classes)

    if random.random() > prob:
        return images, _one_hot(targets, num_classes)

    use_cutmix = (cutmix_alpha > 0.0) and (mixup_alpha <= 0.0 or random.random() < 0.5)
    B, _, H, W = images.shape
    perm = torch.randperm(B, device=images.device)

    y1 = _one_hot(targets, num_classes)
    y2 = _one_hot(targets[perm], num_classes)

    if use_cutmix:
        lam = torch.distributions.Beta(cutmix_alpha, cutmix_alpha).sample().item()
        cut_w = int(W * math.sqrt(1.0 - lam))
        cut_h = int(H * math.sqrt(1.0 - lam))
        cx = random.randint(0, W - 1)
        cy = random.randint(0, H - 1)

        x1 = max(cx - cut_w // 2, 0)
        x2 = min(cx + cut_w // 2, W)
        y1b = max(cy - cut_h // 2, 0)
        y2b = min(cy + cut_h // 2, H)

        images_aug = images.clone()
        images_aug[:, :, y1b:y2b, x1:x2] = images[perm, :, y1b:y2b, x1:x2]

        # adjust lambda based on actual area swapped
        area = (x2 - x1) * (y2b - y1b)
        lam = 1.0 - area / float(W * H)
    else:
        lam = torch.distributions.Beta(mixup_alpha, mixup_alpha).sample().item()
        images_aug = images * lam + images[perm] * (1.0 - lam)

    targets_soft = y1 * lam + y2 * (1.0 - lam)
    return images_aug, targets_soft

# -------------------------
# Metrics
# -------------------------
@torch.no_grad()
def accuracy_topk(logits: torch.Tensor, targets: torch.Tensor, ks=(1, 3, 5)) -> Dict[int, float]:
    """
    targets can be:
      - int64 class indices [B]
      - soft targets [B, num_classes] (we'll argmax for accuracy reporting)
    """
    if targets.ndim == 2:
        targets = targets.argmax(dim=1)

    max_k = max(ks)
    B = targets.size(0)
    _, pred = torch.topk(logits, k=max_k, dim=1)
    correct = pred.eq(targets.view(-1, 1).expand_as(pred))
    out = {}
    for k in ks:
        out[k] = 100.0 * correct[:, :k].any(dim=1).float().sum().item() / B
    return out

import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

def ddp_is_on() -> bool:
    return dist.is_available() and dist.is_initialized()

def ddp_rank() -> int:
    return dist.get_rank() if ddp_is_on() else 0

def is_main_process() -> bool:
    return (not ddp_is_on()) or ddp_rank() == 0

def ddp_sum_(tensor: torch.Tensor) -> torch.Tensor:
    """All-reduce SUM in-place and return tensor."""
    if ddp_is_on():
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    return tensor

def ddp_broadcast_bool(flag: bool, device: torch.device | str) -> bool:
    """Broadcast a stop flag from rank0 to all ranks."""
    t = torch.tensor([1 if flag else 0], device=device)
    if ddp_is_on():
        dist.broadcast(t, src=0)
    return bool(t.item())

from typing import Optional
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

def train_one_epoch(
    model: nn.Module,
    dataloader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    device: str = "cuda",
    scaler=None,
    autocast_dtype: str = "fp16",
    use_amp: bool = True,
    grad_clip_norm: Optional[float] = 1.0,
    label_smoothing: float = 0.1,
    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    mix_prob: float = 1.0,
    num_classes: int = 100,
    channels_last: bool = False,
    print_every: int = 100,
):
    model.train()

    use_scaler = (scaler is not None) and use_amp and autocast_dtype.lower() in ("fp16", "float16")

    running_loss = 0.0
    total = 0
    c1 = c3 = c5 = 0.0  # “correct counts” acumulados (no %)

    t0 = time.time()
    for step, (images, targets) in enumerate(dataloader, start=1):
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if channels_last:
            images = images.contiguous(memory_format=torch.channels_last)

        B = targets.size(0)

        images_aug, targets_soft = apply_mixup_cutmix(
            images, targets,
            num_classes=num_classes,
            mixup_alpha=mixup_alpha,
            cutmix_alpha=cutmix_alpha,
            prob=mix_prob
        )

        use_mix = (mixup_alpha > 0.0) or (cutmix_alpha > 0.0)
        targets_for_acc = targets_soft.argmax(dim=1)

        optimizer.zero_grad(set_to_none=True)
        
        with autocast_ctx(device=device, enabled=use_amp, dtype=autocast_dtype, cache_enabled=True):
            logits = model(images_aug)

            assert logits.ndim == 2 and logits.size(1) == num_classes, f"logits shape {logits.shape} != [B,{num_classes}]"
            tmin = int(targets.min().item()); tmax = int(targets.max().item())
            assert 0 <= tmin and tmax < num_classes, f"targets out of range: min={tmin}, max={tmax}, K={num_classes}"
            
        if use_mix:
            loss = soft_target_cross_entropy(logits.float(), targets_soft)
        else:
            loss = F.cross_entropy(logits.float(), targets, label_smoothing=label_smoothing)

        if use_scaler:
            scaler.scale(loss).backward()
            if grad_clip_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        # métricas “locales”
        running_loss += loss.item() * B
        total += B
        accs = accuracy_topk(logits.detach(), targets_for_acc, ks=(1, 3, 5))
        c1 += accs[1] * B / 100.0
        c3 += accs[3] * B / 100.0
        c5 += accs[5] * B / 100.0

        # log solo en rank0 (si DDP)
        if print_every and (step % print_every == 0) and is_main_process():
            dt = time.time() - t0
            imgs_sec = total / max(dt, 1e-9)
            print(
                f"[train step {step}/{len(dataloader)}] "
                f"loss {running_loss/total:.4f} | "
                f"top1 {100*c1/total:.2f}% | top3 {100*c3/total:.2f}% | top5 {100*c5/total:.2f}% | "
                f"{imgs_sec:.1f} img/s | lr {optimizer.param_groups[0]['lr']:.2e}"
            )

    # ---- REDUCCIÓN GLOBAL (DDP) ----
    stats = torch.tensor([running_loss, total, c1, c3, c5], device=device, dtype=torch.float64)
    ddp_sum_(stats)
    running_loss_g, total_g, c1_g, c3_g, c5_g = stats.tolist()

    avg_loss = running_loss_g / max(total_g, 1e-12)
    metrics = {
        "top1": 100.0 * c1_g / max(total_g, 1e-12),
        "top3": 100.0 * c3_g / max(total_g, 1e-12),
        "top5": 100.0 * c5_g / max(total_g, 1e-12),}

    return avg_loss, metrics

@torch.no_grad()
def evaluate_one_epoch(
    model: nn.Module,
    dataloader,
    device: str = "cuda",
    autocast_dtype: str = "fp16",
    use_amp: bool = True,
    label_smoothing: float = 0.0,
    channels_last: bool = False,
):
    model.eval()

    running_loss = 0.0
    total = 0
    c1 = c3 = c5 = 0.0

    for images, targets in dataloader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if channels_last:
            images = images.contiguous(memory_format=torch.channels_last)

        B = targets.size(0)

        with autocast_ctx(device=device, enabled=use_amp, dtype=autocast_dtype, cache_enabled=True):
            logits = model(images)

        loss = F.cross_entropy(logits.float(), targets, label_smoothing=label_smoothing)

        running_loss += loss.item() * B
        total += B

        accs = accuracy_topk(logits, targets, ks=(1, 3, 5))
        c1 += accs[1] * B / 100.0
        c3 += accs[3] * B / 100.0
        c5 += accs[5] * B / 100.0

    # ---- REDUCCIÓN GLOBAL (DDP) ----
    stats = torch.tensor([running_loss, total, c1, c3, c5], device=device, dtype=torch.float64)
    ddp_sum_(stats)
    running_loss_g, total_g, c1_g, c3_g, c5_g = stats.tolist()

    avg_loss = running_loss_g / max(total_g, 1e-12)
    metrics = {
        "top1": 100.0 * c1_g / max(total_g, 1e-12),
        "top3": 100.0 * c3_g / max(total_g, 1e-12),
        "top5": 100.0 * c5_g / max(total_g, 1e-12),}

    return avg_loss, metrics

import time
import torch
import torch.nn as nn

def train_model(
    model: nn.Module,
    train_loader,
    epochs: int,
    val_loader=None,
    device: str = "cuda",
    lr: float = 5e-4,
    weight_decay: float = 0.05,
    autocast_dtype: str = "fp16",
    use_amp: bool = True,
    grad_clip_norm: float | None = 1.0,
    warmup_ratio: float = 0.05,
    min_lr: float = 0.0,
    label_smoothing: float = 0.1,
    print_every: int = 100,
    save_path: str = "best_model.pt",
    last_path: str = "last_model.pt",
    resume_path: str | None = None,

    mixup_alpha: float = 0.0,
    cutmix_alpha: float = 0.0,
    mix_prob: float = 1.0,
    num_classes: int = 100,
    channels_last: bool = False,

    early_stop: bool = True,
    early_stop_metric: str = "top1",
    early_stop_patience: int = 10,
    early_stop_min_delta: float = 0.0,
    early_stop_require_monotonic: bool = False):

    model.to(device)

    # Optimizer
    param_groups = build_param_groups_no_wd(model, weight_decay=weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.999), eps=1e-8)

    # Scheduler warmup + cosine (step-based)
    total_steps = epochs * len(train_loader)
    warmup_steps = int(total_steps * warmup_ratio)
    scheduler = WarmupCosineLR(
        optimizer,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        min_lr=min_lr,
    )

    scaler = None
    if use_amp and autocast_dtype.lower() in ("fp16", "float16"):
        scaler = make_grad_scaler(device=device, enabled=True)

    # Resume
    start_epoch = 0
    best_val_top1 = -float("inf")
    best_val_loss = float("inf")
    best_epoch = 0

    if resume_path is not None:
        ckpt = load_checkpoint(
            resume_path, model,
            optimizer=optimizer, scheduler=scheduler, scaler=scaler,
            map_location=device,
            strict=True,
        )

        start_epoch = int(ckpt.get("epoch", 0))
        best_val_top1 = float(ckpt.get("best_top1", best_val_top1))
        extra = ckpt.get("extra", {}) or {}
        best_val_loss = float(extra.get("best_val_loss", best_val_loss))
        best_epoch = int(extra.get("best_epoch", best_epoch))

        if is_main_process():
            print(f"Resumed from {resume_path} at epoch {start_epoch} | best_top1 {best_val_top1:.2f}% | best_loss {best_val_loss:.4f}")

    history = {
        "train_loss": [], "train_top1": [], "train_top3": [], "train_top5": [],
        "val_loss": [], "val_top1": [], "val_top3": [], "val_top5": [],
        "lr": [],
    } if is_main_process() else None  # <- solo rank0 guarda history

    # Early stop state (solo rank0 lleva el estado)
    metric = early_stop_metric.lower()
    assert metric in ("top1", "loss")
    patience = int(early_stop_patience)
    mode = "max" if metric == "top1" else "min"
    best_metric = best_val_top1 if metric == "top1" else best_val_loss
    bad_epochs = 0
    last_vals = []

    def _is_improvement(curr: float, best: float) -> bool:
        d = float(early_stop_min_delta)
        return (curr > (best + d)) if mode == "max" else (curr < (best - d))

    def _degradation_monotonic(vals: list[float]) -> bool:
        if not early_stop_require_monotonic or len(vals) < 2:
            return True
        if mode == "max":
            return all(vals[i] >= vals[i + 1] for i in range(len(vals) - 1))
        else:
            return all(vals[i] <= vals[i + 1] for i in range(len(vals) - 1))

    for epoch in range(start_epoch + 1, epochs + 1):
        if is_main_process():
            print(f"\n=== Epoch {epoch}/{epochs} ===")
        t_epoch = time.time()

        # ✅ DDP: reshuffle correcto por epoch
        if hasattr(train_loader, "sampler") and isinstance(train_loader.sampler, DistributedSampler):
            train_loader.sampler.set_epoch(epoch)
        if val_loader is not None and hasattr(val_loader, "sampler") and isinstance(val_loader.sampler, DistributedSampler):
            val_loader.sampler.set_epoch(epoch)

        # --- Train ---
        tr_loss, tr_m = train_one_epoch(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            scaler=scaler,
            autocast_dtype=autocast_dtype,
            use_amp=use_amp,
            grad_clip_norm=grad_clip_norm,
            label_smoothing=label_smoothing,
            mixup_alpha=mixup_alpha,
            cutmix_alpha=cutmix_alpha,
            mix_prob=mix_prob,
            num_classes=num_classes,
            channels_last=channels_last,
            print_every=print_every,
        )

        if is_main_process():
            history["train_loss"].append(tr_loss)
            history["train_top1"].append(tr_m["top1"])
            history["train_top3"].append(tr_m["top3"])
            history["train_top5"].append(tr_m["top5"])
            history["lr"].append(optimizer.param_groups[0]["lr"])

            print(f"[Train] loss {tr_loss:.4f} | top1 {tr_m['top1']:.2f}% | top3 {tr_m['top3']:.2f}% | top5 {tr_m['top5']:.2f}% | lr {optimizer.param_groups[0]['lr']:.2e}")

            # ✅ guardar "last" SOLO en rank0
            save_checkpoint(
                last_path, model, optimizer, scheduler, scaler,
                epoch=epoch, best_top1=best_val_top1,
                extra={
                    "autocast_dtype": autocast_dtype,
                    "use_amp": use_amp,
                    "best_val_loss": best_val_loss,
                    "best_epoch": best_epoch,
                    "early_stop_metric": metric,
                    "early_stop_patience": patience,
                    "early_stop_min_delta": float(early_stop_min_delta),
                },
            )

        stop_now = False

        # --- Val ---
        if val_loader is not None:
            va_loss, va_m = evaluate_one_epoch(
                model=model,
                dataloader=val_loader,
                device=device,
                autocast_dtype=autocast_dtype,
                use_amp=use_amp,
                label_smoothing=0.0,
                channels_last=channels_last,
            )

            if is_main_process():
                history["val_loss"].append(va_loss)
                history["val_top1"].append(va_m["top1"])
                history["val_top3"].append(va_m["top3"])
                history["val_top5"].append(va_m["top5"])

                print(f"[Val]   loss {va_loss:.4f} | top1 {va_m['top1']:.2f}% | top3 {va_m['top3']:.2f}% | top5 {va_m['top5']:.2f}%")

                # Best saved por top1
                if va_m["top1"] > best_val_top1:
                    best_val_top1 = va_m["top1"]
                    if va_loss < best_val_loss:
                        best_val_loss = va_loss
                        best_epoch = epoch

                    save_checkpoint(
                        save_path, model, optimizer, scheduler, scaler,
                        epoch=epoch, best_top1=best_val_top1,
                        extra={
                            "autocast_dtype": autocast_dtype,
                            "use_amp": use_amp,
                            "best_val_loss": best_val_loss,
                            "best_epoch": best_epoch,
                        },
                    )
                    print(f"Best saved to {save_path} (val top1 {best_val_top1:.2f}%)")

                # Early stop (solo rank0 decide)
                if early_stop:
                    curr_metric = va_m["top1"] if metric == "top1" else va_loss

                    last_vals.append(float(curr_metric))
                    if len(last_vals) > patience:
                        last_vals = last_vals[-patience:]

                    if _is_improvement(curr_metric, best_metric):
                        best_metric = float(curr_metric)
                        bad_epochs = 0
                    else:
                        bad_epochs += 1

                    if bad_epochs >= patience and _degradation_monotonic(last_vals):
                        print(f"Early-stop: no improvement on val_{metric} for {patience} epochs.")
                        stop_now = True

        # ✅ DDP: sincroniza el “stop” a todos los ranks
        stop_now = ddp_broadcast_bool(stop_now, device=device)
        if stop_now:
            break

        if is_main_process():
            dt = time.time() - t_epoch
            print(f"Epoch time: {dt/60:.2f} min")

    # return: history solo en rank0; en otros ranks devuelve None
    return history, (model.module if hasattr(model, "module") else model)

Overwriting volo.py


---

# Training

In [91]:
%%writefile train_ddp.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from volo import *


def setup_ddp():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return local_rank

def is_main():
    return (not dist.is_available()) or (not dist.is_initialized()) or dist.get_rank() == 0

def main():
    torch.set_num_threads(1)
    torch.set_num_interop_threads(1)
    
    local_rank = setup_ddp()
    device = torch.device(f"cuda:{local_rank}")

    train_ds, val_ds, test_ds = get_cifar100_datasets(
        data_dir="./data/cifar100",
        val_split=0.1,
        img_size=32,
        ddp_safe_download=True,)

    train_sampler = DistributedSampler(train_ds, shuffle=True, drop_last=True)

    train_loader = DataLoader(
            train_ds,
            batch_size=256,
            sampler=train_sampler,
            num_workers=2,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2,)

    val_loader = None
    
    if val_ds is not None:
        val_sampler = DistributedSampler(val_ds, shuffle=False, drop_last=False)
        val_loader = DataLoader(
            val_ds,
            batch_size=256,
            sampler=val_sampler,
            num_workers=2,
            pin_memory=True,
            persistent_workers=True,)

    model = VOLOClassifier(
        num_classes=100,
        img_size=32,
        patch_size=4,
        hierarchical=False,
        embed_dim=320,
        outlooker_depth=5,
        outlooker_heads=10,
        transformer_depth=10,
        transformer_heads=10,
        kernel_size=3,
        mlp_ratio=4.0,
        dropout=0.12,
        attn_dropout=0.05,
        drop_path_rate=0.20,
        pooling="cls",
        cls_attn_depth=2,
        use_pos_embed=True,
        use_cls_pos=True,).to(device)

    model = DDP(model, device_ids=[local_rank], output_device=local_rank ,find_unused_parameters=True)

    history, best = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=130,
        device=str(device),
        lr=5e-4,
        weight_decay=0.05,
        use_amp=True,
        autocast_dtype="fp16",
        print_every=25,
        num_classes=100,
        save_path="best_model.pt",
        last_path="last_model.pt",)

    dist.destroy_process_group()


if __name__ == "__main__":
    main()


Overwriting train_ddp.py


In [92]:
!torchrun --nproc_per_node=2 train_ddp.py

W1231 03:46:26.248000 2847 torch/distributed/run.py:792] 
W1231 03:46:26.248000 2847 torch/distributed/run.py:792] *****************************************
W1231 03:46:26.248000 2847 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1231 03:46:26.248000 2847 torch/distributed/run.py:792] *****************************************
[rank1]:[W1231 03:46:29.459079996 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank0]:[W1231 03:46:30.310899516 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perfor

In [98]:
model = VOLOClassifier(
        num_classes=100,
        img_size=32,
        patch_size=4,
        hierarchical=False,
        embed_dim=320,
        outlooker_depth=5,
        outlooker_heads=10,
        transformer_depth=10,
        transformer_heads=10,
        kernel_size=3,
        mlp_ratio=4.0,
        dropout=0.12,
        attn_dropout=0.05,
        drop_path_rate=0.20,
        pooling="cls",
        cls_attn_depth=2,
        use_pos_embed=True,
        use_cls_pos=True,)

state = torch.load("best_model.pt", map_location="cpu")

if isinstance(state, dict) and ("model" in state or "state_dict" in state):
    sd = state.get("model", state.get("state_dict"))
else:
    sd = state 


if any(k.startswith("module.") for k in sd.keys()):
    sd = {k.replace("module.", "", 1): v for k, v in sd.items()}


missing, unexpected = model.load_state_dict(sd, strict=True)
print("missing:", missing)
print("unexpected:", unexpected)

device = torch.device("cuda:0")
model = model.to(device) 

model.eval()

test_loss, test_m = evaluate_one_epoch(
    model=model,
    dataloader=test_loader,
    device="cuda",
    use_amp=False,         
    autocast_dtype="fp16")

print("[Test VOLO paper-like] loss", test_loss, "|", test_m)

missing: []
unexpected: []
[Test VOLO paper-like] loss 1.3181868873596192 | {'top1': 67.9, 'top3': 83.93, 'top5': 88.22}
