In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
from datasets import load_dataset
import wandb

In [None]:
# Load the dataset
ds = load_dataset("garythung/trashnet")

In [None]:
print(ds)

In [None]:
wandb.login()

In [None]:
save_dir = './data_trash'
os.makedirs(save_dir, exist_ok=True)

label_names = ds['train'].features['label'].names
for label_name in label_names:
    os.makedirs(os.path.join(save_dir, label_name), exist_ok=True)

for i, example in enumerate(ds['train']):
    image = example['image']
    label = label_names[example['label']]

    # Specify the path to save the image.
    image_path = os.path.join(save_dir, label, f"image_{i}.jpg")
    image.save(image_path)

In [None]:
# Define the dataset class
class TrashDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.classes = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
        
        # Load images and labels
        for label_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(image_dir, class_name)
            for image_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, image_name))
                self.labels.append(label_idx)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 56 * 56, 128),
            nn.ReLU(),
            nn.Linear(128, 6)  # 6 classes
        )
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.fc(x)
        return x

In [None]:
# Initialize wandb
wandb.init(project="simple-trash-classification")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# Create dataset and dataloader
dataset = TrashDataset(image_dir="./data_trash", transform=transform)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize model, loss function, and optimizer
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Prediction function
def predict_image(image_path, model, device):
    classes = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        _, predicted = torch.max(outputs, 1)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
    
    predicted_class = classes[predicted.item()]
    confidence = probabilities[predicted].item() * 100
    print(f"\nPredicted class: {predicted_class}")
    print(f"Confidence: {confidence:.2f}%\n")
    print("Class Probabilities:")
    for i, prob in enumerate(probabilities):
        print(f"{classes[i]}: {prob.item() * 100:.2f}%")

In [None]:
# Main script
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = TrashDataset(image_dir="./data_trash", transform=transform)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    num_epochs = 10
    history = {'loss': [], 'accuracy': []}
    
    for epoch in range(num_epochs):
        loss, acc = train_epoch(model, train_loader, criterion, optimizer, device)
        history['loss'].append(loss)
        history['accuracy'].append(acc)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}, Accuracy: {acc:.2f}%")
    
    print("\nModel training complete. Ready to predict.")
    
    # Prediction
    test_image_path = "test_image.jpg"  # Replace with your test image path
    predict_image(test_image_path, model, device)