In [1]:
%pip install torch torchvision torchaudio
%pip install matplotlib scikit-learn opencv-python tqdm


Collecting torchaudio
  Downloading torchaudio-2.6.0-cp312-cp312-win_amd64.whl.metadata (6.7 kB)
Downloading torchaudio-2.6.0-cp312-cp312-win_amd64.whl (2.4 MB)
   ---------------------------------------- 0.0/2.4 MB ? eta -:--:--
   ---- ----------------------------------- 0.3/2.4 MB ? eta -:--:--
   -------- ------------------------------- 0.5/2.4 MB 1.7 MB/s eta 0:00:02
   ----------------- ---------------------- 1.0/2.4 MB 1.7 MB/s eta 0:00:01
   ------------------------- -------------- 1.6/2.4 MB 2.0 MB/s eta 0:00:01
   ---------------------------------- ----- 2.1/2.4 MB 2.1 MB/s eta 0:00:01
   ---------------------------------------- 2.4/2.4 MB 2.3 MB/s eta 0:00:00
Installing collected packages: torchaudio
Successfully installed torchaudio-2.6.0
Note: you may need to restart the kernel to use updated packages.




Note: you may need to restart the kernel to use updated packages.




In [1]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, random_split
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

# Helper function to convert numpy array to PIL Image
def numpy_to_pil(img):
    return Image.fromarray(img)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transformations
# Augment images with crop, flip, color changes to improve generalization.
data_transforms = {
    'train': transforms.Compose([
        transforms.Lambda(numpy_to_pil), 
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Lambda(numpy_to_pil),  # Use named function instead of lambda
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

class PlantDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.samples = []
        
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, self.class_to_idx[class_name]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        try:
            img_path, label = self.samples[idx]
            img = cv2.imread(img_path)
            if img is None:
                raise ValueError(f"Could not read image {img_path}")
            
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            if self.transform:
                img = self.transform(img)
                
            return img, label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a placeholder image and label
            placeholder = torch.zeros(3, 224, 224)
            return placeholder, 0

def create_model(num_classes):
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    
    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace the final fully connected layer
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_features, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    return model

def train():
    # Create models directory if it doesn't exist
    os.makedirs('models', exist_ok=True)
    
    # Load dataset
    dataset = PlantDiseaseDataset("data/train", transform=data_transforms['train'])
    
    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Create data loaders - set num_workers=0 for Windows or when getting pickle errors
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    # Initialize model
    model = create_model(len(dataset.classes))
    model = model.to(device)
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    best_acc = 0.0
    
    # Training loop
    for epoch in range(20):
        print(f'Epoch {epoch+1}/20')
        print('-' * 10)
        
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in tqdm(train_loader, desc='Training'):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(train_dataset)
        epoch_acc = running_corrects.double() / len(train_dataset)
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc='Validating'):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                _, preds = torch.max(outputs, 1)
                val_loss += loss.item() * inputs.size(0)
                val_corrects += torch.sum(preds == labels.data)
        
        val_loss = val_loss / len(val_dataset)
        val_acc = val_corrects.double() / len(val_dataset)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc.item())
        
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'class_names': dataset.classes,
                'model_name': 'resnet50'
            }, 'models/plant_disease_model.pth')
            print("Saved new best model")
    
    # Plot training history
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('models/training_history.png')
    plt.show()

if __name__ == '__main__':
    torch.multiprocessing.freeze_support()
    train()

Using device: cpu
Epoch 1/20
----------


Training: 100%|██████████| 191/191 [09:06<00:00,  2.86s/it]


Train Loss: 0.3600 Acc: 0.8668


Validating: 100%|██████████| 48/48 [02:04<00:00,  2.59s/it]


Val Loss: 0.2202 Acc: 0.9047
Saved new best model
Epoch 2/20
----------


Training: 100%|██████████| 191/191 [19:04<00:00,  5.99s/it]  


Train Loss: 0.2409 Acc: 0.9053


Validating: 100%|██████████| 48/48 [02:52<00:00,  3.60s/it]


Val Loss: 0.2213 Acc: 0.9093
Saved new best model
Epoch 3/20
----------


Training: 100%|██████████| 191/191 [13:36<00:00,  4.27s/it]


Train Loss: 0.2154 Acc: 0.9191


Validating: 100%|██████████| 48/48 [03:00<00:00,  3.77s/it]


Val Loss: 0.1479 Acc: 0.9474
Saved new best model
Epoch 4/20
----------


Training:  19%|█▉        | 36/191 [02:28<10:39,  4.13s/it]


KeyboardInterrupt: 