In [None]:
from data.load_cifrar100 import *

train_loader, val_loader, test_loader = get_cifar100_dataloaders(
    batch_size=64,
    data_dir="./data",
    num_workers=2,
    val_split=0.1,
    img_size=32 , seed=7)

In [11]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

---

In [None]:
import torch.nn as nn
import timm
from src.training.train_full_model import * 
from src.training.metrics import * 
from src.training.eval_one_epoch_logs import *

model_maxvit = timm.create_model(
    "maxvit_tiny_tf_224",
    pretrained=False,
    num_classes=100,
    img_size=32 )

in_ch_1 = model_maxvit.stem.conv1.in_channels
out_ch_1 = model_maxvit.stem.conv1.out_channels
model_maxvit.stem.conv1 = nn.Conv2d(
    in_ch_1, out_ch_1, kernel_size=3, stride=1, padding=1, bias=False)

in_ch_2 = model_maxvit.stem.conv2.in_channels
out_ch_2 = model_maxvit.stem.conv2.out_channels
model_maxvit.stem.conv2 = nn.Conv2d(
    in_ch_2, out_ch_2, kernel_size=3, stride=1, padding=1, bias=False)


model_maxvit = model_maxvit.to(device)
n_params = sum(p.numel() for p in model_maxvit.parameters() if p.requires_grad)
print(f"MaxViT-Tiny Trainable parameters: {n_params/1e6:.2f}M")

MaxViT-Tiny Trainable parameters: 30.43M


In [None]:
device = 'cuda'

history, _ = train_model(
        model=model_maxvit,
        train_loader=train_loader,
        epochs=100,
        val_loader=val_loader,
        device=device,

        lr=5e-4,
        weight_decay=0.05,

        # Mixed Precision
        autocast_dtype="fp16" if device == "cuda" else "fp32",
        use_amp=(device == "cuda"),
        grad_clip_norm=1.0,

        warmup_ratio=0.05,
        min_lr=1e-6,

        label_smoothing=0.1,
        save_path= "best_model_resnet18.pt",
        last_path = "last_model_resnet18.pt",

        print_every=400,
        mix_prob=0.5,
        mixup_alpha=0.8,
        cutmix_alpha=1.0,

        num_classes=100,
        channels_last=True)

=== Run config ===
device=cuda | amp=True | autocast_dtype=fp16 | channels_last=True
epochs=100 | steps/epoch=704 | total_steps=70400 | warmup_steps=3520
batch_size=64 | input_shape=(64, 3, 32, 32) | num_classes=100
opt=AdamW | lr=0.0005 | wd=0.05 | grad_clip_norm=1.0
aug: mix_prob=0.5 | mixup_alpha=0.8 | cutmix_alpha=1.0 | label_smoothing=0.1

=== Epoch 1/100 ===
[train step 400/704] loss 4.4338 | top1 3.93% | top3 10.18% | top5 15.13% | 417.4 img/s | lr 5.68e-05 | gnorm 3.045 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[train step 704/704] loss 4.2960 | top1 5.86% | top3 14.12% | top5 20.20% | 422.2 img/s | lr 1.00e-04 | gnorm 3.602 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[Train] loss 4.2960 | top1 5.86% | top3 14.12% | top5 20.20% | lr 1.00e-04 | grad_norm 3.602 | clip 100.0% | amp_overflows 0 | nonfinite_loss 0 | scale 65536.0
[Train] mem_peak alloc 1.79 GiB | reserved 2.88 GiB
[Val]   loss 3.9453 | top1 9.42% | top3 21.40% | top5 29.58%
[Val]   mem_peak all

In [18]:
evaluate_one_epoch(model=model_maxvit, dataloader=test_loader)

(0.9612789264678955, {'top1': 75.92, 'top3': 89.35, 'top5': 92.95})

## MaxVit 

In [27]:
avg_loss, metrics = evaluate_one_epoch(
    model=model_maxvit,
    dataloader=test_loader,
    device="cuda",
    use_amp=True,
    autocast_dtype="fp16",     
    channels_last=False,      
    measure_flops=True,       
    flops_warmup_batches=1)

print(f"loss: {avg_loss:.4f}")
print(f"top1: {metrics['top1']:.2f} | top3: {metrics['top3']:.2f} | top5: {metrics['top5']:.2f}")

print(f"throughput: {metrics['imgs_per_sec']:.1f} imgs/s | epoch: {metrics['epoch_time_sec']:.2f}s | ms/batch: {metrics['ms_per_batch']:.2f}")
print(f"GPU mem: alloc={metrics['gpu_mem_allocated_mib']:.0f} MiB | reserved={metrics['gpu_mem_reserved_mib']:.0f} MiB | peak={metrics['gpu_mem_peak_allocated_mib']:.0f} MiB")
print(f"model: params={int(metrics['model_params']):,} | param_size={metrics['model_param_size_mib']:.1f} MiB")

# FLOPs/MACs pueden salir nan si no tienes fvcore/thop instalado o si falla el profiler
print(f"FLOPs/forward: {metrics['flops_per_forward']:.3e} | MACs/forward: {metrics['macs_per_forward']:.3e}")

loss: 0.9610
top1: 75.88 | top3: 89.41 | top5: 93.03
throughput: 1352.5 imgs/s | epoch: 7.39s | ms/batch: 46.16
GPU mem: alloc=956 MiB | reserved=3232 MiB | peak=1064 MiB
model: params=30,426,476 | param_size=116.1 MiB
FLOPs/forward: nan | MACs/forward: nan
