In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import json
from PIL import Image
import os

In [24]:
with open("src/labeled_data.json", "r") as f:
    labeled_data = json.load(f)

# Extract unique labels and create a mapping
all_labels = sorted({label for item in labeled_data for label in item["labels"]})
label_to_idx = {label: i for i, label in enumerate(all_labels)}
print(all_labels)

['burnt', 'casserole', 'coffee', 'cups', 'dirty', 'dutch', 'food', 'oven', 'pan', 'residue', 'stains', 'wok']


In [6]:
class DishDataset(Dataset):
    def __init__(self, labeled_data, image_dir, transform=None):
        self.labeled_data = labeled_data
        self.image_dir = image_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.labeled_data)
    
    def __getitem__(self, idx):
        item = self.labeled_data[idx]
        img_path = os.path.join(self.image_dir, item["filename"])
        image = Image.open(img_path).convert("RGB")
        
        # Create multi-hot encoded label vector
        label_vector = torch.zeros(len(all_labels))
        for label in item["labels"]:
            label_vector[label_to_idx[label]] = 1.0
        
        if self.transform:
            image = self.transform(image)
        
        return image, label_vector

In [7]:
# Define image transforms (resize, normalize, augment)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for EfficientNet
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [13]:
image_dir = "src/dataset/dirty_dishes"
dataset = DishDataset(labeled_data, image_dir, transform=transform)

In [14]:
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)

In [16]:
model = models.efficientnet_b0(pretrained=True)

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /Users/pbanavara/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20.5M/20.5M [00:00<00:00, 33.3MB/s]


In [17]:
num_classes = len(all_labels)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
model = model.to(device)


In [19]:
# Define loss function (multi-label classification)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

In [20]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)  # Logits
        loss = criterion(outputs, labels)  # BCEWithLogitsLoss
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")

print("Training complete!")

Epoch 1/10, Loss: 0.3365
Epoch 2/10, Loss: 0.1568
Epoch 3/10, Loss: 0.1089
Epoch 4/10, Loss: 0.0928
Epoch 5/10, Loss: 0.0801
Epoch 6/10, Loss: 0.0755
Epoch 7/10, Loss: 0.0703
Epoch 8/10, Loss: 0.0668
Epoch 9/10, Loss: 0.0617
Epoch 10/10, Loss: 0.0613
Training complete!


In [21]:
torch.save(model.state_dict(), "efficientnet_multilabel.pth")

In [22]:
def predict(image_path, model, transform, threshold=0.5):
    model.eval()
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        logits = model(image)
        probs = torch.sigmoid(logits).squeeze(0)  # Convert logits to probabilities
    
    predicted_labels = [all_labels[i] for i, prob in enumerate(probs) if prob > threshold]
    
    return predicted_labels

In [23]:
model.load_state_dict(torch.load("efficientnet_multilabel.pth"))
model.to(device)

# Example prediction
image_path = "test_sink_image.png"
predicted_labels = predict(image_path, model, transform)
print("Predicted Labels:", predicted_labels)

Predicted Labels: ['stains']
