In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SmallCNN(nn.Module):
    """
    Input:  [B, 3, 224, 224]
    Output: [B, 2] (logits for ['dog','cat'])
    """
    def __init__(self, num_classes: int = 2, p_drop: float = 0.2):
        super().__init__()
        def block(cin, cout):
            return nn.Sequential(
                nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(cout),
                nn.ReLU(inplace=True),
                nn.Conv2d(cout, cout, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(cout),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2)  # downsample by 2
            )

        self.features = nn.Sequential(
            block(3,   32),   # 224 -> 112
            nn.Dropout(p_drop),
            block(32,  64),   # 112 -> 56
            nn.Dropout(p_drop),
            block(64, 128),   # 56  -> 28
            nn.Dropout(p_drop),
            block(128, 256),  # 28  -> 14
            nn.Dropout(p_drop),
        )
        # global average pooling to 1×1
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(256, num_classes)

        # Kaiming init for convs
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.features(x)          # [B, 256, 14, 14]
        x = self.gap(x)               # [B, 256, 1, 1]
        x = torch.flatten(x, 1)       # [B, 256]
        x = self.classifier(x)        # [B, 2] (logits)
        return x

In [6]:
import torch, os
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision import transforms

# from models.cnn import SmallCNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Match FastAPI preprocessing: ToTensor -> [-1, 1]

train_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

val_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

# Directory layout:
# data/
#   train/
#     dog/*.jpg
#     cat/*.jpg
#   val/
#     dog/*.jpg
#     cat/*.jpg
train_ds = datasets.ImageFolder('data/train', transform=train_tf)
val_ds   = datasets.ImageFolder('data/val',   transform=val_tf)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

model = SmallCNN(num_classes=2).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)

best_acc, epochs = 0.0, 20
for epoch in range(1, epochs+1):
    model.train()
    running = 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running += loss.item() * x.size(0)

    # validation
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    val_acc = correct / total
    print(f"Epoch {epoch:02d} | train_loss={running/len(train_ds):.4f} | val_acc={val_acc:.4f}")

    # keep best
    if val_acc > best_acc:
        best_acc = val_acc
        os.makedirs('pts', exist_ok=True)
        # torch.save(model.state_dict(), 'pts/model.pt') 
        # torch.save(model, './pts/model.pt')
print(f"Best val_acc = {best_acc:.4f}")



Epoch 01 | train_loss=0.7680 | val_acc=0.5000
Epoch 02 | train_loss=0.5878 | val_acc=0.6000
Epoch 03 | train_loss=0.5763 | val_acc=0.5000
Epoch 04 | train_loss=0.5178 | val_acc=0.5000
Epoch 05 | train_loss=0.4887 | val_acc=0.5000
Epoch 06 | train_loss=0.4563 | val_acc=0.5000
Epoch 07 | train_loss=0.4105 | val_acc=0.5000
Epoch 08 | train_loss=0.3838 | val_acc=0.5000


KeyboardInterrupt: 

In [5]:
!pip install onnxscript
!pip install onnxruntime

Collecting onnxscript
  Downloading onnxscript-0.5.6-py3-none-any.whl.metadata (13 kB)
Collecting ml_dtypes (from onnxscript)
  Downloading ml_dtypes-0.5.3-cp313-cp313-macosx_10_13_universal2.whl.metadata (8.9 kB)
Collecting onnx_ir<2,>=0.1.12 (from onnxscript)
  Downloading onnx_ir-0.1.12-py3-none-any.whl.metadata (3.2 kB)
Collecting onnx>=1.16 (from onnxscript)
  Downloading onnx-1.19.1-cp313-cp313-macosx_12_0_universal2.whl.metadata (7.0 kB)
Collecting protobuf>=4.25.1 (from onnx>=1.16->onnxscript)
  Using cached protobuf-6.33.0-cp39-abi3-macosx_10_9_universal2.whl.metadata (593 bytes)
Downloading onnxscript-0.5.6-py3-none-any.whl (683 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m683.0/683.0 kB[0m [31m8.2 MB/s[0m  [33m0:00:00[0m
[?25hDownloading onnx_ir-0.1.12-py3-none-any.whl (129 kB)
Downloading ml_dtypes-0.5.3-cp313-cp313-macosx_10_13_universal2.whl (663 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m663.8/663.8 kB[0m [31m15.7 MB/s

In [None]:
import torch.onnx

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "./pts/model_from_torch.onnx")