In [8]:
# Import necessary libraries
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torchvision.models import ResNet18_Weights


In [2]:
# Dynamic Class Fetching
def get_classes_from_dataset(dataset_path):
    plant_classes = []
    disease_classes = []
    
    for plant_folder in os.listdir(dataset_path):
        if os.path.isdir(os.path.join(dataset_path, plant_folder)):
            plant_classes.append(plant_folder)  # Add plant type (e.g., Potato, Tomato)
            for disease_folder in os.listdir(os.path.join(dataset_path, plant_folder)):
                if os.path.isdir(os.path.join(dataset_path, plant_folder, disease_folder)):
                    disease_classes.append(disease_folder)  # Add disease (e.g., Early_blight, Healthy)
    
    # Removing duplicates, since disease names might repeat across plants
    disease_classes = list(set(disease_classes))
    
    return plant_classes, disease_classes


# Set the path to your dataset
dataset_path = 'dataset'

# Get plant and disease classes dynamically from dataset
plant_classes, disease_classes = get_classes_from_dataset(dataset_path)

# Print out the classes for reference
print(f'Plant Classes: {plant_classes}')
print(f'Disease Classes: {disease_classes}')


Plant Classes: ['Potato', 'Tomato']
Disease Classes: ['Potato___Early_blight', 'Tomato___Bacterial_spot', 'Tomato___Late_blight', 'Potato___Late_blight', 'Tomato___Early_blight', 'Potato___healthy']


In [3]:
# Custom Dataset Class for loading images
class PlantDiseaseDataset(Dataset):
    def __init__(self, dataset_path, plant_classes, disease_classes, transform=None):
        self.dataset_path = dataset_path
        self.plant_classes = plant_classes
        self.disease_classes = disease_classes
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        # Load image paths and labels
        for plant_idx, plant in enumerate(plant_classes):
            for disease_idx, disease in enumerate(disease_classes):
                folder_path = os.path.join(dataset_path, plant, disease)
                if os.path.exists(folder_path):
                    for filename in os.listdir(folder_path):
                        if filename.endswith(".JPG"):
                            self.image_paths.append(os.path.join(folder_path, filename))
                            self.labels.append((plant_idx, disease_idx))
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load image
        image = Image.open(image_path)
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        return image, label


In [4]:
# Define transformations for input images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [5]:
# Split dataset into training and validation sets
dataset = PlantDiseaseDataset(dataset_path, plant_classes, disease_classes, transform)

train_data, val_data = train_test_split(dataset, test_size=0.2, stratify=dataset.labels)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)


In [9]:
# Define the MultiOutputModel
class MultiOutputModel(nn.Module):
    def __init__(self, num_plants, num_diseases):
        super(MultiOutputModel, self).__init__()
        self.base_model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.base_model.fc = nn.Linear(self.base_model.fc.in_features, 512)
        self.plant_fc = nn.Linear(512, num_plants)
        self.disease_fc = nn.Linear(512, num_diseases)

    def forward(self, x):
        features = self.base_model(x)
        plant_output = self.plant_fc(features)
        disease_output = self.disease_fc(features)
        return plant_output, disease_output


In [10]:
# Initialize model, loss function, and optimizer
model = MultiOutputModel(len(plant_classes), len(disease_classes))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [13]:
# Train the model
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        # Assuming labels is already a tuple (plant_labels, disease_labels)
        plant_labels, disease_labels = labels
        plant_labels = torch.tensor(plant_labels)
        disease_labels = torch.tensor(disease_labels)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        plant_outputs, disease_outputs = model(inputs)

        # Calculate loss
        plant_loss = criterion(plant_outputs, plant_labels)
        disease_loss = criterion(disease_outputs, disease_labels)
        loss = plant_loss + disease_loss

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}')


  plant_labels = torch.tensor(plant_labels)
  disease_labels = torch.tensor(disease_labels)


Epoch [1/10], Loss: 0.7574359308396067
Epoch [2/10], Loss: 0.3867612212896347
Epoch [3/10], Loss: 0.41671855699803145
Epoch [4/10], Loss: 0.3766050634639604
Epoch [5/10], Loss: 0.41221226890172274
Epoch [6/10], Loss: 0.33398997145039694
Epoch [7/10], Loss: 0.3807034127946411
Epoch [8/10], Loss: 0.246326643275097
Epoch [9/10], Loss: 0.09768552472482302
Epoch [10/10], Loss: 0.3495805774283196


In [None]:
# Validate the model
model.eval()
correct_plant = 0                       //USE THIS NEXT TIME 
correct_disease = 0
total = 0

with torch.no_grad():
    total_batches = len(val_loader)  # Total number of batches
    for batch_idx, (inputs, labels) in enumerate(val_loader):
        # Print every 100th batch for progress monitoring
        if batch_idx % 100 == 0:
            remaining_batches = total_batches - batch_idx
            print(f"Processing batch {batch_idx + 1}/{total_batches} - {remaining_batches} batches remaining")

        # Assuming labels is a tuple (plant_labels, disease_labels)
        plant_labels, disease_labels = labels
        plant_labels = plant_labels.to(torch.long)  # Ensure the correct tensor type
        disease_labels = disease_labels.to(torch.long)  # Ensure the correct tensor type

        # Forward pass
        plant_outputs, disease_outputs = model(inputs)

        # Get predictions
        _, plant_pred = torch.max(plant_outputs, 1)
        _, disease_pred = torch.max(disease_outputs, 1)

        # Update the counters
        correct_plant += (plant_pred == plant_labels).sum().item()
        correct_disease += (disease_pred == disease_labels).sum().item()
        total += inputs.size(0)  # Use batch size for total count

# Print accuracy
print(f'Accuracy on validation set: Plant: {100 * correct_plant / total:.2f}%, Disease: {100 * correct_disease / total:.2f}%')


In [14]:
# Validate the model
model.eval()
correct_plant = 0
correct_disease = 0
total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        # Assuming labels is a tuple (plant_labels, disease_labels)
        plant_labels, disease_labels = labels
        plant_labels = plant_labels.to(torch.long)  # Ensure the correct tensor type
        disease_labels = disease_labels.to(torch.long)  # Ensure the correct tensor type

        # Forward pass
        plant_outputs, disease_outputs = model(inputs)

        # Get predictions
        _, plant_pred = torch.max(plant_outputs, 1)
        _, disease_pred = torch.max(disease_outputs, 1)

        # Update the counters
        correct_plant += (plant_pred == plant_labels).sum().item()
        correct_disease += (disease_pred == disease_labels).sum().item()
        total += inputs.size(0)  # Use batch size for total count

# Print accuracy
print(f'Accuracy on validation set: Plant: {100 * correct_plant / total:.2f}%, Disease: {100 * correct_disease / total:.2f}%')


Accuracy on validation set: Plant: 97.11%, Disease: 91.68%


In [15]:
# Save the model
torch.save(model.state_dict(), 'plant_disease_model.pth')
