In [None]:
import sys, os

root_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root_path not in sys.path:
    sys.path.append(root_path)

In [None]:
from data.load_cifrar100 import *

train_loader, val_loader, test_loader = get_cifar100_dataloaders(
    batch_size=64,
    data_dir="./data",
    num_workers=2,
    val_split=0.1,
    img_size=32 , seed=7)

In [11]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

---

## MaxVit Nano

In [None]:
import timm
import torch.nn as nn
from src.training.train_full_model import * 
from src.training.metrics import * 
from src.training.eval_one_epoch_logs import *

device = 'cuda'
model_maxvit_nano = timm.create_model(
    "maxvit_tiny_tf_224",
    pretrained=False,
    num_classes=100,
    img_size=32,
    embed_dim=[64, 96, 192, 384] )

model_maxvit_nano.stem.conv1 = nn.Conv2d(
    in_channels=3, 
    out_channels=64, 
    kernel_size=3, 
    stride=1, 
    padding=1, 
    bias=False)


model_maxvit_nano.stem.norm1 = nn.BatchNorm2d(
    num_features=64, 
    eps=1e-3,    
    momentum=0.1)

model_maxvit_nano.stem.conv2 = nn.Conv2d(
    in_channels=64, 
    out_channels=64, 
    kernel_size=3, 
    stride=1, 
    padding=1, 
    bias=False)

model_maxvit_nano = model_maxvit_nano.to(device)

n_params = sum(p.numel() for p in model_maxvit_nano.parameters() if p.requires_grad)
print(f"MaxViT-Nano Surgery Successful. Params: {n_params/1e6:.2f}M")

MaxViT-Nano Surgery Successful. Params: 17.38M


In [26]:
history, _ = train_model(
        model=model_maxvit_nano,
        train_loader=train_loader,
        epochs=100,
        val_loader=val_loader,
        device=device,

        lr=5e-4,
        weight_decay=0.05,

        # Mixed Precision
        autocast_dtype="fp16" if device == "cuda" else "fp32",
        use_amp=(device == "cuda"),
        grad_clip_norm=1.0,

        warmup_ratio=0.05,
        min_lr=1e-6,

        label_smoothing=0.1,
        save_path= "best_model_vit1.pt",
        last_path = "last_model_vit1.pt",

        print_every=400,
        mix_prob=0.5,
        mixup_alpha=0.8,
        cutmix_alpha=1.0,

        num_classes=100,
        channels_last=True)

=== Run config ===
device=cuda | amp=True | autocast_dtype=fp16 | channels_last=True
epochs=100 | steps/epoch=704 | total_steps=70400 | warmup_steps=3520
batch_size=64 | input_shape=(64, 3, 32, 32) | num_classes=100
opt=AdamW | lr=0.0005 | wd=0.05 | grad_clip_norm=1.0
aug: mix_prob=0.5 | mixup_alpha=0.8 | cutmix_alpha=1.0 | label_smoothing=0.1

=== Epoch 1/100 ===
[train step 400/704] loss 4.4453 | top1 4.00% | top3 10.15% | top5 15.09% | 472.7 img/s | lr 5.68e-05 | gnorm 2.919 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[train step 704/704] loss 4.3081 | top1 5.88% | top3 14.03% | top5 20.19% | 459.2 img/s | lr 1.00e-04 | gnorm 3.474 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[Train] loss 4.3081 | top1 5.88% | top3 14.03% | top5 20.19% | lr 1.00e-04 | grad_norm 3.474 | clip 100.0% | amp_overflows 0 | nonfinite_loss 0 | scale 65536.0
[Train] mem_peak alloc 1.77 GiB | reserved 2.47 GiB
[Val]   loss 3.9383 | top1 10.38% | top3 22.20% | top5 30.50%
[Val]   mem_peak al

---

In [None]:
avg_loss, metrics = evaluate_one_epoch_logs(
    model=model_maxvit_nano,
    dataloader=test_loader,
    device="cuda",
    use_amp=True,
    autocast_dtype="fp16",     
    channels_last=False,      
    measure_flops=True,       
    flops_warmup_batches=1)

print(f"loss: {avg_loss:.4f}")
print(f"top1: {metrics['top1']:.2f} | top3: {metrics['top3']:.2f} | top5: {metrics['top5']:.2f}")

print(f"throughput: {metrics['imgs_per_sec']:.1f} imgs/s | epoch: {metrics['epoch_time_sec']:.2f}s | ms/batch: {metrics['ms_per_batch']:.2f}")
print(f"GPU mem: alloc={metrics['gpu_mem_allocated_mib']:.0f} MiB | reserved={metrics['gpu_mem_reserved_mib']:.0f} MiB | peak={metrics['gpu_mem_peak_allocated_mib']:.0f} MiB")
print(f"model: params={int(metrics['model_params']):,} | param_size={metrics['model_param_size_mib']:.1f} MiB")

# FLOPs/MACs pueden salir nan si no tienes fvcore/thop instalado o si falla el profiler
print(f"FLOPs/forward: {metrics['flops_per_forward']:.3e} | MACs/forward: {metrics['macs_per_forward']:.3e}")

loss: 0.9906
top1: 75.41 | top3: 88.87 | top5: 92.34
throughput: 1253.9 imgs/s | epoch: 7.97s | ms/batch: 45.15
GPU mem: alloc=1094 MiB | reserved=2526 MiB | peak=1201 MiB
model: params=17,379,140 | param_size=66.3 MiB
FLOPs/forward: nan | MACs/forward: nan
