In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
IMAGE_SIZE = 224
BATCH_SIZE = 16
NUM_WORKERS = 2
MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]
LEARNING_RATE = 1e-3
NUM_EPOCHS = 30

In [3]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

val_transforms = transforms.Compose([
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [4]:
train_dataset_path = '../../../dataset/Train'
val_dataset_path = '../../../dataset/val'

train_dataset = datasets.ImageFolder(train_dataset_path, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dataset_path, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Classes detected: {train_dataset.classes}")

Classes detected: ['algal_spot', 'brown_blight', 'gray_blight', 'healthy', 'helopeltis', 'red-rust', 'red-spider-infested', 'red_spot', 'white-spot']


In [5]:
model = models.densenet169(weights=models.DenseNet169_Weights.IMAGENET1K_V1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [7]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels, in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)
    return running_loss/total, correct/total

In [8]:
def validate(model, loader, device, criterion):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)
    return val_loss/total, correct/total

In [9]:
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, device, criterion)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Val loss {val_loss:.4f}, acc {train_acc:.4f}"
          )
    
    if(val_loss < best_val_loss):
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_densenet169.pth")

KeyboardInterrupt: 