In [None]:
%pip install torch torchvision matplotlib

In [28]:
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [29]:
num_classes = 7  # DANGER
model = models.resnet50()
# model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(p=0.2),
    torch.nn.Linear(model.fc.in_features, num_classes)
) 
model.load_state_dict(torch.load("resnet50_model4.pth", map_location=device))
model.to(device)
model.eval()  # Set model to evaluation mode
class_names = ['battery', 'cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

num_classes_old = 6  # DANGER
model_old = models.resnet50()
model_old.fc = torch.nn.Linear(model_old.fc.in_features, num_classes_old)
model_old.load_state_dict(torch.load("resnet50_model_old.pth", map_location=device))
model_old.to(device)
model_old.eval()  # Set model to evaluation mode
class_names_old = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [30]:
def classify_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension

    # Predict
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        class_idx = predicted.item()
        class_name = class_names[class_idx]

    print(f"Predicted Class: {class_name}", end=' | ')

def classify_image_old(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension

    # Predict
    with torch.no_grad():
        outputs = model_old(image)
        _, predicted = torch.max(outputs, 1)
        class_idx = predicted.item()
        class_name = class_names_old[class_idx]

    print(f"Old Predicted Class: {class_name}", end=' | ')

In [31]:
image_path = "./custom-test-images/amazon-cardboard-box-03.jpg"
classify_image(image_path)
classify_image_old(image_path)
print('expected: cardboard')


Predicted Class: cardboard | Old Predicted Class: cardboard | expected: cardboard


In [32]:
image_path = "./custom-test-images/coke2.webp"
classify_image(image_path)
classify_image_old(image_path)
print('expected: metal')


Predicted Class: battery | Old Predicted Class: metal | expected: metal


In [33]:
image_path = "./custom-test-images/dietcoke.jpg"
classify_image(image_path)
classify_image_old(image_path)
print('expected: metal')


Predicted Class: metal | Old Predicted Class: cardboard | expected: metal


In [34]:
image_path = "./custom-test-images/dietcoke-removebg-preview.png"
classify_image(image_path)
classify_image_old(image_path)
print('expected: metal')


Predicted Class: battery | Old Predicted Class: metal | expected: metal


In [35]:
image_path = "./custom-test-images/mugs.jpg"
classify_image(image_path)
classify_image_old(image_path)
print('expected: glass')


Predicted Class: battery | Old Predicted Class: glass | expected: glass


In [36]:
image_path = "./custom-test-images/Duracell_AA__49529.jpg"
classify_image(image_path)
print('expected: battery')

Predicted Class: battery | expected: battery


In [37]:
image_path = "./custom-test-images/wine-glass.jpg"
classify_image(image_path)
classify_image_old(image_path)
print('expected: glass')

Predicted Class: glass | Old Predicted Class: glass | expected: glass


In [38]:
image_path = "./custom-test-images/glass-container.jpg"
classify_image(image_path)
classify_image_old(image_path)
print('expected: glass')

Predicted Class: metal | Old Predicted Class: metal | expected: glass


In [39]:
image_path = "./custom-test-images/sprite.png"
classify_image(image_path)
classify_image_old(image_path)
print('expected: glass')

Predicted Class: glass | Old Predicted Class: glass | expected: glass


In [40]:
image_path = "./custom-test-images/fanta.webp"
classify_image(image_path)
classify_image_old(image_path)
print('expected: metal')

Predicted Class: battery | Old Predicted Class: metal | expected: metal


In [41]:
image_path = "./custom-test-images/pie.jpg"
classify_image(image_path)
classify_image_old(image_path)
print('expected: metal')

Predicted Class: metal | Old Predicted Class: metal | expected: metal


In [42]:
image_path = "./custom-test-images/root-beer.webp"
classify_image(image_path)
classify_image_old(image_path)
print('expected: glass')

Predicted Class: glass | Old Predicted Class: glass | expected: glass
