In [18]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [19]:
IMAGE_SIZE = 224
BATCH_SIZE = 16
NUM_WORKERS = 2
MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]

In [20]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)), # random crop & resize
    transforms.RandomHorizontalFlip(),                          # horizontal flip
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),                              # small rotations
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # color jitter
    transforms.ToTensor(),                                      # to C×H×W tensor
    transforms.Normalize(MEAN, STD)                             # normalize
])

val_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(IMAGE_SIZE),          # crop to IMAGE_SIZE×IMAGE_SIZE
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [21]:
train_dataset = datasets.ImageFolder(root='../../dataset-dapa/train/', transform=train_transforms)
val_dataset   = datasets.ImageFolder(root='../../dataset-dapa/val/',   transform=val_transforms)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [22]:
if __name__ == "__main__":
    images, labels = next(iter(train_loader))
    print(f"Batch shape: {images.shape}")  # e.g. [32, 3, 224, 224]
    print(f"Labels shape: {labels.shape}")  # e.g. [32]

model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1)
backbone = model.features
backbone.requires_grad_(False)

num_classes = 9  # your nine tea disease classes
classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.classifier[0].in_features, num_classes)
)
model.classifier = classifier

Batch shape: torch.Size([16, 3, 224, 224])
Labels shape: torch.Size([16])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)   # Train head with higher LR

In [24]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)
    return running_loss/total, correct/total

def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)
    return val_loss/total, correct/total


In [None]:
num_epochs = 30
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Val   loss {val_loss:.4f}, acc {val_acc:.4f}")

    # Checkpoint best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_mobilenetv3_teadiseases.pth")


Epoch 1/30: Train loss 1.2385, acc 0.6042 | Val   loss 0.7443, acc 0.7889
Epoch 2/30: Train loss 0.8084, acc 0.7391 | Val   loss 0.5687, acc 0.8313
Epoch 3/30: Train loss 0.7129, acc 0.7585 | Val   loss 0.5218, acc 0.8253
Epoch 4/30: Train loss 0.6464, acc 0.7812 | Val   loss 0.4665, acc 0.8465
Epoch 5/30: Train loss 0.6276, acc 0.7817 | Val   loss 0.4285, acc 0.8586
Epoch 6/30: Train loss 0.6010, acc 0.7877 | Val   loss 0.4396, acc 0.8485
Epoch 7/30: Train loss 0.6053, acc 0.7884 | Val   loss 0.4380, acc 0.8434
Epoch 8/30: Train loss 0.5847, acc 0.7940 | Val   loss 0.4111, acc 0.8475
Epoch 9/30: Train loss 0.5684, acc 0.8018 | Val   loss 0.4061, acc 0.8495
Epoch 10/30: Train loss 0.5552, acc 0.8046 | Val   loss 0.4007, acc 0.8515
Epoch 11/30: Train loss 0.5586, acc 0.8044 | Val   loss 0.3749, acc 0.8717
Epoch 12/30: Train loss 0.5412, acc 0.8085 | Val   loss 0.3674, acc 0.8707
Epoch 13/30: Train loss 0.5429, acc 0.8141 | Val   loss 0.3811, acc 0.8646
Epoch 14/30: Train loss 0.5566, ac