In [66]:
import os

import kaggle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision.transforms import v2
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import shap

seed = 67

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print("done")


done


In [67]:
# Load Dataset Metadata
metadata_path = "./Garbage_data/Garbage_Dataset_Classification/metadata.csv"
metadata_df = pd.read_csv(metadata_path)


def get_file_path(metadata):
    image_folder = "./Garbage_data/Garbage_Dataset_Classification/images/"
    return os.path.join(image_folder, metadata["label"], metadata["filename"])


class GarbageDataset(Dataset):
    def __init__(self, metadata_df, transform=None):
        self.metadata_df = metadata_df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        metadata = self.metadata_df.iloc[idx]
        img_path = get_file_path(metadata)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = metadata["label"]

        if self.transform:
            image = self.transform(image)

        label_idx = label_to_idx[label]
        return image, label_idx


split_seed = 13

metadata_train_df, metadata_val_df = train_test_split(
    metadata_df, random_state=split_seed, stratify=metadata_df["label"].values
)

labels = sorted(metadata_df["label"].unique())
label_to_idx = {label: idx for idx, label in enumerate(labels)}
idx_to_label = {idx: label for label, idx in label_to_idx.items()}

# Data Augmentation & Preprocessing
transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.RandomResizedCrop(size=(256, 256), antialias=True),
        v2.RandomPhotometricDistort(p=0.5),
        v2.RandomHorizontalFlip(p=0.5),
        v2.RandomVerticalFlip(p=0.5),
        v2.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ]
)

test_transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Resize(256, antialias=True),
        v2.CenterCrop(256),
        v2.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ]
)



trainset = GarbageDataset(metadata_train_df, transform=transform)
testset = GarbageDataset(metadata_val_df, transform=test_transform)

trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=64, shuffle=False)
# change device (for mac use mainly)

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


device = torch.device("cpu")
print(f"Using device: {device}")

Using device: cpu


In [68]:
# load model definitions 
class CustomCNN(nn.Module):
    def __init__(self, num_classes=6):
        super(CustomCNN, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # -> 128x128
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # -> 64x64
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # -> 32x32
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 32 * 32, 256),
            nn.ReLU(),
            nn.Dropout(0.15),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


class AxNet(nn.Module):
    def __init__(self, num_classes, dropout_rate=0):
        super(AxNet, self).__init__()
        self.num_classes = num_classes

        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 63->31
            nn.Conv2d(96, 256, kernel_size=5, padding=2, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 31->15
            nn.Conv2d(256, 384, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),  # 15-> 15
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),  # 15-> 15
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 15-> 7
        )

        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(256 * 7 * 7, 1024),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


# Define Ensemble Model to combine both LeNet and CustomCNN
class EnsembleCNN(nn.Module):
    def __init__(self, model1, model2):
        super(EnsembleCNN, self).__init__()
        self.model1 = model1
        self.model2 = model2

    def forward(self, x):
        output1 = self.model1(x)
        output2 = self.model2(x)
        ensemble_output = (output1 + output2) / 2  # Averaging logits
        return ensemble_output
    
print("done")


done


In [69]:
# load model 
customcnn_model = CustomCNN(num_classes=6).to(device)
Ax_model = AxNet(num_classes=6, dropout_rate=0.15).to(device)
# Create the ensemble model
ensemble_model = EnsembleCNN(Ax_model, customcnn_model).to(device)
PATH = "Ensemble_model_states.pth"
ensemble_model.load_state_dict(torch.load(PATH, map_location=device))



<All keys matched successfully>

In [None]:
ensemble_model.eval()
all_preds = []
all_labels = []

# get preds 
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = ensemble_model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)

# reset def of labels
labels = sorted(metadata_df["label"].unique())

print(classification_report(all_labels, all_preds, target_names=labels))

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d',
            cmap="Blues", xticklabels=labels,
            yticklabels=labels)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()
    


In [None]:
num_classes = len(labels)
correct = np.zeros(num_classes)
total   = np.zeros(num_classes)

all_images = []
all_labels = []
all_preds  = []

ensemble_model.eval()

with torch.no_grad():
    # preds 
    for images, labels_batch in testloader:
        images, labels_batch = images.to(device), labels_batch.to(device)
        outputs = ensemble_model(images)
        _, preds = outputs.max(1)

        # store imgs and vals for later
        all_images.extend(images.cpu())
        all_labels.extend(labels_batch.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

        true_labels = labels_batch.cpu().numpy()
        pred_labels = preds.cpu().numpy()

        for i in range(len(true_labels)):
            total[true_labels[i]] += 1     
            if true_labels[i] == pred_labels[i]  :
                correct[true_labels[i]] += 1  

acc = (correct / total) * 100



In [None]:
plt.figure(figsize=(8, 6))
ax = sns.barplot(x=list(labels), y=acc.tolist())
for container in ax.containers:
    ax.bar_label(container, fmt='%.1f')
ax.set_ylim(70, 100)
plt.ylabel("Accuracy")
plt.xlabel("Class")
plt.title("Per-Class Accuracy")
plt.show()


In [None]:
misclassified = []

for i in range(len(all_labels)):
    if all_labels[i] != all_preds[i]:
        misclassified.append((all_images[i], all_labels[i], all_preds[i]))


num_show = 8
plt.figure(figsize=(10, 14))
for i in range(num_show):
    
    img, true_lbl, pred_lbl = misclassified[i]
    
    # change to numpy to print with plt 
    img = img.permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())  # normalize for display

    plt.subplot(4, 2, i + 1)
    plt.imshow(img)
    plt.title(f"True: {labels[true_lbl]}\nPred: {labels[pred_lbl]}")
    plt.axis("off")

plt.show()

In [None]:
# predict function for shap
def predict(images):
    ensemble_model.eval()
    with torch.no_grad():
        return ensemble_model(images.to(device)).detach().cpu().numpy()
    
background = []
for i in range(20):
    img, _ = trainset[i]
    background.append(img)
background = torch.stack(background).to(device)

test_images = []
test_labels = []

for i in range(5):    
    img, label = testset[i]
    test_images.append(img)
    test_labels.append(label)

test_images = torch.stack(test_images).to(device)

In [None]:
explainer = shap.GradientExplainer(predict, background)

shap_values = explainer.shap_values(test_images)
shap.image_plot(shap_values, test_images.cpu().numpy())