# ðŸŒ¿ GreenTwin: Universal Plant Disease Model Training

This notebook trains a **MobileNetV2** model on the **New Plant Diseases Dataset** to recognize **38 different classes** of plants/diseases.

## 1. Setup
Ensure you have added the **[New Plant Diseases Dataset](https://www.kaggle.com/vipoooool/new-plant-diseases-dataset)** to your Kaggle Notebook input.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import os
import json
import time
import copy

## 2. Configuration & Data Loading (with Auto-Discovery)

In [None]:
# Robust Data Path Finding
# Kaggle datasets are mounted in /kaggle/input, but the exact path varies depending on upload structure.
print("Searching for dataset in /kaggle/input...")

TRAIN_DIR = None
VALID_DIR = None

# Walk specifically looking for 'train' and 'valid' folders
for root, dirs, files in os.walk("/kaggle/input"):
    if "train" in dirs:
        found_train = os.path.join(root, "train")
        # Simple check to ensure it's not empty
        if len(os.listdir(found_train)) > 0:
            TRAIN_DIR = found_train
            
    if "valid" in dirs:
        found_valid = os.path.join(root, "valid")
        if len(os.listdir(found_valid)) > 0:
            VALID_DIR = found_valid
            
    if TRAIN_DIR and VALID_DIR:
        break

if not TRAIN_DIR or not VALID_DIR:
    # Fallback: Sometimes folder names are capitalized or differently named
    print("Standard 'train'/'valid' folders not found. Trying manual path construction...")
    # Add more fallback logic if necessary, but the walk usually works.

if not TRAIN_DIR:
    raise FileNotFoundError("CRITICAL: Could not find 'train' directory in /kaggle/input. Did you add the dataset?")
if not VALID_DIR:
    raise FileNotFoundError("CRITICAL: Could not find 'valid' directory in /kaggle/input.")

print(f"âœ… Found Train Data: {TRAIN_DIR}")
print(f"âœ… Found Valid Data: {VALID_DIR}")

BATCH_SIZE = 32
IMG_SIZE = 224
EPOCHS = 10 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

In [None]:
# Data Transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {
    'train': datasets.ImageFolder(TRAIN_DIR, data_transforms['train']),
    'valid': datasets.ImageFolder(VALID_DIR, data_transforms['valid'])
}

dataloaders = {
    'train': DataLoader(image_datasets['train'], batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
    'valid': DataLoader(image_datasets['valid'], batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
}

class_names = image_datasets['train'].classes
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}

print(f"Classes ({len(class_names)}): {class_names}")
print(f"Training images: {dataset_sizes['train']}")
print(f"Validation images: {dataset_sizes['valid']}")

## 3. Save Classes JSON
We need this file for our backend to map predictions back to names.

In [None]:
with open('classes.json', 'w') as f:
    json.dump(class_names, f)
    
print("Saved classes.json âœ…")

## 4. Build Model (MobileNetV2)

In [None]:
model = models.mobilenet_v2(pretrained=True)

# Freeze early layers (optional, but speeds up training)
# for param in model.features.parameters():
#     param.requires_grad = False

# Replace Classifier Head
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, len(class_names))

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

## 5. Train Model

In [None]:
def train_model(model, criterion, optimizer, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model

model_ft = train_model(model, criterion, optimizer, num_epochs=EPOCHS)

## 6. Save Model

In [None]:
torch.save(model_ft.state_dict(), 'universal_model.pth')
print("Saved universal_model.pth âœ…")