In [None]:
from data.load_data import *

train_loader, val_loader, test_loader = get_cifar100_dataloaders(
    batch_size=256,
    data_dir="./data/cifar100",
    num_workers=2,
    val_split=0.1,
    img_size=32)

In [None]:
from model.embeddings import *

def test_patch_embedding_conv():
    torch.manual_seed(0)

    B, C, H, W = 2, 3, 64, 64
    x = torch.randn(B, C, H, W)

    pe = PatchEmbeddingConv(
        patch_size=4,
        in_chans=3,
        embed_dim=192,
        norm_layer=torch.nn.LayerNorm,
        pad_if_needed=True,
        return_tokens=True,)

    x_map, (Hp, Wp), x_tok, (pad_h, pad_w) = pe(x)

    assert x_map.shape == (B, Hp, Wp, 192)
    assert x_tok.shape == (B, Hp * Wp, 192)
    assert (pad_h, pad_w) == (0, 0)
    assert (Hp, Wp) == (H // 4, W // 4)

    print("[OK] PatchEmbeddingConv divisible:",
          "x_map", tuple(x_map.shape),
          "| x_tok", tuple(x_tok.shape),
          "| pad", (pad_h, pad_w))

    # tamaño NO divisible (65x63 con patch=4) -> debería paddear
    H2, W2 = 65, 63
    x2 = torch.randn(B, C, H2, W2)

    x_map2, (Hp2, Wp2), x_tok2, (pad_h2, pad_w2) = pe(x2)

    assert (H2 + pad_h2) % 4 == 0
    assert (W2 + pad_w2) % 4 == 0
    assert x_map2.shape == (B, Hp2, Wp2, 192)
    assert x_tok2.shape == (B, Hp2 * Wp2, 192)

    print("[OK] PatchEmbeddingConv non-divisible:",
          "input", (H2, W2),
          "| padded by", (pad_h2, pad_w2),
          "| patches", (Hp2, Wp2),
          "| x_map", tuple(x_map2.shape))

test_patch_embedding_conv()

[OK] PatchEmbeddingConv divisible: x_map (2, 16, 16, 192) | x_tok (2, 256, 192) | pad (0, 0)
[OK] PatchEmbeddingConv non-divisible: input (65, 63) | padded by (3, 1) | patches (17, 16) | x_map (2, 17, 16, 192)


In [9]:
def test_outlook_attention_stride1():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    oa = OutlookAttention(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=1,
        attn_drop=0.0,
        proj_drop=0.0)

    y = oa(x_map)
    assert y.shape == x_map.shape, f"Expected {x_map.shape}, got {y.shape}"

    loss = y.mean()
    loss.backward()

    assert x_map.grad is not None, "No gradient flowed to input!"
    assert torch.isfinite(x_map.grad).all(), "Non-finite grads!"

    print("[OK] OutlookAttention stride=1:",
          "in", tuple(x_map.shape),
          "| out", tuple(y.shape),
          "| grad mean", float(x_map.grad.abs().mean()))

test_outlook_attention_stride1()

[OK] OutlookAttention stride=1: in (2, 16, 16, 192) | out (2, 16, 16, 192) | grad mean 2.7283142571832286e-06


In [None]:
from model.outlook import *

def test_outlook_attention_stride2():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    oa = OutlookAttention(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=2,
        attn_drop=0.0,
        proj_drop=0.0)

    y = oa(x_map)

    assert y.shape[0] == B and y.shape[-1] == C
    assert y.shape[1] == H // 2 and y.shape[2] == W // 2, f"Got {y.shape[1:3]}"

    loss = y.mean()
    loss.backward()
    assert x_map.grad is not None
    assert torch.isfinite(x_map.grad).all()

    print("[OK] OutlookAttention stride=2:",
          "in", (B, H, W, C),
          "| out", tuple(y.shape))

test_outlook_attention_stride2()

[OK] OutlookAttention stride=2: in (2, 16, 16, 192) | out (2, 8, 8, 192)


In [12]:
def test_outlooker_block():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x_map = torch.randn(B, H, W, C, requires_grad=True)

    blk = OutlookerBlock(
        dim=C,
        num_heads=6,
        kernel_size=3,
        stride=1,
        mlp_ratio=4.0,
        attn_drop=0.0,
        proj_drop=0.0,
        drop_path=0.0,
        mlp_drop=0.0,)

    y = blk(x_map)
    assert y.shape == x_map.shape

    y.mean().backward()
    assert x_map.grad is not None
    assert torch.isfinite(x_map.grad).all()

    print("[OK] OutlookerBlock:",
          "in/out", tuple(y.shape),
          "| grad mean", float(x_map.grad.abs().mean()))

test_outlooker_block()

[OK] OutlookerBlock: in/out (2, 16, 16, 192) | grad mean 1.0187762200075667e-05


In [None]:


def test_embed_then_outlook(img_size=64, patch_size=4, dim=192, heads=6):
    torch.manual_seed(0)

    B = 2
    x = torch.randn(B, 3, img_size, img_size, requires_grad=True)

    pe = PatchEmbeddingConv(
        patch_size=patch_size,
        in_chans=3,
        embed_dim=dim,
        norm_layer=torch.nn.LayerNorm,
        pad_if_needed=True,
        return_tokens=True,)

    blk = OutlookerBlock(
        dim=dim,
        num_heads=heads,
        kernel_size=3,
        stride=1,
        mlp_ratio=4.0,
        drop_path=0.0,)

    x_map, (Hp, Wp), x_tok, pad_hw = pe(x)
    y_map = blk(x_map)

    assert y_map.shape == x_map.shape == (B, Hp, Wp, dim)

    # grad
    y_map.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()

    print("[OK] Embed->Outlook:",
          "img", (img_size, img_size),
          "| patches", (Hp, Wp),
          "| map", tuple(y_map.shape),
          "| pad", pad_hw)

test_embed_then_outlook(img_size=32)
test_embed_then_outlook(img_size=64)

[OK] Embed->Outlook: img (32, 32) | patches (8, 8) | map (2, 8, 8, 192) | pad (0, 0)
[OK] Embed->Outlook: img (64, 64) | patches (16, 16) | map (2, 16, 16, 192) | pad (0, 0)


In [None]:
from model.volo_stage import *

def test_volo_stage():
    torch.manual_seed(0)

    B, H, W, C = 2, 16, 16, 192
    x = torch.randn(B, H, W, C, requires_grad=True)

    stage = VOLOStage(
        dim=C,
        depth=3,
        num_heads=6,
        kernel_size=3,
        stride=1,
        drop_path=[0.0, 0.05, 0.1])

    y = stage(x)
    assert y.shape == x.shape
    y.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()

    print("[OK] VOLOStage:", tuple(y.shape), "| grad mean", float(x.grad.abs().mean()))

test_volo_stage()

[OK] VOLOStage: (2, 16, 16, 192) | grad mean 1.0873188330151606e-05


In [None]:
from model.attention import *

def test_transformer_block():
    torch.manual_seed(0)
    B, N, C = 2, 256, 192
    x = torch.randn(B, N, C, requires_grad=True)

    blk = TransformerBlock(dim=C, num_heads=6, mlp_ratio=4.0, attn_dropout=0.0, dropout=0.1, drop_path=0.0)
    y = blk(x)
    assert y.shape == x.shape
    y.mean().backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()
    print("[OK] TransformerBlock:", tuple(y.shape), "grad", float(x.grad.abs().mean()))

test_transformer_block()

[OK] TransformerBlock: (2, 256, 192) grad 1.018048442347208e-05


In [None]:
from model.pooling_volo_blocks import *

def test_volo_pyramid_map():
    torch.manual_seed(0)
    B = 2
    H = W = 16
    x_map = torch.randn(B, H, W, 192)

    pyr = VOLOPyramid(
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads=(6, 8, 12),
        downsample_kind="map",
        drop_path_rate=0.1,)

    x_tok, (Hf, Wf) = pyr(x_map)
    print("[OK] Pyramid-map:", x_tok.shape, "grid", (Hf, Wf))
    assert x_tok.shape[0] == B
    assert x_tok.shape[2] == 384
    assert Hf * Wf == x_tok.shape[1]

test_volo_pyramid_map()

[OK] Pyramid-map: torch.Size([2, 16, 384]) grid (4, 4)


In [26]:
def test_volo_pyramid_token():
    torch.manual_seed(0)
    B = 2
    H = W = 16
    x_map = torch.randn(B, H, W, 192)

    pyr = VOLOPyramid(
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads=(6, 8, 12),
        downsample_kind="token",
        drop_path_rate=0.1)

    x_tok, (Hf, Wf) = pyr(x_map)
    print("[OK] Pyramid-token:", x_tok.shape, "grid", (Hf, Wf))
    assert x_tok.shape[2] == 384
    assert Hf * Wf == x_tok.shape[1]

test_volo_pyramid_token()

[OK] Pyramid-token: torch.Size([2, 16, 384]) grid (4, 4)


# VOLO

In [None]:
from model.VOLO import *

def test_volo_classifier_flat():
    torch.manual_seed(0)
    model = VOLOClassifier(
        num_classes=100,
        img_size=64,
        patch_size=4,
        hierarchical=False,
        embed_dim=192,
        outlooker_depth=2,
        transformer_depth=2,
        outlooker_heads=6,
        transformer_heads=6,
        pooling="mean")

    x = torch.randn(2, 3, 64, 64)
    y = model(x)
    print("[OK] flat logits:", y.shape)
    assert y.shape == (2, 100)

def test_volo_classifier_hier():
    torch.manual_seed(0)
    model = VOLOClassifier(
        num_classes=100,
        img_size=64,
        patch_size=4,
        hierarchical=True,
        downsample_kind="map",
        dims=(192, 256, 384),
        outlooker_depths=(2, 2, 0),
        outlooker_heads_list=(6, 8, 12),
        transformer_depths=(0, 2, 2),
        transformer_heads_list=(6, 8, 12),
        pooling="mean",)

    x = torch.randn(2, 3, 64, 64)
    y = model(x)
    print("[OK] hier logits:", y.shape)
    assert y.shape == (2, 100)

test_volo_classifier_flat()
test_volo_classifier_hier()

[OK] flat logits: torch.Size([2, 100])
[OK] hier logits: torch.Size([2, 100])


In [None]:
from model.model_utils import *

model_flat64 = VOLOClassifier(
    num_classes=100,
    img_size=64,
    patch_size=4,
    hierarchical=False,
    embed_dim=192,
    outlooker_depth=2,
    outlooker_heads=6,
    transformer_depth=2,
    transformer_heads=6,
    pooling="mean",
    use_pos_embed=True,)

debug_forward_shapes(model_flat64, img_size=64, device="cpu")



=== Forward debug | img_size=64 | model=VOLOClassifier ===
patch_embed                         -> [(2, 16, 16, 192), 'tuple', (2, 256, 192), 'tuple']
local_stage (outlooker)             -> (2, 16, 16, 192)
global_block[0]                     -> (2, 256, 192)
global_block[1]                     -> (2, 256, 192)
norm                                -> (2, 256, 192)
head                                -> (2, 100)
OUTPUT logits                       -> (2, 100)


In [33]:
model_hier64 = VOLOClassifier(
    num_classes=100,
    img_size=64,
    patch_size=4,
    hierarchical=True,
    downsample_kind="map",
    dims=(192, 256, 384),
    outlooker_depths=(2, 2, 0),
    outlooker_heads_list=(6, 8, 12),
    transformer_depths=(0, 2, 2),
    transformer_heads_list=(6, 8, 12),
    pooling="mean",
    use_pos_embed=True,)

debug_forward_shapes(model_hier64, img_size=64, device="cpu")


=== Forward debug | img_size=64 | model=VOLOClassifier ===
patch_embed                         -> [(2, 16, 16, 192), 'tuple', (2, 256, 192), 'tuple']
pyr.level[0].local                  -> (2, 16, 16, 192)
pyr.down[0]                         -> (2, 8, 8, 256)
pyr.level[1].local                  -> (2, 8, 8, 256)
pyr.level[1].global                 -> (2, 64, 256)
pyr.down[1]                         -> (2, 4, 4, 384)
pyr.level[2].global                 -> (2, 16, 384)
pyramid (top)                       -> [(2, 16, 384), 'tuple']
norm                                -> (2, 16, 384)
head                                -> (2, 100)
OUTPUT logits                       -> (2, 100)


In [47]:
model_hier64_tok = VOLOClassifier(
    num_classes=100,
    img_size=64,
    patch_size=4,
    hierarchical=True,
    downsample_kind="token",
    dims=(192, 256, 384),
    outlooker_depths=(2, 2, 0),
    outlooker_heads_list=(6, 8, 12),
    transformer_depths=(0, 2, 2),
    transformer_heads_list=(6, 8, 12),
    pooling="mean",
    use_pos_embed=True,)

debug_forward_shapes(model_hier64_tok, img_size=64, device="cpu")


=== Forward debug | img_size=64 | model=VOLOClassifier ===
patch_embed                         -> [(2, 16, 16, 192), 'tuple', (2, 256, 192), 'tuple']
pyr.level[0].local                  -> (2, 16, 16, 192)
pyr.down[0]                         -> [(2, 64, 256), 'tuple']
pyr.level[1].local                  -> (2, 8, 8, 256)
pyr.level[1].global                 -> (2, 64, 256)
pyr.down[1]                         -> [(2, 16, 384), 'tuple']
pyr.level[2].global                 -> (2, 16, 384)
pyramid (top)                       -> [(2, 16, 384), 'tuple']
norm                                -> (2, 16, 384)
head                                -> (2, 100)
OUTPUT logits                       -> (2, 100)


---

In [None]:
!torchrun --nproc_per_node=2 main_training_ddp.py

W1231 03:46:26.248000 2847 torch/distributed/run.py:792] 
W1231 03:46:26.248000 2847 torch/distributed/run.py:792] *****************************************
W1231 03:46:26.248000 2847 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1231 03:46:26.248000 2847 torch/distributed/run.py:792] *****************************************
[rank1]:[W1231 03:46:29.459079996 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank0]:[W1231 03:46:30.310899516 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perfor

In [None]:
from training.train_one_epoch import *

model = VOLOClassifier(
        num_classes=100,
        img_size=32,
        patch_size=4,
        hierarchical=False,
        embed_dim=320,
        outlooker_depth=5,
        outlooker_heads=10,
        transformer_depth=10,
        transformer_heads=10,
        kernel_size=3,
        mlp_ratio=4.0,
        dropout=0.12,
        attn_dropout=0.05,
        drop_path_rate=0.20,
        pooling="cls",
        cls_attn_depth=2,
        use_pos_embed=True,
        use_cls_pos=True,)

state = torch.load("best_model.pt", map_location="cpu")

if isinstance(state, dict) and ("model" in state or "state_dict" in state):
    sd = state.get("model", state.get("state_dict"))
else:
    sd = state 


if any(k.startswith("module.") for k in sd.keys()):
    sd = {k.replace("module.", "", 1): v for k, v in sd.items()}


missing, unexpected = model.load_state_dict(sd, strict=True)
print("missing:", missing)
print("unexpected:", unexpected)

device = torch.device("cuda:0")
model = model.to(device) 

model.eval()

test_loss, test_m = evaluate_one_epoch(
    model=model,
    dataloader=test_loader,
    device="cuda",
    use_amp=False,         
    autocast_dtype="fp16")

print("[Test VOLO paper-like] loss", test_loss, "|", test_m)

missing: []
unexpected: []
[Test VOLO paper-like] loss 1.3181868873596192 | {'top1': 67.9, 'top3': 83.93, 'top5': 88.22}
