In [None]:
#%pip install torch torchvision matplotlib
# CHECK CLASS NAMES BEFORE RUNNING NEW MODELS
# push comment

In [89]:
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [92]:
num_classes = 8  # DANGER
model = models.resnet50()
# model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(p=0.2),
    torch.nn.Linear(model.fc.in_features, num_classes)
) 
model.load_state_dict(torch.load("resnet50_model5.pth", map_location=device))
model.to(device)
model.eval()  # Set model to evaluation mode
class_names = ['battery', 'cardboard', 'glass', 'metal', 'paper', 'plastic', 'syringe', 'trash']

model_dense = models.densenet121()
# model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model_dense.classifier = torch.nn.Sequential(
    torch.nn.Dropout(p=0.2),  # Apply dropout for regularization
    torch.nn.Linear(model_dense.classifier.in_features, num_classes)
)
model_dense.load_state_dict(torch.load("densenet121_model1.pth", map_location=device))
model_dense.to(device)
model_dense.eval()  # Set model to evaluation mode
class_names = ['battery', 'cardboard', 'glass', 'metal', 'paper', 'plastic', 'syringe', 'trash']

num_classes_old = 6  # DANGER
model_old = models.resnet50()
model_old.fc = torch.nn.Linear(model_old.fc.in_features, num_classes_old)
model_old.load_state_dict(torch.load("resnet50_model_old.pth", map_location=device))
model_old.to(device)
model_old.eval()  # Set model to evaluation mode
class_names_old = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])



In [94]:
def classify_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension

    # Predict
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        class_idx = predicted.item()
        class_name = class_names[class_idx]

    print(f"Predicted Class: {class_name}", end=' | ')

def classify_image_old(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension

    # Predict
    with torch.no_grad():
        outputs = model_old(image)
        _, predicted = torch.max(outputs, 1)
        class_idx = predicted.item()
        class_name = class_names_old[class_idx]

    print(f"Old Predicted Class: {class_name}", end=' | ')

def classify_dense(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension

    # Predict
    with torch.no_grad():
        outputs = model_old(image)
        _, predicted = torch.max(outputs, 1)
        class_idx = predicted.item()
        class_name = class_names_old[class_idx]

    print(f"Dense Predicted Class: {class_name}", end=' | ')

In [96]:
image_path = "./custom-test-images/amazon-cardboard-box-03.jpg"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: cardboard')


Predicted Class: cardboard | Old Predicted Class: cardboard | Dense Predicted Class: cardboard | expected: cardboard


In [97]:
image_path = "./custom-test-images/coke2.webp"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: metal')


Predicted Class: battery | Old Predicted Class: metal | Dense Predicted Class: metal | expected: metal


In [98]:
image_path = "./custom-test-images/dietcoke.jpg"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: metal')


Predicted Class: metal | Old Predicted Class: cardboard | Dense Predicted Class: cardboard | expected: metal


In [99]:
image_path = "./custom-test-images/dietcoke-removebg-preview.png"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: metal')


Predicted Class: battery | Old Predicted Class: metal | Dense Predicted Class: metal | expected: metal


In [100]:
image_path = "./custom-test-images/mugs.jpg"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: glass')


Predicted Class: battery | Old Predicted Class: glass | Dense Predicted Class: glass | expected: glass


In [101]:
image_path = "./custom-test-images/Duracell_AA__49529.jpg"
classify_image(image_path)
classify_dense(image_path)
print('expected: battery')

Predicted Class: battery | Dense Predicted Class: metal | expected: battery


In [102]:
image_path = "./custom-test-images/wine-glass.jpg"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: glass')

Predicted Class: glass | Old Predicted Class: glass | Dense Predicted Class: glass | expected: glass


In [103]:
image_path = "./custom-test-images/glass-container.jpg"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: glass')

Predicted Class: metal | Old Predicted Class: metal | Dense Predicted Class: metal | expected: glass


In [104]:
image_path = "./custom-test-images/sprite.png"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: glass')

Predicted Class: glass | Old Predicted Class: glass | Dense Predicted Class: glass | expected: glass


In [105]:
image_path = "./custom-test-images/fanta.webp"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: metal')

Predicted Class: battery | Old Predicted Class: metal | Dense Predicted Class: metal | expected: metal


In [106]:
image_path = "./custom-test-images/pie.jpg"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: metal')

Predicted Class: metal | Old Predicted Class: metal | Dense Predicted Class: metal | expected: metal


In [107]:
image_path = "./custom-test-images/root-beer.webp"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: glass')

Predicted Class: glass | Old Predicted Class: glass | Dense Predicted Class: glass | expected: glass


In [108]:
image_path = "./custom-test-images/syringe.webp"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: syringe')

Predicted Class: syringe | Old Predicted Class: glass | Dense Predicted Class: glass | expected: syringe


In [109]:
image_path = "./custom-test-images/syringe_with_bg.jpg"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: syringe')

Predicted Class: syringe | Old Predicted Class: cardboard | Dense Predicted Class: cardboard | expected: syringe


In [110]:
image_path = "./custom-test-images/water-glass-removebg-preview.png"
classify_image(image_path)
classify_image_old(image_path)
classify_dense(image_path)
print('expected: glass')

Predicted Class: plastic | Old Predicted Class: metal | Dense Predicted Class: metal | expected: glass
