In [None]:
import sys, os

root_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root_path not in sys.path:
    sys.path.append(root_path)

In [None]:
from src.data.load_cifrar100 import *

train_loader, val_loader, test_loader = get_cifar100_dataloaders(
    batch_size=64,
    data_dir="./data/cifar100",
    num_workers=2,
    val_split=0.1,
    img_size=32)

In [None]:
from src.data.data_utils import *

CIFAR100_MEAN = (0.5071, 0.4867, 0.4408)
CIFAR100_STD  = (0.2675, 0.2565, 0.2761)


describe_loader(train_loader, "train_loader", max_batches_for_stats=50)


TRAIN_LOADER SUMMARY
Dataset type        : Subset
  ↳ Wrapped dataset  : CIFAR100 (Subset-like)
  ↳ Subset size      : 45000
Num samples         : 45000
Batch size          : 128
Num workers         : 2
Pin memory          : True
Drop last           : False
Sampler             : RandomSampler
len(loader) (#batches): 352 (≈ ceil(45000/128) = 352)

First batch:
  x.shape           : (128, 3, 32, 32)
  y.shape           : (128,)
  x.dtype           : torch.float32
  y.dtype           : torch.int64
  x.min/max         : -3.9943 / 3.9723
  y.min/max         : 0 / 99
  unique labels (batch): 71

Quick stats over up to 50 batches:
  Approx mean        : -0.281566
  Approx std         : 1.129613
  Seen label counts  : 100 classes (in sampled batches)
  Top-5 labels       : [(45, 88), (27, 83), (42, 78), (49, 77), (64, 77)]

Full dataset label distribution:
  #classes detected  : 100
  min/max per class  : 436 / 463
  first 10 classes   : [(0, 457), (1, 439), (2, 448), (3, 455), (4, 446), (5, 

---

# Training

In [None]:
from src.training.train_full_model import *
from stage_config import * 
from Model_A_OutGridNet import * 
from Model_B_OutGridNet import *

def cifar32_stages_t4_tinyplus(drop_path=0.08):
    # resoluciones: 64 -> 32 -> 16 -> 8
    return [
        StageCfg(dim=80,  depth=2, num_heads=2,  grid_size=4, outlook_heads=2,  drop_path=drop_path),
        StageCfg(dim=160, depth=3, num_heads=5,  grid_size=4, outlook_heads=5,  drop_path=drop_path),
        StageCfg(dim=320, depth=4, num_heads=10, grid_size=2, outlook_heads=10, drop_path=drop_path),
        StageCfg(dim=448, depth=2, num_heads=8,  grid_size=1, outlook_heads=8,  drop_path=drop_path),
    ]

stages = cifar32_stages_t4_tinyplus(drop_path=0.10)

model = MaxOutNet(
    num_classes=100,
    stages=stages,
    stem_dim=64,
    dpr_max=0.12)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [30]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

n_params = count_trainable_parameters(model)
print(f"Trainable parameters: {n_params:,}")

Trainable parameters: 32,974,583


In [31]:
import random, numpy as np

seed = 7
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

random.seed(seed);
np.random.seed(seed)
torch.backends.cudnn.benchmark = True


history, model = train_model(
    model=model,
    train_loader=train_loader,
    epochs=50,
    val_loader=val_loader,
    device=device,

    lr=5e-4,
    weight_decay=0.05,

    autocast_dtype="fp16" if device == "cuda" else "fp32",
    use_amp=(device == "cuda"),
    grad_clip_norm=1.0,

    warmup_ratio=0.05,
    min_lr=1e-6,

    label_smoothing=0.0,

    print_every=100,
    save_path="best_maxout_medium.pt",
    last_path="last_maxout_medium.pt",
    resume_path=None,

    # Augmentations
    mix_prob=0.5,
    mixup_alpha=0.0,
    cutmix_alpha=1.0,

    num_classes=100,
    channels_last=True)

=== Run config ===
device=cuda | amp=True | autocast_dtype=fp16 | channels_last=True
epochs=50 | steps/epoch=704 | total_steps=35200 | warmup_steps=1760
batch_size=64 | input_shape=(64, 3, 32, 32) | num_classes=100
opt=AdamW | lr=0.0005 | wd=0.05 | grad_clip_norm=1.0
aug: mix_prob=0.5 | mixup_alpha=0.0 | cutmix_alpha=1.0 | label_smoothing=0.0

=== Epoch 1/50 ===
[train step 100/704] loss 4.5565 | top1 2.69% | top3 6.70% | top5 10.14% | 114.2 img/s | lr 2.84e-05 | gnorm 8.440 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[train step 200/704] loss 4.4513 | top1 3.63% | top3 9.41% | top5 13.98% | 143.1 img/s | lr 5.68e-05 | gnorm 7.938 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[train step 300/704] loss 4.3773 | top1 4.54% | top3 11.35% | top5 16.70% | 156.3 img/s | lr 8.52e-05 | gnorm 7.437 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[train step 400/704] loss 4.3407 | top1 5.00% | top3 12.44% | top5 18.14% | 163.5 img/s | lr 1.14e-04 | gnorm 6.877 | clip 100.

In [32]:
evaluate_one_epoch(model=model,dataloader=test_loader)

(0.7801645481109619, {'top1': 78.42, 'top3': 92.07, 'top5': 95.22})