In [None]:
%pip install torch torchvision matplotlib

In [29]:
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [30]:
num_classes = 7  # DANGER
model = models.resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load("resnet50_model.pth", map_location=device))
model.to(device)
model.eval()  # Set model to evaluation mode
class_names = ['battery', 'cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [31]:
def classify_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension

    # Predict
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        class_idx = predicted.item()
        class_name = class_names[class_idx]

    print(f"Predicted Class: {class_name}", end=' | ')

In [32]:
# Test the function
image_path = "./custom-test-images/amazon-cardboard-box-03.jpg"  # Replace with your image path
if os.path.exists(image_path):
    classify_image(image_path)
    print('expected: cardboard')
else:
    print("Image not found.")


# Test the function
image_path = "./custom-test-images/coke2.webp"  # Replace with your image path
if os.path.exists(image_path):
    classify_image(image_path)
    print('expected: metal')
else:
    print("Image not found.")

# Test the function
image_path = "./custom-test-images/dietcoke.jpg"  # Replace with your image path
if os.path.exists(image_path):
    classify_image(image_path)
    print('expected: metal')
else:
    print("Image not found.")

# Test the function
image_path = "./custom-test-images/dietcoke-removebg-preview.png"  # Replace with your image path
if os.path.exists(image_path):
    classify_image(image_path)
    print('expected: metal')
else:
    print("Image not found.")

# Test the function
image_path = "./custom-test-images/mugs.jpg"  # Replace with your image path
if os.path.exists(image_path):
    classify_image(image_path)
    print('expected: glass')
else:
    print("Image not found.")

image_path = "./custom-test-images/Duracell_AA__49529.jpg"  # Replace with your image path
if os.path.exists(image_path):
    classify_image(image_path)
    print('expected: battery')
else:
    print("Image not found.")

Predicted Class: cardboard | expected: cardboard
Predicted Class: battery | expected: metal
Predicted Class: metal | expected: metal
Predicted Class: metal | expected: metal
Predicted Class: battery | expected: glass
Predicted Class: battery | expected: battery
