In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models


In [4]:
!pip install kaggle split-folders


Collecting split-folders
  Downloading split_folders-0.5.1-py3-none-any.whl.metadata (6.2 kB)
Downloading split_folders-0.5.1-py3-none-any.whl (8.4 kB)
Installing collected packages: split-folders
Successfully installed split-folders-0.5.1


In [5]:
from google.colab import files
files.upload()   # upload kaggle.json


Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"abhimani22","key":"19c3194f6c2ab8d889d74c8b0f885ce4"}'}

In [6]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


In [7]:
pip install -q kaggle

In [8]:
!kaggle datasets list

ref                                                                title                                                    size  lastUpdated                 downloadCount  voteCount  usabilityRating  
-----------------------------------------------------------------  -------------------------------------------------  ----------  --------------------------  -------------  ---------  ---------------  
neurocipher/heartdisease                                           Heart Disease                                            3491  2025-12-11 15:29:14.327000           2114        318  1.0              
mabubakrsiddiq/retail-store-product-sales-simulation-dataset       üè™ Retail Store Product Sales Simulation Dataset       1383545  2026-01-16 13:12:07.310000              0         26  1.0              
saidaminsaidaxmadov/chocolate-sales                                Chocolate Sales                                        468320  2026-01-04 14:23:35.490000              0         59  1.0  

In [9]:
!kaggle datasets download -d emmarex/plantdisease -p /content


Dataset URL: https://www.kaggle.com/datasets/emmarex/plantdisease
License(s): unknown
Downloading plantdisease.zip to /content
 99% 654M/658M [00:06<00:00, 161MB/s]
100% 658M/658M [00:06<00:00, 110MB/s]


In [10]:
!unzip -q /content/plantdisease.zip -d /content/plantvillage_raw


In [11]:
!pip install split-folders




In [12]:
import splitfolders

splitfolders.ratio(
    "/content/plantvillage_raw/PlantVillage",
    output="/content/plantvillage_split",
    seed=42,
    ratio=(0.8, 0.2)
)


Copying files: 20639 files [00:04, 4783.62 files/s]


In [13]:
IMG_SIZE = 224

train_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [14]:
train_dir = "/content/plantvillage_split/train"
val_dir   = "/content/plantvillage_split/val"

train_data = datasets.ImageFolder(train_dir, transform=train_tfms)
val_data   = datasets.ImageFolder(val_dir, transform=val_tfms)

print("Classes:", train_data.classes)
print("Train images:", len(train_data))
print("Val images:", len(val_data))


Classes: ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']
Train images: 16504
Val images: 4134


In [15]:
BATCH_SIZE = 32

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = len(train_data.classes)

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 120MB/s]


In [17]:
import torch.nn as nn


In [18]:

from collections import Counter

counts = Counter(train_data.targets)
total = sum(counts.values())

class_weights = [
    total / counts[i] for i in range(num_classes)
]

weights = torch.FloatTensor(class_weights).to(device)

criterion = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [20]:
EPOCHS = 20

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = 100 * correct / total
    print(f"Epoch [{epoch+1}/{EPOCHS}] | Loss: {train_loss:.2f} | Val Acc: {acc:.2f}%")


KeyboardInterrupt: 

In [None]:
MODEL_PATH = "/content/drive/MyDrive/plant_disease_resnet50.pth"

torch.save({
    "model_state_dict": model.state_dict(),
    "class_names": train_data.classes
}, MODEL_PATH)

print("‚úÖ Model saved at:", MODEL_PATH)
