# Train ResNet for Pest Classification

This template assumes your data is in `datasets/pest/images/{train,val}` with class subfolders.


In [None]:
import os
from pathlib import Path

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

DATA_ROOT = Path('datasets/pest/images')
BATCH_SIZE = 32
EPOCHS = 5
LR = 1e-3
NUM_WORKERS = 2
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_ds = datasets.ImageFolder(DATA_ROOT / 'train', transform=train_tfms)
val_ds = datasets.ImageFolder(DATA_ROOT / 'val', transform=val_tfms)

num_classes = len(train_ds.classes)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)



In [None]:
def evaluate(model, dataloader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    with torch.inference_mode():
        for x, y in dataloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            loss = criterion(logits, y)
            loss_sum += float(loss.item()) * y.size(0)
            preds = logits.argmax(dim=1)
            correct += int((preds == y).sum().item())
            total += int(y.size(0))
    return loss_sum / total, correct / total

best_acc = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train()
    for x, y in train_dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

    val_loss, val_acc = evaluate(model, val_dl)
    print(f"Epoch {epoch}: val_loss={val_loss:.4f} val_acc={val_acc:.4f}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'backend/models/resnet_best_state_dict.pkl')
        with open('backend/models/classes.txt', 'w', encoding='utf-8') as f:
            for cls in train_ds.classes:
                f.write(cls + '\n')

print('Best Val Acc:', best_acc)

