In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        # print(os.path.join(dirname, filename))
        pass

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import pandas as pd

import copy
import torch
import torch.nn as nn
from PIL import Image
import torch.optim as optim
from tqdm.notebook import tqdm
import torch.nn.functional as F
from ipywidgets import FileUpload
import torchvision.transforms as T
from IPython.display import display
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.io import read_image
from torchvision.datasets import ImageFolder

In [None]:
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split

# Directory (single folder for all data)
data_dir = "/kaggle/input/pumpkin-leaf-diseases-dataset-from-bangladesh/Pumpkin Leaf Diseases Dataset From Bangladesh/Original Dataset"

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Load full dataset from one folder
full_dataset = ImageFolder(data_dir, transform=transform)

# Split into train and validation sets (80/20)
train_size = int(0.75 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [None]:
# torch.cuda.is_available()

In [None]:
class AlexNet_simplified(nn.Module):
    def __init__(self, num_classes=9, exit_threshold=0.90):
        super(AlexNet_simplified, self).__init__()
        self.exit_threshold = exit_threshold

        # Conv Layer 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, 2)
        )

        # Conv Layer 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2)
        )

        self.exit1 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, num_classes)
        )

        # Conv Layer 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU()
        )

        self.exit2 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(384, num_classes)
        )

        # Conv Layer 4
        self.conv4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU()
        )

        self.exit3 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(384, num_classes)
        )

        # Conv Layer 5
        self.conv5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        # Final Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x, inference=False):
        x = self.conv1(x)
        x = self.conv2(x)
        out1 = self.exit1(x)

        x = self.conv3(x)
        out2 = self.exit2(x)

        x = self.conv4(x)
        out3 = self.exit3(x)

        x = self.conv5(x)
        out_final = self.classifier(x)

        if inference:
            conf1 = F.softmax(out1, dim=1).max(1).values
            if conf1.item() >= self.exit_threshold:
                return out1, "Exit1"

            conf2 = F.softmax(out2, dim=1).max(1).values
            if conf2.item() >= self.exit_threshold:
                return out2, "Exit2"

            conf3 = F.softmax(out3, dim=1).max(1).values
            if conf3.item() >= self.exit_threshold:
                return out3, "Exit3"

            return out_final, "Final"

        # During training, return all outputs
        return out1, out2, out3, out_final

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AlexNet_simplified(num_classes=5,exit_threshold=0.8).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=2, verbose=True
)


In [None]:
train_accuracies = []
val_accuracies = []
train_losses = []
val_losses = []

num_epochs = 30
patience = 5
best_val_acc = 0
early_stop_counter = 0
best_model_wts = copy.deepcopy(model.state_dict())

# Set weights for losses
α1, α2, α3, α4 = 0.2, 0.2, 0.2, 0.4  # You can tune this

for epoch in range(num_epochs):
    model.train()
    train_loss, train_correct = 0, 0

    for images, labels in tqdm(train_loader, desc="training loop"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        out1, out2, out3, out_final = model(images, inference=False)

        # Individual losses
        loss1 = criterion(out1, labels)
        loss2 = criterion(out2, labels)
        loss3 = criterion(out3, labels)
        loss_final = criterion(out_final, labels)

        # Combined weighted loss
        loss = α1 * loss1 + α2 * loss2 + α3 * loss3 + α4 * loss_final
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        train_correct += (out_final.argmax(1) == labels).sum().item()

    train_loss /= len(train_loader.dataset)
    train_acc = train_correct / len(train_loader.dataset)


    # Validation
    model.eval()
    val_loss, val_correct = 0, 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="validation loop"):
            images, labels = images.to(device), labels.to(device)
            _, _, _, outputs = model(images, inference=False)
            loss = criterion(outputs, labels)
    
            val_loss += loss.item() * images.size(0)
            val_correct += (outputs.argmax(1) == labels).sum().item()
            
    val_loss /= len(val_loader.dataset)
    val_acc = val_correct / len(val_loader.dataset)

    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    scheduler.step(val_loss)

    print(f"Epoch {epoch+1}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Val Loss: {val_loss:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("Early stopping triggered.")
            break

model.load_state_dict(best_model_wts)


In [None]:
# model = AlexNet_simplified(num_classes=5).to(device)
# model.load_state_dict(torch.load("/kaggle/input/alexnet_simplified/pytorch/simpler/1/alexnet_simplified_ee1_0.5_30.pth", map_location=device))
# model.eval()

# exit_counts = {"Exit1": 0, "Exit2": 0, "Exit3": 0, "Final": 0}
# correct = 0
# total = 0

# with torch.no_grad():
#     for images, labels in tqdm(val_loader, desc="early exit inference"):
#         images, labels = images.to(device), labels.to(device)
#         for i in range(images.size(0)):
#             img = images[i].unsqueeze(0)  # (1, 3, H, W)
#             label = labels[i].unsqueeze(0)
#             output, exit_name = model(img, inference=True)
#             pred = torch.argmax(output, dim=1)

#             # Print prediction and exit point (exclude final if you want)
#             if exit_name != "Final":
#                 print(f"Predicted: {pred.item()} | Exited at {exit_name}")

#             correct += (pred == label).sum().item()
#             total += 1
#             exit_counts[exit_name] += 1

# print(f"\n✅ Early Exit Accuracy: {correct/total:.4f}")
# print(f"📊 Exit Distribution: {exit_counts}")

In [None]:
model.eval()

exit_counts = {"Exit1": 0, "Exit2": 0, "Exit3": 0, "Final": 0}
correct = correct1 = correct2 = correct3 = correct4 = 0
wrong1 = wrong2 = wrong3 = wrong4 = 0
total = 0

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="early exit inference"):
        images, labels = images.to(device), labels.to(device)
        for i in range(images.size(0)):
            img = images[i].unsqueeze(0)  # (1, 3, H, W)
            label = labels[i].unsqueeze(0)
            output, exit_name = model(img, inference=True)
            pred = torch.argmax(output, dim=1)

            # Print prediction and exit point (exclude final if you want)
            # if exit_name != "Final":
            #     print(f"Predicted: {pred.item()} | Exited at {exit_name}")

            if exit_name == "Exit1":
                if pred == label:
                    correct1 += 1

            if exit_name == "Exit2":
                if pred == label:
                    correct2 += 1

            if exit_name == "Exit3":
                if pred == label:
                    correct3 += 1

            if exit_name == "Final":
                if pred == label:
                    correct4 += 1
                
            correct += (pred == label).sum().item()
            total += 1
            exit_counts[exit_name] += 1

print("Pumpkin-0.8-final")
print(f"\n✅ Early Exit Accuracy: {correct/total:.4f}")
print("\n✅ Early Exit Block 1 Accuracy: ",correct1/exit_counts["Exit1"])
print("\n✅ Early Exit Block 2 Accuracy: ",correct2/exit_counts["Exit2"])
print("\n✅ Early Exit Block 3 Accuracy: ",correct3/exit_counts["Exit3"])
print("\n✅ Final Early Exit Block Accuracy: ",correct4/exit_counts["Final"])
print(f"\n📊 Exit Distribution: {exit_counts}")

# Calculate accuracy percentages
accuracies = [
    (correct1 / exit_counts["Exit1"]) * 100,
    (correct2 / exit_counts["Exit2"]) * 100,
    (correct3 / exit_counts["Exit3"]) * 100,
    (correct4 / exit_counts["Final"]) * 100
]

plt.figure(figsize=(6, 4))
plt.bar(["exit_1", "exit_2", "exit_3", "exit_final"], accuracies, color='skyblue')
plt.xlabel("Exit Points")
plt.ylabel("Accuracy (%)")
plt.title("Accuracy of Each Block")
plt.ylim(0, 100)  # Set y-axis range from 0 to 100
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 4))
plt.bar(["wrong","correct"],[total-correct,correct], color='skyblue')
plt.xlabel("Exit Points")
plt.ylabel("Number of Samples")
plt.title("Samples Exiting at Each Point")
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
exit_counts

In [None]:
plt.figure(figsize=(6, 4))
plt.bar(exit_counts.keys(), exit_counts.values(), color='skyblue')
plt.xlabel("Exit Points")
plt.ylabel("Number of Samples")
plt.title("Samples Exiting at Each Point")
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(20,5))

plt.subplot(1,3,1)
plt.plot(train_accuracies)
plt.plot(val_accuracies)
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')

plt.subplot(1,3,2)
plt.plot([x * 100 for x in train_accuracies])
plt.plot([x * 100 for x in val_accuracies])
plt.title('Model Accuracy (0–100%)')
plt.ylabel('Accuracy (%)')
plt.xlabel('Epoch')
plt.ylim(0, 100)
plt.legend(['Train', 'Val'], loc='upper left')

plt.subplot(1,3,3)
plt.plot(train_losses)
plt.plot(val_losses)
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper right')

plt.show()

In [None]:
# Define the same transform as during training
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Load class names from dataset
class_names = full_dataset.classes  # Uses ImageFolder

def predict_image(image_path, model, device, threshold=0.90):
    # Temporarily override threshold
    original_threshold = model.exit_threshold
    model.exit_threshold = threshold

    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    image_tensor = inference_transform(image).unsqueeze(0).to(device)

    # Inference with early exit
    model.eval()
    with torch.no_grad():
        output, exit_point = model(image_tensor, inference=True)
        probabilities = F.softmax(output, dim=1)
        confidence, predicted_class_idx = torch.max(probabilities, 1)

    predicted_class = class_names[predicted_class_idx.item()]
    confidence_score = confidence.item()

    # Restore original threshold
    model.exit_threshold = original_threshold

    # Print and plot
    print(f"✅ Predicted Class: {predicted_class}")
    print(f"🔁 Exit Used: {exit_point}")
    print(f"📈 Confidence: {confidence_score:.4f} (Threshold: {threshold})")

    plt.figure(figsize=(5, 5))
    plt.imshow(image)
    plt.axis('off')
    plt.title(f"{predicted_class} | {exit_point} | {confidence_score:.2f}", fontsize=12)
    plt.show()

In [None]:
# Example usage
image_path = "/kaggle/input/pumpkin-leaf-diseases-dataset-from-bangladesh/Pumpkin Leaf Diseases Dataset From Bangladesh/Original Dataset/Healthy Leaf/Healthy Leaf (120).jpg"
predict_image(image_path, model, device, threshold=0.9)

In [None]:
torch.save(model.state_dict(),"pumpkin_0.80_final.pth")
print("saved")

In [None]:
from IPython.display import FileLink

name = "pumpkin_0.80_final.pth"
torch.save(model.state_dict(),name)
FileLink(name)