In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),  
    transforms.ToTensor(),      
])

### PlantVillage

In [4]:
dataset_dir = "W:/PlantVillage"
dataset = datasets.ImageFolder(root=dataset_dir, transform=transform)

In [5]:
train_indices, test_indices = train_test_split(range(len(dataset)), test_size=0.2, random_state=42, stratify=dataset.targets)

train_subset = Subset(dataset, train_indices)
test_subset = Subset(dataset, test_indices)

In [6]:
train_dataloader = DataLoader(train_subset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_subset, batch_size=64, shuffle=False)

### CNN

In [7]:
class CNN(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        # Global average pooling instead of large fully connected layers
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)
        
        # Smaller fully connected layers
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 128x128 -> 64x64
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 64x64 -> 32x32
        x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 32x32 -> 16x16
        x = self.pool(F.relu(self.bn4(self.conv4(x))))  # 16x16 -> 8x8
        
        # Global average pooling (reduces to 256x1x1)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        
        # Fully connected layers with dropout
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        
        return x

In [8]:
model = CNN().to(device)
print(model)

CNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (global_pool): AdaptiveAvgPool2d(output_size=1)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=5, bias=True)
)


In [9]:
learning_rate = 0.001
batch_size = 64
# momentum = 0.9
epochs = 60
patience = 10

In [10]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    acc = f"{(100*correct):>0.1f}"
    print(f"Test Error: \n Accuracy: {acc}%, Avg loss: {test_loss:>8f} \n")

    return test_loss, correct, acc

In [11]:
loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
# Early stopping variables
best_loss = float('inf')
trigger_times = 0
best_model_weights = model.state_dict()
best_epoch = 0
best_acc = ""

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loss, correct, acc = test_loop(test_dataloader, model, loss_fn)

    if test_loss < best_loss:
        best_loss = test_loss
        trigger_times = 0
        best_model_weights = model.state_dict()
        best_epoch = t+1
        best_acc = acc
        # torch.save(best_model_weights, 'best_model.pth')
        # print(f"New best model saved at epoch {t+1}.")
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print(f"Early stopping at epoch {t+1}. No improvement in {patience} epochs.")
            break

# model.load_state_dict(torch.load('best_model.pth')) # load best model weights
print(f"Best epoch at {best_epoch} with accuracy {best_acc} and loss {best_loss}")

Epoch 1
-------------------------------
loss: 1.739413  [   64/ 5600]
Test Error: 
 Accuracy: 77.6%, Avg loss: 0.580688 

Epoch 2
-------------------------------
loss: 0.327140  [   64/ 5600]
Test Error: 
 Accuracy: 85.2%, Avg loss: 0.358371 

Epoch 3
-------------------------------
loss: 0.200563  [   64/ 5600]
Test Error: 
 Accuracy: 71.1%, Avg loss: 1.054172 

Epoch 4
-------------------------------
loss: 0.213917  [   64/ 5600]
Test Error: 
 Accuracy: 80.7%, Avg loss: 0.561026 

Epoch 5
-------------------------------
loss: 0.278625  [   64/ 5600]
Test Error: 
 Accuracy: 92.8%, Avg loss: 0.184925 

Epoch 6
-------------------------------
loss: 0.243889  [   64/ 5600]
Test Error: 
 Accuracy: 77.4%, Avg loss: 0.950621 

Epoch 7
-------------------------------
loss: 0.193263  [   64/ 5600]
Test Error: 
 Accuracy: 94.1%, Avg loss: 0.146496 

Epoch 8
-------------------------------
loss: 0.052567  [   64/ 5600]
Test Error: 
 Accuracy: 83.2%, Avg loss: 0.417445 

Epoch 9
----------------

### PlantDoc

In [13]:
dataset_dir = "W:/PlantDoc"
dataset = datasets.ImageFolder(root=dataset_dir, transform=transform)

In [14]:
train_indices, test_indices = train_test_split(range(len(dataset)), test_size=0.2, random_state=42, stratify=dataset.targets)

train_subset = Subset(dataset, train_indices)
test_subset = Subset(dataset, test_indices)

In [15]:
train_dataloader = DataLoader(train_subset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_subset, batch_size=64, shuffle=False)

In [16]:
# Early stopping variables
best_loss = float('inf')
trigger_times = 0
best_model_weights = model.state_dict()
best_epoch = 0
best_acc = ""

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loss, correct, acc = test_loop(test_dataloader, model, loss_fn)

    if test_loss < best_loss:
        best_loss = test_loss
        trigger_times = 0
        best_model_weights = model.state_dict()
        best_epoch = t+1
        best_acc = acc
        # torch.save(best_model_weights, 'best_model.pth')
        # print(f"New best model saved at epoch {t+1}.")
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print(f"Early stopping at epoch {t+1}. No improvement in {patience} epochs.")
            break

# model.load_state_dict(torch.load('best_model.pth')) # load best model weights
print(f"Best epoch at {best_epoch} with accuracy {best_acc} and loss {best_loss}")

Epoch 1
-------------------------------
loss: 6.593884  [   64/  304]
Test Error: 
 Accuracy: 32.9%, Avg loss: 8.682895 

Epoch 2
-------------------------------
loss: 2.289258  [   64/  304]
Test Error: 
 Accuracy: 25.0%, Avg loss: 2.581922 

Epoch 3
-------------------------------
loss: 1.445219  [   64/  304]
Test Error: 
 Accuracy: 26.3%, Avg loss: 1.619481 

Epoch 4
-------------------------------
loss: 1.472059  [   64/  304]
Test Error: 
 Accuracy: 30.3%, Avg loss: 1.602612 

Epoch 5
-------------------------------
loss: 1.458898  [   64/  304]
Test Error: 
 Accuracy: 31.6%, Avg loss: 1.552876 

Epoch 6
-------------------------------
loss: 1.379739  [   64/  304]
Test Error: 
 Accuracy: 27.6%, Avg loss: 1.546756 

Epoch 7
-------------------------------
loss: 1.384405  [   64/  304]
Test Error: 
 Accuracy: 27.6%, Avg loss: 1.531502 

Epoch 8
-------------------------------
loss: 1.459602  [   64/  304]
Test Error: 
 Accuracy: 23.7%, Avg loss: 1.506362 

Epoch 9
----------------