In [None]:
!pip install transformers timm

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import random
import shutil

In [None]:
train_dir = "/kaggle/input/plant-seedlings-classification/train"
labels = os.listdir(train_dir)

# Create a List of (image_path, label)
data = []
for label in labels:
    image_paths = os.listdir(os.path.join(train_dir, label))
    for img_name in image_paths:
        img_path = os.path.join(train_dir, label, img_name)
        data.append((img_path, label))

df = pd.DataFrame(data, columns=["image_path", "label"])


# Dataset Summary
print("Total images:", len(df))
print("Number of classes:", df['label'].nunique())
print("Classes:", df['label'].unique())

In [None]:
df.sample(n=5)

In [None]:
# Check missing value
print("Missing image paths:", df['image_path'].isnull().sum())
print("Missing labels:", df['label'].isnull().sum())

In [None]:
# Check for completely duplicated rows (all columns)
duplicates = df[df.duplicated()]
print(f"Total completely duplicated rows: {len(duplicates)}")

# Check for duplicate image paths only
dup_images = df[df.duplicated(subset=['image_path'])]
print(f"Duplicate image paths: {len(dup_images)}")


In [None]:
# Visual Inspection (1 sample per class)

plt.figure(figsize=(15, 8))
unique_labels = df['label'].unique()

for i, label in enumerate(unique_labels):
    sample_path = df[df['label'] == label].sample(1)['image_path'].values[0]
    image = Image.open(sample_path)
    
    plt.subplot(3, 5, i+1)
    plt.imshow(image)
    plt.title(label)
    plt.axis('off')

plt.tight_layout()
plt.suptitle("Sample Image per Class", y=1.02, fontsize=16)
plt.show()


In [None]:
# Map Labels to IDs - Encode labels
label2id = {label: idx for idx, label in enumerate(sorted(df['label'].unique()))}
id2label = {v: k for k, v in label2id.items()}
df["label_id"] = df["label"].map(label2id)

train_df, val_df = train_test_split(df, test_size=0.1, stratify=df["label_id"], random_state=42)

In [None]:
print("Missing image paths:", df['image_path'].isnull().sum())
print("Missing labels:", df['label'].isnull().sum())
print("Missing label_ids:", df['label_id'].isnull().sum())

In [None]:
class_counts = df['label'].value_counts()
print(class_counts)

# Visualize
plt.figure(figsize=(12,5))
sns.barplot(x=class_counts.index, y=class_counts.values)
plt.title("Original Class Distribution")
plt.xticks(rotation=90)
plt.show()


In [None]:
augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
])


In [None]:
# Create a new augmented directory
aug_dir = "./augmented_train"
os.makedirs(aug_dir, exist_ok=True)

target_count = class_counts.max()

balanced_records = []

for label in class_counts.index:
    class_dir = os.path.join(aug_dir, label)
    os.makedirs(class_dir, exist_ok=True)

    class_df = df[df['label'] == label]
    n_existing = len(class_df)
    samples = class_df.sample(n=target_count, replace=True, random_state=42)

    for i, row in samples.iterrows():
        img = Image.open(row['image_path']).convert("RGB")

        # Augment if it’s synthetic sample
        if n_existing <= target_count and i >= n_existing:
            img = augment_transform(img)

        # Save image
        filename = f"{label}_{i}_{random.randint(1000, 9999)}.png"
        save_path = os.path.join(class_dir, filename)
        img.save(save_path)

        balanced_records.append({
            "image_path": save_path,
            "label": label,
            "label_id": row["label_id"]
        })

# Create new balanced dataframe
df = pd.DataFrame(balanced_records)

In [None]:
class_counts = df['label'].value_counts()
print(class_counts)

plt.figure(figsize=(12, 6))
sns.countplot(x='label', data=df)
plt.xticks(rotation=90)
plt.title("Class Distribution")
plt.show()

In [None]:
class PlantSeedlingsDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.transform = transform
        
    # Tell PyTorch how many items in dataset
    def __len__(self):
        return len(self.dataframe)

    #Load and return one image and its label as a tensor
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        label = row["label_id"]

        if self.transform:
            image = self.transform(image)

        return image, label


In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")


# Define Train Transform
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), #Converts image to a PyTorch tensor
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) #Standardize pixel values
])

# Define Validation Transform
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), # Converts image to a PyTorch tensor
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) #Standardize pixel values
])


In [None]:
train_dataset = PlantSeedlingsDataset(train_df, transform=train_transform)
val_dataset = PlantSeedlingsDataset(val_df, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

In [None]:
from torch.nn import CrossEntropyLoss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)
model = model.to(device)

# Count samples in each class
class_counts = df['label_id'].value_counts().sort_index()
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum()

# Convert to tensor and move to GPU if available
weights_tensor = torch.tensor(class_weights.values, dtype=torch.float).to(device)

# Create loss function with class weights
criterion = CrossEntropyLoss(weight=weights_tensor)

# criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-5)


In [None]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0
    correct = 0
    total = 0

    # Loop through each batch of images and labels
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass: Get model predictions (logits)
        outputs = model(images).logits
        # Compute loss between predictions and true labels
        loss = criterion(outputs, labels)

         # Zero out gradients from previous step
        optimizer.zero_grad()
        # Backward pass: compute gradients
        loss.backward() # 
        # Update model weights based on gradients
        optimizer.step()

        running_loss += loss.item()
        # Get predicted class indices (the one with highest logit score)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    accuracy = 100. * correct / total
    avg_loss = running_loss / len(loader)
    return avg_loss, accuracy

def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images).logits
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

    accuracy = 100. * correct / total
    avg_loss = running_loss / len(loader)
    return avg_loss, accuracy


In [None]:
EPOCHS = 30

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val   Loss: {val_loss:.4f}, Val   Acc: {val_acc:.2f}%")


In [None]:
val_df

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Make sure model is in evaluation mode
model.eval()

# Image transform (same as training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

def predict_image(image_path):
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)  # add batch dimension
    with torch.no_grad():
        outputs = model(pixel_values=input_tensor)
        logits = outputs.logits
        predicted_class_id = logits.argmax(dim=1).item()
    return predicted_class_id, image

# Show N predictions
def show_predictions(df, n=12):
    plt.figure(figsize=(16, 8))
    for i in range(n):
        row = df.iloc[i]
        pred_id, image = predict_image(row["image_path"])
        actual = row["label"]
        predicted = id2label[pred_id]

        plt.subplot(2, n//2, i+1)
        plt.imshow(image)
        plt.title(f"Actual: {actual}\nPredicted: {predicted}", fontsize=10)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

# Example usage
show_predictions(val_df, n=12)


In [None]:
save_path = "./vit-plant-seedlings"

# Save model and processor
model.save_pretrained(save_path)
feature_extractor.save_pretrained(save_path)

In [None]:
# from huggingface_hub import HfApi

# api = HfApi(token=os.getenv(""))
# api.upload_folder(
#     folder_path="./vit-plant-seedlings",
#     repo_id="sajeewa/vit-base-patch16-224",
#     repo_type="model",
# )