In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
IMAGE_SIZE = 224
BATCH_SIZE = 16
NUM_WORKERS = 2
MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]
LEARNING_RATE = 1e-3
NUM_EPOCHS = 30

In [3]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

val_transforms = transforms.Compose([
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [4]:
train_dataset_path = '../../../dataset/Train'
val_dataset_path = '../../../dataset/val'

train_dataset = datasets.ImageFolder(train_dataset_path, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dataset_path, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Classes detected: {train_dataset.classes}")

Classes detected: ['algal_spot', 'brown_blight', 'gray_blight', 'healthy', 'helopeltis', 'red-rust', 'red-spider-infested', 'red_spot', 'white-spot']


In [5]:
model = models.densenet169(weights=models.DenseNet169_Weights.IMAGENET1K_V1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [7]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels, in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)
    return running_loss/total, correct/total

In [8]:
def validate(model, loader, device, criterion):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)
    return val_loss/total, correct/total

In [9]:
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, device, criterion)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Val loss {val_loss:.4f}, acc {train_acc:.4f}"
          )
    
    if(val_loss < best_val_loss):
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_densenet169.pth")

Total batches: 290
Epoch 1/30: Train loss 1.0298, acc 0.6626 | Val   loss 0.5441, acc 0.8444
Epoch 2/30: Train loss 0.6166, acc 0.7912 | Val   loss 0.4156, acc 0.8727
Epoch 3/30: Train loss 0.5384, acc 0.8173 | Val   loss 0.3794, acc 0.8768
Epoch 4/30: Train loss 0.5196, acc 0.8150 | Val   loss 0.3574, acc 0.8798
Epoch 5/30: Train loss 0.4929, acc 0.8262 | Val   loss 0.3035, acc 0.8949
Epoch 6/30: Train loss 0.4733, acc 0.8325 | Val   loss 0.2977, acc 0.8980
Epoch 7/30: Train loss 0.4800, acc 0.8271 | Val   loss 0.2870, acc 0.9081
Epoch 8/30: Train loss 0.4675, acc 0.8411 | Val   loss 0.2796, acc 0.9040
Epoch 9/30: Train loss 0.4512, acc 0.8413 | Val   loss 0.2634, acc 0.9061
Epoch 10/30: Train loss 0.4617, acc 0.8331 | Val   loss 0.2728, acc 0.9051
Epoch 11/30: Train loss 0.4520, acc 0.8398 | Val   loss 0.2730, acc 0.9091
Epoch 12/30: Train loss 0.4438, acc 0.8470 | Val   loss 0.2556, acc 0.9071
Epoch 13/30: Train loss 0.4522, acc 0.8420 | Val   loss 0.2813, acc 0.8970
Epoch 14/30: Tr