<a href="https://colab.research.google.com/github/prathameshdv/Grainspace-Project/blob/main/Grainspace%3AEnsemble.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
from pathlib import Path
import shutil
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder

from sklearn.metrics import f1_score, confusion_matrix, classification_report

# ---------- CONFIG ----------
ZIP_NAME = "WHEAT_R19-22_G600_TRAIN.zip"   # change to your uploaded zip filename
UNZIP_DIR = Path("wheat_R19_22_G600/train")   # where zip will extract
SPLIT_DIR = Path("wheat_R19_22_G600_split")  # output split folder
RANDOM_SEED = 42

# Training options
PRETRAINED = True   # If True -> ImageNet pretrained ConvNeXt, else random init (scratch)
MODEL_NAME = "convnext_tiny"  # "convnext_tiny" recommended; you can also try convnext_base if you have memory
BATCH_SIZE = 32
EPOCHS = 25
LR = 0.015
WEIGHT_DECAY = 1e-4
MILESTONES = [15, 25]  # lr scheduler step-down epochs or use StepLR below
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 4
SAVE_BEST = True
BEST_MODEL_PATH = "best_convnext.pth"

print("Device:", DEVICE)


In [None]:
rom zipfile import ZipFile

if not UNZIP_DIR.exists():
    print("Unzipping:", ZIP_NAME)
    assert Path(ZIP_NAME).exists(), f"{ZIP_NAME} not found in current directory!"
    with ZipFile(ZIP_NAME, 'r') as zip_ref:
        zip_ref.extractall(UNZIP_DIR)
    print("Extracted to", UNZIP_DIR)
else:
    print("Unzip target already exists:", UNZIP_DIR)

# quick listing
print("Top-level folders (classes):", sorted([p.name for p in UNZIP_DIR.iterdir() if p.is_dir()]))


In [None]:
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

def make_splits(src_dir: Path, out_dir: Path, train_frac=0.8, val_frac=0.1):
    if out_dir.exists():
        print("Split dir exists, removing and recreating:", out_dir)
        shutil.rmtree(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    for split in ["train","val","test"]:
        (out_dir / split).mkdir(parents=True, exist_ok=True)

    for class_dir in sorted(src_dir.iterdir()):
        if not class_dir.is_dir():
            continue
        imgs = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png")) + list(class_dir.glob("*.jpeg"))
        if len(imgs) == 0:
            print(f"Warning: no images found in {class_dir}")
            continue
        random.shuffle(imgs)
        n = len(imgs)
        n_train = int(train_frac * n)
        n_val = int(val_frac * n)
        n_test = n - n_train - n_val
        splits = {
            "train": imgs[:n_train],
            "val": imgs[n_train:n_train+n_val],
            "test": imgs[n_train+n_val:]
        }
        for sp, files in splits.items():
            tgt_dir = out_dir / sp / class_dir.name
            tgt_dir.mkdir(parents=True, exist_ok=True)
            for f in files:
                # copy (keeps original intact)
                shutil.copy(f, tgt_dir / f.name)
    print("Done creating splits at:", out_dir)

make_splits(UNZIP_DIR, SPLIT_DIR)

In [None]:
def counts(split_dir: Path):
    info = {}
    for split in ["train","val","test"]:
        split_path = split_dir / split
        if not split_path.exists():
            info[split] = 0
            continue
        total = 0
        per_class = {}
        for cls in sorted([d for d in split_path.iterdir() if d.is_dir()]):
            c = len(list(cls.glob("*.*")))
            per_class[cls.name] = c
            total += c
        info[split] = {"total": total, "per_class": per_class}
    return info

cnts = counts(SPLIT_DIR)
print("Train total:", cnts["train"]["total"])
print("Val total:", cnts["val"]["total"])
print("Test total:", cnts["test"]["total"])
print("Per-class counts (train):")
for k,v in cnts["train"]["per_class"].items():
    print(f"  {k}: {v}")

In [None]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_dataset = ImageFolder(str(SPLIT_DIR/"train"), transform=train_transform)
val_dataset = ImageFolder(str(SPLIT_DIR/"val"), transform=val_test_transform)
test_dataset = ImageFolder(str(SPLIT_DIR/"test"), transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

class_names = train_dataset.classes
num_classes = len(class_names)
print("Num classes:", num_classes, "Class names:", class_names)
print("Train/Val/Test sizes:", len(train_dataset), len(val_dataset), len(test_dataset))


In [None]:
def build_resnet50(num_classes, pretrained=True):
    try:
        model = torchvision.models.resnet50(pretrained=pretrained)
    except:
        weights = torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None
        model = torchvision.models.resnet50(weights=weights)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def build_convnext(num_classes, pretrained=True):
    try:
        model = torchvision.models.convnext_tiny(pretrained=pretrained)
    except:
        weights = torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
        model = torchvision.models.convnext_tiny(weights=weights)
    model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
    return model

resnet = build_resnet50(num_classes, pretrained=PRETRAINED).to(DEVICE)
convnext = build_convnext(num_classes, pretrained=PRETRAINED).to(DEVICE)

In [None]:
def train_model(model, train_loader, val_loader, epochs, lr):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    best_val_f1 = 0.0
    for epoch in range(epochs):
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        scheduler.step()

        # validation
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                outputs = model(imgs)
                preds = outputs.argmax(1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        val_f1 = f1_score(all_labels, all_preds, average="macro")
        print(f"Epoch {epoch+1}/{epochs}, Val Macro F1: {val_f1:.4f}")
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_state = model.state_dict().copy()
    model.load_state_dict(best_state)
    return model

In [None]:
print("Training ResNet-50...")
resnet = train_model(resnet, train_loader, val_loader, epochs=EPOCHS, lr=LR)

print("\nTraining ConvNeXt-Tiny...")
convnext = train_model(convnext, train_loader, val_loader, epochs=EPOCHS, lr=LR)


In [None]:
def evaluate_ensemble(models, loader):
    for m in models: m.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            logits = [m(imgs) for m in models]
            avg_logits = sum(logits) / len(logits)
            preds = avg_logits.argmax(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    macro_f1 = f1_score(all_labels, all_preds, average="macro")
    return macro_f1, all_preds, all_labels

test_f1, preds, labels = evaluate_ensemble([resnet, convnext], test_loader)
print(f" Ensemble (ResNet+ConvNeXt) Test Macro F1: {test_f1:.4f}")
print(classification_report(labels, preds, target_names=class_names, digits=4))