# Rice Variety Classifier (Stage 2)

This notebook implements the second stage of the rice analysis pipeline: **Variety Classification**.

**Goal:** Train a ResNet50 model to classify individual rice grains into one of 8 varieties:
1. Arborio
2. Basmati
3. Ipsala
4. Jasmine
5. Karacadag
6. Jhili
7. HMT (Sona Masuri)
8. Masuri

**Prerequisites:**
*   `Milled Rice Dataset.7z` must be uploaded to your Google Drive.

In [None]:
# 1. Setup Environment
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
from google.colab import drive

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Mount Drive
drive.mount('/content/drive')

In [None]:
# 2. Extract Dataset
# Copying to local runtime first is faster for training than reading from Drive

ARCHIVE_PATH = '/content/drive/MyDrive/Milled Rice Dataset.7z'  # Update if path differs
EXTRACT_PATH = '/content/rice_dataset'

if not os.path.exists(EXTRACT_PATH):
    print("Copying and extracting dataset... (This may take a few minutes)")
    os.makedirs(EXTRACT_PATH, exist_ok=True)
    # Using 7z to extract
    !7z x "$ARCHIVE_PATH" -o"$EXTRACT_PATH" > /dev/null
    print("Extraction complete!")
else:
    print("Dataset already extracted.")

In [None]:
# 3. Verify Structure & Classes
# Let's find where the image folders are exactly
import glob

print("Looking for dataset structure...")
# Try to find the root folder containing the class subdirectories
possible_roots = glob.glob(f"{EXTRACT_PATH}/**/Basmati", recursive=True)

if not possible_roots:
    # Fallback search for any directory to help debug
    print("Could not verify 'Basmati' folder immediately. Printing directory tree:")
    !find "$EXTRACT_PATH" -maxdepth 2 -type d
    DATA_DIR = EXTRACT_PATH
else:
    # The parent of 'Basmati' is our data root
    DATA_DIR = os.path.dirname(possible_roots[0])
    print(f"Dataset root found at: {DATA_DIR}")

print(f"\nClasses found:")
!ls "$DATA_DIR"

In [None]:
# 4. Prepare Data Loaders

# Define Transforms
# Rice grains have no orientation, so we can rotate freely
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(180),
    transforms.ToTensor(),
    # ImageNet normalization
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load Data
full_dataset = datasets.ImageFolder(DATA_DIR, transform=train_transform)
classes = full_dataset.classes
print(f"Detected {len(classes)} classes: {classes}")

# Split Train/Val (80/20)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Important: Apply val_transform to validation set (hacky way since random_split doesn't allow separate transforms)
# Ideally we split indices first, but for quick implementation this is acceptable or we just use train_transform for both (less optimal)
# A cleaner way is using Subset with a custom Wrapper, but let's stick to simple for now.
# Since random_split shares the underlying dataset, modifying .transform affects both.
# We will proceed with train_transform for both for simplicity in this notebook, 
# or we can reload the dataset for validation usage. Let's reload for correctness.

print("Re-splitting for correct transform application...")
# Get indices
indices = torch.randperm(len(full_dataset)).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:]

train_subset = torch.utils.data.Subset(datasets.ImageFolder(DATA_DIR, transform=train_transform), train_indices)
val_subset = torch.utils.data.Subset(datasets.ImageFolder(DATA_DIR, transform=val_transform), val_indices)

BATCH_SIZE = 32

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_subset)}, Val samples: {len(val_subset)}")

In [None]:
# 5. Define Model (ResNet50)

model = models.resnet50(pretrained=True)

# Freeze early layers (Optional - unfreezing often better for distinct textures like rice)
# for param in model.parameters():
#     param.requires_grad = False

# Replace last layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

In [None]:
# 6. Training Loop

def train_model(model, train_loader, val_loader, epochs=10):
    best_acc = 0.0
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        print("-" * 10)
        
        # Training
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        
        print(f"Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                val_corrects += torch.sum(preds == labels.data)
                
        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_corrects.double() / len(val_loader.dataset)
        
        print(f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), '/content/drive/MyDrive/rice_resnet50_best.pth')
            print("Saved new best model!")

train_model(model, train_loader, val_loader, epochs=5) # Start with 5 epochs

In [None]:
# 7. Verify on Example Image (Optional)
# Takes a random validation image and shows the prediction
import numpy as np

def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title:
        plt.title(title)
    plt.pause(0.001)

model.eval()
inputs, classes_idx = next(iter(val_loader))
outputs = model(inputs.to(device))
_, preds = torch.max(outputs, 1)

# Show first 4 images
imshow(torchvision.utils.make_grid(inputs[:4]), title=[classes[x] for x in preds[:4]])