In [1]:
# Import necessary libraries
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from PIL import Image
import numpy as np


In [2]:
# Dynamic Class Fetching
def get_classes_from_dataset(dataset_path):
    plant_classes = []
    disease_classes = []

    for plant_folder in os.listdir(dataset_path):
        if os.path.isdir(os.path.join(dataset_path, plant_folder)):
            plant_classes.append(plant_folder)  # Add plant type
            for disease_folder in os.listdir(os.path.join(dataset_path, plant_folder)):
                if os.path.isdir(os.path.join(dataset_path, plant_folder, disease_folder)):
                    disease_classes.append(disease_folder)  # Add disease name
    
    # Remove duplicates
    disease_classes = list(set(disease_classes))

    return plant_classes, disease_classes

# Set the path to your dataset
dataset_path = 'dataset'
plant_classes, disease_classes = get_classes_from_dataset(dataset_path)

print(f'Plant Classes: {plant_classes}')
print(f'Disease Classes: {disease_classes}')

# Create a mapping: {plant_index: [valid_disease_indices]}
plant_disease_map = {
    i: [disease_classes.index(d) for d in os.listdir(os.path.join(dataset_path, plant_classes[i]))]
    for i in range(len(plant_classes))
}


Plant Classes: ['Eggplant', 'Potato', 'Tomato']
Disease Classes: ['Tomato___Early_blight', 'Tomato___Late_blight', 'Eggplant___Cercospora_Leaf_Spot', 'Eggplant___Flea_Beetles', 'Potato___healthy', 'Eggplant___Defect_Eggplant', 'Potato___Late_blight', 'Eggplant___Leaf_Wilt', 'Eggplant___Fresh_Eggplant_Leaf', 'Tomato___Bacterial_spot', 'Eggplant___Fresh_Eggplant', 'Potato___Early_blight', 'Eggplant___Tobacco_Mosaic_Virus', 'Eggplant___Phytophthora_Blight', 'Eggplant___Aphids', 'Eggplant___Powdery_Mildew']


In [3]:
# Custom Dataset Class for loading images
class PlantDiseaseDataset(Dataset):
    def __init__(self, dataset_path, plant_classes, disease_classes, transform=None):
        self.dataset_path = dataset_path
        self.plant_classes = plant_classes
        self.disease_classes = disease_classes
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        # Load image paths and labels
        for plant_idx, plant in enumerate(plant_classes):
            for disease_idx, disease in enumerate(disease_classes):
                folder_path = os.path.join(dataset_path, plant, disease)
                if os.path.exists(folder_path):
                    for filename in os.listdir(folder_path):
                        if filename.endswith(".JPG"):
                            self.image_paths.append(os.path.join(folder_path, filename))
                            self.labels.append((plant_idx, disease_idx))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        # Load image
        image = Image.open(image_path).convert('RGB')

        # Apply transformations
        if self.transform:
            image = self.transform(image)

        return image, label


In [4]:
# Data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset
dataset = PlantDiseaseDataset(dataset_path, plant_classes, disease_classes, transform)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)


In [41]:
print("Plant Classes:", plant_classes)
print("Disease Classes:", disease_classes)
print("Disease Mapping:", disease_mapping)


Plant Classes: ['Eggplant', 'Potato', 'Tomato']
Disease Classes: ['Tomato___Early_blight', 'Tomato___Late_blight', 'Eggplant___Cercospora_Leaf_Spot', 'Eggplant___Flea_Beetles', 'Potato___healthy', 'Eggplant___Defect_Eggplant', 'Potato___Late_blight', 'Eggplant___Leaf_Wilt', 'Eggplant___Fresh_Eggplant_Leaf', 'Tomato___Bacterial_spot', 'Eggplant___Fresh_Eggplant', 'Potato___Early_blight', 'Eggplant___Tobacco_Mosaic_Virus', 'Eggplant___Phytophthora_Blight', 'Eggplant___Aphids', 'Eggplant___Powdery_Mildew']


NameError: name 'disease_mapping' is not defined

In [5]:
# Define MultiOutputModel
class MultiOutputModel(nn.Module):
    def __init__(self, num_plants, num_diseases):
        super(MultiOutputModel, self).__init__()
        self.base_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        
        # Freeze pretrained layers
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Replace the final fully connected layer
        self.base_model.fc = nn.Linear(self.base_model.fc.in_features, 512)
        
        # Separate classification layers for plant and disease
        self.plant_fc = nn.Linear(512, num_plants)
        self.disease_fc = nn.Linear(512, num_diseases)

    def forward(self, x):
        features = self.base_model(x)
        plant_output = self.plant_fc(features)
        disease_output = self.disease_fc(features)
        return plant_output, disease_output


In [20]:
# Initialize model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiOutputModel(len(plant_classes), len(disease_classes)).to(device)

def compute_class_weights(labels, num_classes):
    class_counts = np.bincount(labels, minlength=num_classes)  # Ensure correct size
    class_weights = 1.0 / (class_counts + 1e-6)  # Avoid division by zero
    return torch.tensor(class_weights, dtype=torch.float)


plant_weights = compute_class_weights([p for p, d in dataset.labels], len(plant_classes)).to(device)
disease_weights = compute_class_weights([d for p, d in dataset.labels], len(disease_classes)).to(device)


plant_criterion = nn.CrossEntropyLoss(weight=plant_weights)
disease_criterion = nn.CrossEntropyLoss(weight=disease_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)


In [21]:
# Enforce plant-disease consistency
def enforce_consistency(plant_pred, disease_pred):
    valid_diseases = plant_disease_map[plant_pred]
    if disease_pred not in valid_diseases:
        disease_pred = min(valid_diseases, key=lambda d: abs(d - disease_pred))  # Pick closest valid disease
    return disease_pred


In [22]:
print(f"Plant Outputs Shape: {plant_outputs.shape}")  # Should be [batch_size, num_plant_classes]
print(f"Plant Labels Shape: {plant_labels.shape}")    # Should be [batch_size]

print(f"Disease Outputs Shape: {disease_outputs.shape}")  # Should be [batch_size, num_disease_classes]
print(f"Disease Labels Shape: {disease_labels.shape}")    # Should be [batch_size]


Plant Outputs Shape: torch.Size([16, 3])
Plant Labels Shape: torch.Size([16])
Disease Outputs Shape: torch.Size([16, 16])
Disease Labels Shape: torch.Size([16])


In [23]:
num_disease_classes = disease_outputs.shape[1]  # Should match dataset
print(f"Number of disease classes in the model: {num_disease_classes}")


Number of disease classes in the model: 16


In [24]:
disease_criterion = nn.CrossEntropyLoss(weight=torch.ones(num_disease_classes).to(device))


In [25]:
print(f"Plant Outputs Shape: {plant_outputs.shape}, Plant Labels Shape: {plant_labels.shape}")
print(f"Disease Outputs Shape: {disease_outputs.shape}, Disease Labels Shape: {disease_labels.shape}")


Plant Outputs Shape: torch.Size([16, 3]), Plant Labels Shape: torch.Size([16])
Disease Outputs Shape: torch.Size([16, 16]), Disease Labels Shape: torch.Size([16])


In [26]:
print(labels[:5])  # Print first 5 labels to check the structure


[tensor([1, 1, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1]), tensor([ 4, 11,  9, 11,  1,  0,  9,  9,  9,  6,  9,  0,  4,  9,  0, 11])]


In [28]:
num_epochs = 10  # Adjust as needed

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    total_loss = 0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)

        # Extract labels correctly
        plant_labels = labels[0].to(torch.long).to(device)
        disease_labels = labels[1].to(torch.long).to(device)

        optimizer.zero_grad()

        # Forward pass
        plant_outputs, disease_outputs = model(inputs)

        # Compute losses
        plant_loss = plant_criterion(plant_outputs, plant_labels)
        disease_loss = disease_criterion(disease_outputs, disease_labels)
        loss = plant_loss + disease_loss

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Step the learning rate scheduler
    scheduler.step()

    # Print training status
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")


Epoch [1/10], Loss: 0.6381
Epoch [2/10], Loss: 0.5835
Epoch [3/10], Loss: 0.5201
Epoch [4/10], Loss: 0.5039
Epoch [5/10], Loss: 0.4348
Epoch [6/10], Loss: 0.4040
Epoch [7/10], Loss: 0.3666
Epoch [8/10], Loss: 0.3831
Epoch [9/10], Loss: 0.3726
Epoch [10/10], Loss: 0.3760


In [37]:
from torch.utils.data import random_split

# Define train-validation split (e.g., 80% train, 20% validation)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Split dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [38]:
model.eval()
correct_plant = 0
correct_disease = 0
total_samples = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        plant_labels = labels[0].to(torch.long).to(device)
        disease_labels = labels[1].to(torch.long).to(device)

        # Forward pass
        plant_outputs, disease_outputs = model(inputs)

        # Predictions
        _, plant_preds = torch.max(plant_outputs, 1)
        _, disease_preds = torch.max(disease_outputs, 1)

        # Compute accuracy
        correct_plant += (plant_preds == plant_labels).sum().item()
        correct_disease += (disease_preds == disease_labels).sum().item()
        total_samples += plant_labels.size(0)

plant_acc = correct_plant / total_samples * 100
disease_acc = correct_disease / total_samples * 100

print(f"Validation Accuracy - Plant: {plant_acc:.2f}%, Disease: {disease_acc:.2f}%")


Validation Accuracy - Plant: 97.83%, Disease: 95.12%


In [None]:
model_path = "plant_disease_model.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved at {model_path}")


Model saved at plant_disease_model.pth


In [40]:
dummy_input = torch.randn(1, 3, 224, 224)  # Assuming input size is (3, 224, 224)
plant_output, disease_output = model(dummy_input)

print("Plant output shape:", plant_output.shape)  # Expected: (1, num_plants)
print("Disease output shape:", disease_output.shape)  # Expected: (1, num_diseases)


Plant output shape: torch.Size([1, 3])
Disease output shape: torch.Size([1, 16])


In [42]:
print("Plant Classes:", plant_classes)
print("Disease Classes:", disease_classes)
print("Disease Mapping:", disease_mapping)


Plant Classes: ['Eggplant', 'Potato', 'Tomato']
Disease Classes: ['Tomato___Early_blight', 'Tomato___Late_blight', 'Eggplant___Cercospora_Leaf_Spot', 'Eggplant___Flea_Beetles', 'Potato___healthy', 'Eggplant___Defect_Eggplant', 'Potato___Late_blight', 'Eggplant___Leaf_Wilt', 'Eggplant___Fresh_Eggplant_Leaf', 'Tomato___Bacterial_spot', 'Eggplant___Fresh_Eggplant', 'Potato___Early_blight', 'Eggplant___Tobacco_Mosaic_Virus', 'Eggplant___Phytophthora_Blight', 'Eggplant___Aphids', 'Eggplant___Powdery_Mildew']


NameError: name 'disease_mapping' is not defined