In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

train_tf = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])


val_tf = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])


In [3]:
from torchvision.datasets import ImageFolder

train_ds = ImageFolder("dataset/train", transform=train_tf)
val_ds   = ImageFolder("dataset/valid", transform=val_tf)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False)
print(train_ds.classes)

['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Sp

In [4]:
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = x.flatten(1)
        return self.fc(x)


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model = CNN(num_classes=len(train_ds.classes)).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


cuda


In [6]:
for epoch in range(10):
    # ---- TRAIN ----
    model.train()
    train_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # ---- VALIDATION ----
    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)

            val_loss += loss.item()
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    val_acc = correct / total

    print(
        f"Epoch {epoch+1} | "
        f"train_loss={train_loss/len(train_loader):.4f} | "
        f"val_loss={val_loss/len(val_loader):.4f} | "
        f"val_acc={val_acc:.3f}"
    )


Epoch 1 | train_loss=1.8938 | val_loss=1.3466 | val_acc=0.595
Epoch 2 | train_loss=1.1335 | val_loss=1.0579 | val_acc=0.685
Epoch 3 | train_loss=0.9113 | val_loss=0.8552 | val_acc=0.747
Epoch 4 | train_loss=0.7838 | val_loss=0.7735 | val_acc=0.770
Epoch 5 | train_loss=0.6948 | val_loss=0.6821 | val_acc=0.796
Epoch 6 | train_loss=0.6211 | val_loss=0.5890 | val_acc=0.828
Epoch 7 | train_loss=0.5603 | val_loss=0.5662 | val_acc=0.829
Epoch 8 | train_loss=0.5091 | val_loss=0.5170 | val_acc=0.845
Epoch 9 | train_loss=0.4710 | val_loss=0.4563 | val_acc=0.867
Epoch 10 | train_loss=0.4413 | val_loss=0.4500 | val_acc=0.867


In [8]:
torch.save(model, "planNN.pkl")
torch.save(model, "planNN.pth")