# lets train the model


## we have used the dataset from the kaggle 
link:- `https://www.kaggle.com/datasets/abdallahalidev/plantvillage-dataset/data`



In [2]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())        # Should print: True
print(torch.cuda.get_device_name(0))    # Should print: NVIDIA GeForce RTX 3050


2.5.1+cu121
True
NVIDIA GeForce RTX 3050 6GB Laptop GPU


In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split


In [8]:
#. Device Config + Dataset Path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

DATASET_PATH = '../plantvillagedataset/color'


Using: cuda


In [9]:

#  Transform and Dataset Loading

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

full_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
num_classes = len(full_dataset.classes)

print(f"Total images: {len(full_dataset)}")
print(f"Classes: {full_dataset.classes}")


Total images: 54305
Classes: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spid

In [10]:
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [11]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)



In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [13]:
EPOCHS = 5  # change to 10 or 15 later if needed

for epoch in range(EPOCHS):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_acc = 100 * correct / total
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {running_loss:.4f}, Train Accuracy: {train_acc:.2f}%")


Epoch [1/5], Loss: 389.5754, Train Accuracy: 91.02%
Epoch [2/5], Loss: 163.9958, Train Accuracy: 96.11%
Epoch [3/5], Loss: 116.1546, Train Accuracy: 97.15%
Epoch [4/5], Loss: 96.3594, Train Accuracy: 97.74%
Epoch [5/5], Loss: 83.0326, Train Accuracy: 98.08%


In [14]:
print("Device being used:", device)


Device being used: cuda


In [15]:
model.eval()
val_correct, val_total = 0, 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        val_total += labels.size(0)
        val_correct += (predicted == labels).sum().item()

val_acc = 100 * val_correct / val_total
print(f"\n✅ Validation Accuracy: {val_acc:.2f}%")



✅ Validation Accuracy: 97.54%


In [16]:
from sklearn.metrics import classification_report

all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print(classification_report(all_labels, all_preds, target_names=full_dataset.classes))


                                                    precision    recall  f1-score   support

                                Apple___Apple_scab       0.94      0.98      0.96       118
                                 Apple___Black_rot       0.97      0.98      0.97       118
                          Apple___Cedar_apple_rust       0.98      1.00      0.99        59
                                   Apple___healthy       0.99      0.99      0.99       317
                               Blueberry___healthy       0.99      1.00      1.00       306
          Cherry_(including_sour)___Powdery_mildew       1.00      1.00      1.00       209
                 Cherry_(including_sour)___healthy       0.97      0.99      0.98       149
Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot       0.87      0.94      0.90        85
                       Corn_(maize)___Common_rust_       0.99      1.00      0.99       241
               Corn_(maize)___Northern_Leaf_Blight       0.98      0.94      0.

In [17]:

# for saving
if val_acc >= 90:  # You can set your own threshold
    torch.save(model.state_dict(), 'plant_disease_model_final.pth')
    print("✅ Model saved successfully!")
else:
    print("❌ Accuracy not good enough. Try improving the model.")


✅ Model saved successfully!
