In [None]:
from data.load_cifrar100 import *

train_loader, val_loader, test_loader = get_cifar100_dataloaders(
    batch_size=128,
    data_dir="./data/cifar100",
    num_workers=2,val_split=0.1)

100%|██████████| 169M/169M [00:05<00:00, 29.6MB/s]


In [None]:
from data.data_utils import *

describe_loader(train_loader, "train_loader", max_batches_for_stats=50)


TRAIN_LOADER SUMMARY
Dataset type        : Subset
  ↳ Wrapped dataset  : CIFAR100 (Subset-like)
  ↳ Subset size      : 45000
Num samples         : 45000
Batch size          : 128
Num workers         : 2
Pin memory          : True
Drop last           : False
Sampler             : RandomSampler
len(loader) (#batches): 352 (≈ ceil(45000/128) = 352)

First batch:
  x.shape           : (128, 3, 32, 32)
  y.shape           : (128,)
  x.dtype           : torch.float32
  y.dtype           : torch.int64
  x.min/max         : -3.3577 / 3.3945
  y.min/max         : 0 / 98
  unique labels (batch): 75

Quick stats over up to 50 batches:
  Approx mean        : -0.290280
  Approx std         : 1.133057
  Seen label counts  : 100 classes (in sampled batches)
  Top-5 labels       : [(35, 80), (63, 80), (73, 79), (96, 78), (44, 76)]

Full dataset label distribution:
  #classes detected  : 100
  min/max per class  : 436 / 463
  first 10 classes   : [(0, 457), (1, 439), (2, 448), (3, 455), (4, 446), (5, 

In [None]:
describe_loader(val_loader, "val_loader", max_batches_for_stats=50)


VAL_LOADER SUMMARY
Dataset type        : Subset
  ↳ Wrapped dataset  : CIFAR100 (Subset-like)
  ↳ Subset size      : 5000
Num samples         : 5000
Batch size          : 128
Num workers         : 2
Pin memory          : True
Drop last           : False
Sampler             : SequentialSampler
len(loader) (#batches): 40 (≈ ceil(5000/128) = 40)

First batch:
  x.shape           : (128, 3, 32, 32)
  y.shape           : (128,)
  x.dtype           : torch.float32
  y.dtype           : torch.int64
  x.min/max         : -3.9538 / 3.9404
  y.min/max         : 0 / 99
  unique labels (batch): 75

Quick stats over up to 50 batches:
  Approx mean        : -0.277020
  Approx std         : 1.130208
  Seen label counts  : 100 classes (in sampled batches)
  Top-5 labels       : [(38, 64), (1, 61), (90, 61), (78, 60), (40, 60)]

Full dataset label distribution:
  #classes detected  : 100
  min/max per class  : 37 / 64
  first 10 classes   : [(0, 43), (1, 61), (2, 52), (3, 45), (4, 54), (5, 53), (6, 49

In [None]:
from model.max_vit_stem import *

x = torch.randn(4, 3, 32, 32)

yA = MaxViTStem(StemConfig(stem_type="A", out_ch=64))(x)
print("A:", yA.shape)

yB = MaxViTStem(StemConfig(stem_type="B", out_ch=64))(x)
print("B:", yB.shape)

A: torch.Size([4, 64, 32, 32])
B: torch.Size([4, 64, 32, 32])


In [10]:
def test_stem_A_shape():
    x = torch.randn(2, 3, 32, 32)
    stem = MaxViTStem(StemConfig(stem_type="A", out_ch=64))
    y = stem(x)
    assert y.shape == (2, 64, 32, 32)


def test_stem_B_shape():
    x = torch.randn(2, 3, 32, 32)
    stem = MaxViTStem(StemConfig(stem_type="B", out_ch=64))
    y = stem(x)
    assert y.shape == (2, 64, 32, 32)

test_stem_A_shape()
test_stem_B_shape()

In [None]:
from model.downsample import *

x = torch.randn(2, 64, 32, 32)
y = Downsample(64, 128)(x)
print(y.shape)

torch.Size([2, 128, 16, 16])


In [12]:
def test_downsample_conv_shape():
    x = torch.randn(2, 64, 32, 32)
    y = Downsample(64, 128, DownsampleConfig(kind="conv"))(x)
    assert y.shape == (2, 128, 16, 16)


def test_downsample_pool_shape():
    x = torch.randn(2, 64, 32, 32)
    y = Downsample(64, 128, DownsampleConfig(kind="pool"))(x)
    assert y.shape == (2, 128, 16, 16)

test_downsample_conv_shape()
test_downsample_pool_shape()

In [None]:
from model.MBConv import *

x = torch.randn(2, 64, 32, 32)
b1 = MBConv(64, 64, stride=1, cfg=MBConvConfig(expand_ratio=4.0, se_ratio=0.25, drop_path=0.1))
y1 = b1(x)
print("stride1:", y1.shape)

b2 = MBConv(64, 128, stride=2, cfg=MBConvConfig(expand_ratio=4.0, se_ratio=0.25))
y2 = b2(x)
print("stride2:", y2.shape)

stride1: torch.Size([2, 64, 32, 32])
stride2: torch.Size([2, 128, 16, 16])


In [14]:

def test_mbconv_stride1_residual_shape():
    x = torch.randn(2, 64, 32, 32)
    m = MBConv(64, 64, stride=1, cfg=MBConvConfig(expand_ratio=4.0, se_ratio=0.25, drop_path=0.0))
    y = m(x)
    assert y.shape == (2, 64, 32, 32)


def test_mbconv_stride2_downsample_shape():
    x = torch.randn(2, 64, 32, 32)
    m = MBConv(64, 128, stride=2, cfg=MBConvConfig(expand_ratio=4.0, se_ratio=0.25))
    y = m(x)
    assert y.shape == (2, 128, 16, 16)


def test_mbconv_no_expand():
    x = torch.randn(2, 64, 32, 32)
    m = MBConv(64, 64, stride=1, cfg=MBConvConfig(expand_ratio=1.0, se_ratio=0.25))
    y = m(x)
    assert y.shape == (2, 64, 32, 32)

In [None]:
from model.window_partition import * 
from model.grid_partition import *

def test_window_roundtrip_exact():
    torch.manual_seed(0)
    B, H, W, C = 2, 32, 32, 64
    ws = 4
    x = torch.randn(B, H, W, C)

    windows = window_partition(x, ws)
    x_rec = window_unpartition(windows, ws, H=H, W=W, B=B)

    assert x_rec.shape == x.shape
    assert torch.equal(x_rec, x)


def test_grid_roundtrip_exact():
    torch.manual_seed(0)
    B, H, W, C = 2, 32, 32, 64
    g = 4
    x = torch.randn(B, H, W, C)

    grids, meta = grid_partition(x, g)
    x_rec = grid_unpartition(grids, meta)

    assert x_rec.shape == x.shape
    assert torch.equal(x_rec, x)


def test_window_invalid_divisibility_raises():
    x = torch.randn(1, 30, 32, 8)
    try:
        _ = window_partition(x, 4)
        assert False, "Expected ValueError"
    except ValueError:
        pass


def test_grid_invalid_divisibility_raises():
    x = torch.randn(1, 32, 30, 8)
    try:
        _ = grid_partition(x, 4)
        assert False, "Expected ValueError"
    except ValueError:
        pass

In [18]:
test_window_roundtrip_exact()
test_grid_roundtrip_exact()
test_window_invalid_divisibility_raises()
test_grid_invalid_divisibility_raises()

In [None]:
from model.attention import *

def test_mhsa_shape():
    torch.manual_seed(0)
    B, N, C = 8, 16, 64
    x = torch.randn(B, N, C)
    attn = MultiHeadSelfAttention(AttentionConfig(dim=C, num_heads=8))
    y = attn(x)
    assert y.shape == x.shape


def test_mhsa_backward():
    torch.manual_seed(0)
    B, N, C = 4, 49, 96
    x = torch.randn(B, N, C, requires_grad=True)
    attn = MultiHeadSelfAttention(AttentionConfig(dim=C, num_heads=8, attn_drop=0.1, proj_drop=0.1))
    y = attn(x).sum()
    y.backward()
    assert x.grad is not None

test_mhsa_shape()
test_mhsa_backward()

In [None]:
from model.local_attention import *

def test_local_attention_window_shape():
    torch.manual_seed(0)
    B, H, W, C = 2, 32, 32, 64
    x = torch.randn(B, H, W, C)
    m = LocalAttention2D(LocalAttention2DConfig(mode="window", dim=C, num_heads=8, window_size=4))
    y = m(x)
    assert y.shape == x.shape


def test_local_attention_grid_shape():
    torch.manual_seed(0)
    B, H, W, C = 2, 32, 32, 64
    x = torch.randn(B, H, W, C)
    m = LocalAttention2D(LocalAttention2DConfig(mode="grid", dim=C, num_heads=8, grid_size=4))
    y = m(x)
    assert y.shape == x.shape


def test_local_attention_backward():
    torch.manual_seed(0)
    B, H, W, C = 2, 16, 16, 32
    x = torch.randn(B, H, W, C, requires_grad=True)
    m = LocalAttention2D(LocalAttention2DConfig(mode="window", dim=C, num_heads=4, window_size=4))
    y = m(x).sum()
    y.backward()
    assert x.grad is not None

test_local_attention_window_shape()
test_local_attention_grid_shape()
test_local_attention_backward()

In [None]:
from model.MaxViT_block import *

def test_maxvit_block_shape():
    torch.manual_seed(0)
    B, C, H, W = 2, 64, 32, 32
    x = torch.randn(B, C, H, W)

    block = MaxViTBlock(
        MaxViTBlockConfig(
            dim=C,
            num_heads=8,
            window_size=4,
            grid_size=4,
            drop_path=0.1,))

    y = block(x)
    assert y.shape == x.shape


def test_maxvit_block_backward():
    torch.manual_seed(0)
    B, C, H, W = 2, 32, 16, 16
    x = torch.randn(B, C, H, W, requires_grad=True)

    block = MaxViTBlock(
        MaxViTBlockConfig(
            dim=C,
            num_heads=4,
            window_size=4,
            grid_size=4,
            drop_path=0.0))

    y = block(x).sum()
    y.backward()
    assert x.grad is not None

test_maxvit_block_shape()
test_maxvit_block_backward()

In [None]:
from model_configurations import *


model =  maxvit_cifar100_tiny()
print(model)

MaxViTConfig(num_classes=100, in_chans=3, stem_type='A', stem_out_ch=64, stem_act='silu', stem_use_bn=True, stem_mid_ch=None, dims=(64, 128, 256, 512), depths=(2, 2, 3, 2), heads=(2, 4, 8, 16), window_size=4, grid_size=4, drop_path_rate=0.1, attn_drop=0.0, proj_drop=0.0, ffn_drop=0.0, mbconv_expand_ratio=4.0, mbconv_se_ratio=0.25, mbconv_act='silu', use_bn=True, mlp_ratio=4.0, mlp_act='gelu', downsample_kind='conv', downsample_act='silu', downsample_use_bn=True)


In [None]:

def test_maxvit_forward_shape():
    torch.manual_seed(0)
    cfg = maxvit_cifar100_tiny(stem_type="A", drop_path_rate=0.0)
    model = MaxViT(cfg)

    x = torch.randn(2, 3, 32, 32)
    y = model(x)
    assert y.shape == (2, 100)


def test_maxvit_backward():
    torch.manual_seed(0)
    cfg = maxvit_cifar100_tiny(stem_type="B", drop_path_rate=0.1)
    model = MaxViT(cfg)

    x = torch.randn(2, 3, 32, 32, requires_grad=True)
    y = model(x).sum()
    y.backward()
    assert x.grad is not None

test_maxvit_forward_shape()
test_maxvit_backward()

In [None]:
from model.model_utils import *

cfg = maxvit_cifar100_tiny(stem_type="A", drop_path_rate=0.0)
model = MaxViT(cfg)
hooks = attach_shape_hooks(model)


In [None]:
x = torch.randn(2, 3, 32, 32)

with torch.no_grad():
    y = model(x)

print("\nFinal logits:", y.shape)

for h in hooks:
    h.remove()


stem                 -> (2, 64, 32, 32)
stage0               -> (2, 64, 32, 32)
down0                -> (2, 128, 16, 16)
stage1               -> (2, 128, 16, 16)
down1                -> (2, 256, 8, 8)
stage2               -> (2, 256, 8, 8)
down2                -> (2, 512, 4, 4)
stage3               -> (2, 512, 4, 4)
pool                 -> (2, 512, 1, 1)
head                 -> (2, 100)

Final logits: torch.Size([2, 100])


---

## Training 

In [None]:
from training.train_MaxViT import *

cfg = maxvit_cifar100_tiny(stem_type="A", drop_path_rate=0.1)
model = MaxViT(cfg)

device = "cuda" if torch.cuda.is_available() else "cpu"

history, model = train_model(
    model=model,
    train_loader=train_loader,
    epochs=20,
    val_loader=val_loader,
    device=device,
    lr=5e-4,
    weight_decay=0.05,
    autocast_dtype="fp16" if device == "cuda" else "fp32",
    use_amp=(device == "cuda"),
    grad_clip_norm=1.0,
    warmup_ratio=0.05,
    min_lr=1e-6,
    label_smoothing=0.0,
    print_every=100,
    save_path="best_maxvit_cifar100.pt",
    last_path="last_maxvit_cifar100.pt",
    resume_path=None,
    ema_decay=0.999,
    ema_device=None,
    mix_prob=0.5 , mixup_alpha=0.0 , cutmix_alpha=1.0)



=== Epoch 1/20 ===
[train step 100/352] loss 4.5360 | top1 3.62% | top3 8.85% | top5 12.80% | 315.0 img/s | lr 1.42e-04
[train step 200/352] loss 4.3959 | top1 5.18% | top3 12.29% | top5 17.45% | 339.3 img/s | lr 2.84e-04
[train step 300/352] loss 4.2949 | top1 6.47% | top3 14.70% | top5 20.60% | 350.0 img/s | lr 4.26e-04
[Train] loss 4.2447 | top1 7.14% | top3 15.98% | top5 22.22%
[Val] loss 3.9775 | top1 10.70% | top3 22.84% | top5 30.76%
 Best saved to best_maxvit_cifar100.pt (val top1 10.70%)
Epoch time: 2.40 min

=== Epoch 2/20 ===
[train step 100/352] loss 3.7871 | top1 13.84% | top3 28.51% | top5 37.67% | 366.7 img/s | lr 5.00e-04
[train step 200/352] loss 3.7107 | top1 15.27% | top3 30.54% | top5 39.96% | 369.4 img/s | lr 4.99e-04
[train step 300/352] loss 3.6326 | top1 16.90% | top3 32.64% | top5 42.11% | 370.3 img/s | lr 4.98e-04
[Train] loss 3.5937 | top1 17.56% | top3 33.64% | top5 43.20%
[Val] loss 3.0162 | top1 24.52% | top3 44.54% | top5 56.02%
 Best saved to best_maxvi

In [48]:
history.keys()

dict_keys(['train_loss', 'train_top1', 'train_top3', 'train_top5', 'val_loss', 'val_top1', 'val_top3', 'val_top5'])