In [None]:
import sys, os

root_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root_path not in sys.path:
    sys.path.append(root_path)

In [None]:
from data.load_cifrar100 import *

train_loader, val_loader, test_loader = get_cifar100_dataloaders(
    batch_size=64,
    data_dir="./data",
    num_workers=2,
    val_split=0.1,
    img_size=32 , seed=7)

In [None]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

---

## Train Convnext

In [None]:
import torch.nn as nn
import timm
from src.training.train_full_model import * 
from src.training.metrics import * 
from src.training.eval_one_epoch_logs import *


model_convnext_tiny = timm.create_model(
    "convnext_tiny",
    pretrained=False,
    num_classes=100)

model_convnext_tiny.stem[0] = nn.Conv2d(
    in_channels=3, 
    out_channels=96, 
    kernel_size=2, 
    stride=2, 
    padding=0)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_convnext_tiny = model_convnext_tiny.to(device)


n_params = count_trainable_parameters(model_convnext_tiny)
print(f"Trainable parameters: {n_params:,}")


Trainable parameters: 27,893,572


In [22]:
history, _ = train_model(
        model=model_convnext_tiny,
        train_loader=train_loader,
        epochs=100,
        val_loader=val_loader,
        device=device,

        lr=5e-4,
        weight_decay=0.05,

        # Mixed Precision
        autocast_dtype="fp16" if device == "cuda" else "fp32",
        use_amp=(device == "cuda"),
        grad_clip_norm=1.0,

        warmup_ratio=0.05,
        min_lr=1e-6,

        label_smoothing=0.1,
        save_path= "best_model_resnet18.pt",
        last_path = "last_model_resnet18.pt",

        print_every=400,
        mix_prob=0.5,
        mixup_alpha=0.8,
        cutmix_alpha=1.0,

        num_classes=100,
        channels_last=True)

=== Run config ===
device=cuda | amp=True | autocast_dtype=fp16 | channels_last=True
epochs=100 | steps/epoch=704 | total_steps=70400 | warmup_steps=3520
batch_size=64 | input_shape=(64, 3, 32, 32) | num_classes=100
opt=AdamW | lr=0.0005 | wd=0.05 | grad_clip_norm=1.0
aug: mix_prob=0.5 | mixup_alpha=0.8 | cutmix_alpha=1.0 | label_smoothing=0.1

=== Epoch 1/100 ===
[train step 400/704] loss 4.5984 | top1 2.67% | top3 7.17% | top5 10.84% | 1027.0 img/s | lr 5.68e-05 | gnorm 5.965 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[train step 704/704] loss 4.5072 | top1 3.51% | top3 9.04% | top5 13.44% | 1066.3 img/s | lr 1.00e-04 | gnorm 5.535 | clip 100.0% | oflow 0 | nonfinite 0 | scale 65536.0
[Train] loss 4.5072 | top1 3.51% | top3 9.04% | top5 13.44% | lr 1.00e-04 | grad_norm 5.535 | clip 100.0% | amp_overflows 0 | nonfinite_loss 0 | scale 65536.0
[Train] mem_peak alloc 0.78 GiB | reserved 0.88 GiB
[Val]   loss 4.2599 | top1 6.08% | top3 13.56% | top5 19.30%
[Val]   mem_peak allo

---

In [None]:
avg_loss, metrics = evaluate_one_epoch(
    model=model_convnext_tiny,
    dataloader=test_loader,
    device="cuda",
    use_amp=True,
    autocast_dtype="fp16",     
    channels_last=False,      
    measure_flops=True,       
    flops_warmup_batches=1)

print(f"loss: {avg_loss:.4f}")
print(f"top1: {metrics['top1']:.2f} | top3: {metrics['top3']:.2f} | top5: {metrics['top5']:.2f}")

print(f"throughput: {metrics['imgs_per_sec']:.1f} imgs/s | epoch: {metrics['epoch_time_sec']:.2f}s | ms/batch: {metrics['ms_per_batch']:.2f}")
print(f"GPU mem: alloc={metrics['gpu_mem_allocated_mib']:.0f} MiB | reserved={metrics['gpu_mem_reserved_mib']:.0f} MiB | peak={metrics['gpu_mem_peak_allocated_mib']:.0f} MiB")
print(f"model: params={int(metrics['model_params']):,} | param_size={metrics['model_param_size_mib']:.1f} MiB")

print(f"FLOPs/forward: {metrics['flops_per_forward']:.3e} | MACs/forward: {metrics['macs_per_forward']:.3e}")

loss: 1.1613
top1: 72.60 | top3: 86.71 | top5: 90.92
throughput: 1832.8 imgs/s | epoch: 5.46s | ms/batch: 34.26
GPU mem: alloc=956 MiB | reserved=3232 MiB | peak=1125 MiB
model: params=27,895,012 | param_size=106.4 MiB
FLOPs/forward: nan | MACs/forward: nan
