<a href="https://colab.research.google.com/github/shama-llama/crop-mapping/blob/main/src/crop_mapping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Crop Classification using Convolutional Neural Networks

## 1. Initialization

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### 1.1. Import Libraries

In [None]:
# Importing essential libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import io
import numpy as np

# For displaying plots inline
%matplotlib inline

# Setting a seaborn style for better visuals
sns.set(style="whitegrid")

### 1.2. Data Loading & Initial Inspection

In [None]:
# Specify the path to the .parquet file
parquet_path = "/content/drive/MyDrive/projects/crop-mapping-with-deep-learning/crop-mapping/dataset/crop_dataset.parquet"

# Load the dataset into a DataFrame
df = pd.read_parquet(parquet_path)
print("Dataset loaded successfully!")

df.head(5)

## 2. Exploratory Data Analysis

### 2.1. Basic Data Overview

In [None]:
# Display basic information about the dataset
print("Dataset Shape:", df.shape)
print("\nDataset Columns:")
print(df.columns)

In [None]:
# Dimensions, info, and summary statistics
print("Dataset shape:", df.shape)
df.info()
print("\nSummary statistics:")
display(df.describe(include='all'))

### 2.2. Missing Values & Data Types

In [None]:
# Check for missing values in each column
print("Missing Values by Column:")
print(df.isnull().sum())

# Display data types for each column
print("\nData Types:")
print(df.dtypes)

### 2.3. Map Numerical Labels to Crop Names

In [None]:
# Define the mapping from numerical labels to crop names
label_mapping = {
    0: "BARLEY",
    1: "CANOLA",
    2: "CORN",
    3: "MIXEDWOOD",
    4: "OAT",
    5: "ORCHARD",
    6: "PASTURE",
    7: "POTATO",
    8: "SOYBEAN",
    9: "SPRING_WHEAT"
}

# Create a new column 'crop' by mapping the 'label' column
df['crop'] = df['label'].map(label_mapping)

# Verify the mapping by displaying a few rows
df[['label', 'crop']].head()

### 2.4. Analyze Class Distribution

In [None]:
# Count the number of examples for each crop class
class_counts = df['crop'].value_counts().sort_index()

# Display the counts
print("Class Distribution:")
print(class_counts)

# Plot the class distribution using a bar chart
plt.figure(figsize=(10,6))
sns.barplot(x=class_counts.index, y=class_counts.values, hue=class_counts.index, palette="viridis", legend=False)
plt.title("Distribution of Crop Classes")
plt.xlabel("Crop Type")
plt.ylabel("Number of Examples")
plt.xticks(rotation=90)
plt.show()

### 2.5. Sample Images with Labels

In [None]:
# Function to decode and display an image

def display_image(image_data):
  if isinstance(image_data, dict):
  # If the dictionary has a 'bytes' key, use its value as the image data.
    if 'bytes' in image_data:
      image_data = image_data['bytes']
    else:
      raise ValueError("Dictionary does not contain key 'bytes' for image content.")

    # If image_data is stored as bytes, open it using BytesIO.
    if isinstance(image_data, bytes):
      img = Image.open(io.BytesIO(image_data))
    else:
      raise ValueError(f"Unsupported image data type: {type(image_data)}")

    return img

In [None]:
grouped_by_crop = df.groupby('crop')
num_images_per_crop = 5

for crop_name, crop_group in grouped_by_crop:
    print(f"\nSample Images for {crop_name}:")
    sample_images = crop_group.sample(min(num_images_per_crop, len(crop_group)), random_state=42)

    # Create a grid of subplots
    fig, axes = plt.subplots(1, num_images_per_crop, figsize=(10, 5))

    # Ensure axes is a list even if num_images_per_crop is 1
    if num_images_per_crop == 1:
        axes = [axes]

    # Display the sampled images
    for i, (idx, row) in enumerate(sample_images.iterrows()):
        img = display_image(row['image'])

        if img is not None:
            axes[i].imshow(img)
            axes[i].set_title(f"Index: {idx}")
            axes[i].axis("off")

    plt.tight_layout()
    plt.show()

### 2.6. Image Dimension Analysis

In [None]:
# Analyze image dimensions for a random sample (e.g., 100 images)
def get_image_size(image_data):
    if isinstance(image_data, dict):
        if 'bytes' in image_data:
            image_data = image_data['bytes']
        else:
            return (np.nan, np.nan)

    if isinstance(image_data, bytes):
        try:
            img = Image.open(io.BytesIO(image_data))
            return img.size
        except Exception as e:
            return (np.nan, np.nan)
    else:
        return (np.nan, np.nan)

# Sample 100 images to avoid processing the entire dataset
sample_df = df.sample(100, random_state=42)
dimensions = sample_df['image'].apply(get_image_size)

# Convert the list of tuples into two separate lists for widths and heights
widths = [dim[0] for dim in dimensions if not np.isnan(dim[0])]
heights = [dim[1] for dim in dimensions if not np.isnan(dim[1])]

# Display summary statistics side-by-side
width_stats = pd.Series(widths).describe()
height_stats = pd.Series(heights).describe()

# Concatenate the two series horizontally
stats_df = pd.concat([width_stats, height_stats], axis=1)
stats_df.columns = ['Width Statistics', 'Height Statistics']
print(stats_df)

## 3. Data Preparation

### 3.1 Import Libraries

In [None]:
%%capture
!pip install pytorch-lightning

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE

### 3.2. Data Cleaning

In [None]:
# Remove Missing Values
df_clean = df.dropna(subset=['image', 'label']).copy()
print(f"Rows after dropping missing values: {df_clean.shape[0]}")

# Validate Labels
valid_labels = set(label_mapping.keys())

df_clean = df_clean[df_clean['label'].isin(valid_labels)]
print(f"Rows after filtering invalid labels: {df_clean.shape[0]}")

# Function to decode and check the image
def process_image_data(image_data):
    try:
        if isinstance(image_data, dict):
            if 'bytes' in image_data:
                image_data = image_data['bytes']
            else:
                raise ValueError("Dictionary does not contain key 'bytes' for image content.")

        if isinstance(image_data, bytes):
            img = Image.open(io.BytesIO(image_data))
            img = img.convert("RGB")
            return img
        else:
            raise ValueError(f"Unsupported image data type: {type(image_data)}")
    except Exception as e:
        return None

df_clean['processed_image'] = df_clean['image'].apply(process_image_data)

num_invalid = df_clean['processed_image'].isnull().sum()
print(f"Number of invalid or corrupted images detected: {num_invalid}")

df_clean = df_clean.dropna(subset=['processed_image']).copy()
df_clean.drop(columns=['processed_image'], inplace=True)

print(f"Final cleaned dataset has {df_clean.shape[0]} rows.")

### 3.3. Create PyTorch Dataset

In [None]:
class CropDataset(Dataset):
    def __init__(self, df_clean, transform=None):
        self.df_clean = df_clean
        self.transform = transform

    def __len__(self):
        return len(self.df_clean)

    def __getitem__(self, idx):
        row = self.df_clean.iloc[idx]
        img_data = row['image']['bytes']
        image = Image.open(io.BytesIO(img_data)).convert('RGB')
        label = row['label']

        if self.transform:
            image = self.transform(image)

        return image, label

### 3.4. Data Augmentation

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

### 3.5. Data Splitting

In [None]:
# Split dataset into train (70%), val (15%), test (15%)
train_df, temp_df = train_test_split(
    df_clean,
    test_size=0.3,
    stratify=df['label'],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df['label'],
    random_state=42
)

# Create datasets
train_dataset = CropDataset(train_df, transform=train_transform)
val_dataset = CropDataset(val_df, transform=val_transform)
test_dataset = CropDataset(test_df, transform=val_transform)

# Class balance check
print("Class distribution:")
print("Train samples:", train_df.shape[0])
print("Validation samples:", val_df.shape[0])
print("Test samples:", test_df.shape[0])

# Create dataloaders
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

## 4. Model Selection & Training

### 4.1. Model Definition

In [None]:
class CropClassifierCustom(pl.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv_layers = nn.Sequential(
            # Input: 3 x 65 x 65
            # 32 x 65 x 65
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # 32 x 32 x 32
            nn.MaxPool2d(2),

            # 64 x 32 x 32
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # 64 x 16 x 16
            nn.MaxPool2d(2),

            # 128 x 16 x 16
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # 128 x 8 x 8
            nn.MaxPool2d(2),

            # 256 x 8 x 8
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # 256 x 4 x 4
            nn.MaxPool2d(2)
        )
        self.fc_layers = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.conv_layers(x)
        # For feature analysis, we can save the last convolutional feature map.
        self.features = x
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy_score(y.cpu(), preds.cpu())
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return {"loss": loss, "preds": preds, "targets": y}

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy_score(y.cpu(), preds.cpu())
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return {"loss": loss, "preds": preds, "targets": y}

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

### 4.2. Model Training

In [None]:
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='min'
)

trainer = pl.Trainer(
    max_epochs=20,
    callbacks=[early_stop],
    accelerator='auto',
    devices=1 if torch.cuda.is_available() else None,
    deterministic=True
)

model = CropClassifierCustom(num_classes=10)
trainer.fit(model, train_loader, val_loader)

### 4.3. Model Evaluation

In [None]:
# Evaluate on validation set
val_results = trainer.validate(model, val_loader)
print("Validation results:", val_results)

## 5. Evaluation on Test Set with Visualization

In [None]:
# Evaluate on test set
test_results = trainer.test(model, test_loader)
print("Test results:", test_results)

# Confusion Matrix on Test Set
all_preds = []
all_targets = []

model.eval()
with torch.no_grad():
    for batch in test_loader:
        x, y = batch
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(y.cpu().numpy())

cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=list(range(10)), yticklabels=list(range(10)))
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix on Test Set")
plt.show()

## 6. Feature Analysis using t-SNE

In [None]:
# Extract features from a subset of the test set for visualization
features_list = []
labels_list = []

model.eval()
with torch.no_grad():
    for batch in test_loader:
        x, y = batch
        _ = model(x)
        # Forward pass to update self.features

        # Flatten the spatial dimensions for each sample
        features = model.features.view(x.size(0), -1)
        features_list.append(features.cpu())
        labels_list.append(y.cpu())

features_all = torch.cat(features_list).numpy()
labels_all = torch.cat(labels_list).numpy()

# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42)
features_2d = tsne.fit_transform(features_all)

# Plot t-SNE result
plt.figure(figsize=(10, 8))
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels_all, cmap='tab10', alpha=0.7)
plt.legend(*scatter.legend_elements(), title="Classes")
plt.title("t-SNE Visualization of Deep Features")
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.show()