In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SmallCNN(nn.Module):
    """
    Input:  [B, 3, 224, 224]
    Output: [B, 2] (logits for ['dog','cat'])
    """
    def __init__(self, num_classes: int = 2, p_drop: float = 0.2):
        super().__init__()
        def block(cin, cout):
            return nn.Sequential(
                nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(cout),
                nn.ReLU(inplace=True),
                nn.Conv2d(cout, cout, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(cout),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2)  # downsample by 2
            )

        self.features = nn.Sequential(
            block(3,   32),   # 224 -> 112
            nn.Dropout(p_drop),
            block(32,  64),   # 112 -> 56
            nn.Dropout(p_drop),
            block(64, 128),   # 56  -> 28
            nn.Dropout(p_drop),
            block(128, 256),  # 28  -> 14
            nn.Dropout(p_drop),
        )
        # global average pooling to 1×1
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(256, num_classes)

        # Kaiming init for convs
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.features(x)          # [B, 256, 14, 14]
        x = self.gap(x)               # [B, 256, 1, 1]
        x = torch.flatten(x, 1)       # [B, 256]
        x = self.classifier(x)        # [B, 2] (logits)
        return x

In [5]:
import torch, os
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision import transforms

# from models.cnn import SmallCNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Match FastAPI preprocessing: ToTensor -> [-1, 1]

train_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

val_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

# Directory layout:
# data/
#   train/
#     dog/*.jpg
#     cat/*.jpg
#   val/
#     dog/*.jpg
#     cat/*.jpg
train_ds = datasets.ImageFolder('data/train', transform=train_tf)
val_ds   = datasets.ImageFolder('data/val',   transform=val_tf)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

model = SmallCNN(num_classes=2).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)

best_acc, epochs = 0.0, 20
for epoch in range(1, epochs+1):
    model.train()
    running = 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running += loss.item() * x.size(0)

    # validation
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    val_acc = correct / total
    print(f"Epoch {epoch:02d} | train_loss={running/len(train_ds):.4f} | val_acc={val_acc:.4f}")

    # keep best
    if val_acc > best_acc:
        best_acc = val_acc
        os.makedirs('pts', exist_ok=True)
        # torch.save(model.state_dict(), 'pts/model.pt') 
        # torch.save(model, './pts/model.pt')
print(f"Best val_acc = {best_acc:.4f}")



Epoch 01 | train_loss=0.7056 | val_acc=0.5000
Epoch 02 | train_loss=0.5527 | val_acc=0.6000
Epoch 03 | train_loss=0.5290 | val_acc=0.5000
Epoch 04 | train_loss=0.5125 | val_acc=0.5000
Epoch 05 | train_loss=0.4686 | val_acc=0.5000
Epoch 06 | train_loss=0.4505 | val_acc=0.5000
Epoch 07 | train_loss=0.3688 | val_acc=0.5000
Epoch 08 | train_loss=0.3661 | val_acc=0.5000
Epoch 09 | train_loss=0.3525 | val_acc=0.5000
Epoch 10 | train_loss=0.3184 | val_acc=0.5000
Epoch 11 | train_loss=0.3839 | val_acc=0.5000
Epoch 12 | train_loss=0.3422 | val_acc=0.5000
Epoch 13 | train_loss=0.3023 | val_acc=0.5000
Epoch 14 | train_loss=0.2680 | val_acc=0.6000
Epoch 15 | train_loss=0.2188 | val_acc=0.6000
Epoch 16 | train_loss=0.2910 | val_acc=0.6000
Epoch 17 | train_loss=0.2999 | val_acc=0.6000
Epoch 18 | train_loss=0.1917 | val_acc=0.6000
Epoch 19 | train_loss=0.1980 | val_acc=0.6000
Epoch 20 | train_loss=0.2278 | val_acc=0.6000
Best val_acc = 0.6000


In [6]:
!pip install onnxscript
!pip install onnxruntime



In [7]:
import torch.onnx

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "./pts/model_from_torch.onnx")

[torch.onnx] Obtain model graph for `SmallCNN([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `SmallCNN([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 16 of general pattern rewrite rules.


ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'': 20},
            producer_name='pytorch',
            producer_version='2.9.0',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT,[1,3,224,224]>
            ),
            outputs=(
                %"linear"<FLOAT,[1,2]>
            ),
            initializers=(
                %"features.0.0.weight"<FLOAT,[32,3,3,3]>{Tensor(...)},
                %"features.0.3.weight"<FLOAT,[32,32,3,3]>{Tensor(...)},
                %"features.2.0.weight"<FLOAT,[64,32,3,3]>{Tensor(...)},
                %"features.2.3.weight"<FLOAT,[64,64,3,3]>{Tensor(...)},
                %"features.4.0.weight"<FLOAT,[128,64,3,3]>{Tensor(...)},
                %"features.4.3.weight"<FLOAT,[128,128,3,3]>{Tensor(...)},
                %"features.6.0.weight"<FLOAT,[256,128,3,3]>{Tensor(...)},
                %"features.6