In [1]:
import os
import wandb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision
from torchvision import transforms, datasets, models

from sklearn.utils.class_weight import compute_class_weight

from collections import Counter

# Check if GPU is available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

cuda:0


In [2]:
trainset_dir = 'data/enel645_2024f/garbage_data/CVPR_2024_dataset_Train'
valset_dir = 'data/enel645_2024f/garbage_data/CVPR_2024_dataset_Val'
testset_dir = 'data/enel645_2024f/garbage_data/CVPR_2024_dataset_Test'


In [4]:

# Function to build vocabulary from text descriptions
def build_vocab(texts, min_freq=1):
    counter = Counter()
    for text in texts:
        tokens = text.split()
        counter.update(tokens)
    # Keep tokens with frequency >= min_freq
    vocab = {word for word, freq in counter.items() if freq >= min_freq}
    # Build word_to_idx mapping, reserve indices for PAD and UNK tokens
    word_to_idx = {'<PAD>': 0, '<UNK>': 1}
    for idx, word in enumerate(sorted(vocab), start=2):
        word_to_idx[word] = idx
    idx_to_word = {idx: word for word, idx in word_to_idx.items()}
    return word_to_idx, idx_to_word

# Extract images, labels, and text descriptions from a given folder
def extract_data_from_folders(base_dir):
    data = []

    # Traverse through each subfolder
    for label_folder in os.listdir(base_dir):
        folder_path = os.path.join(base_dir, label_folder)

        # Check if it's a directory
        if os.path.isdir(folder_path):
            # Loop through each image file in the subfolder
            for filename in os.listdir(folder_path):
                if filename.endswith(('.jpg', '.png', '.jpeg')):  # Filter image files
                    image_path = os.path.join(folder_path, filename)

                    # Extract text from filename (remove file extension)
                    text_description = os.path.splitext(filename)[0]

                    # Append image path, text, and label to the data list
                    data.append({
                        'image_path': image_path,
                        'text_description': text_description,
                        'label': label_folder  # The subfolder name represents the label (bin)
                    })

    # Convert to DataFrame for easy manipulation
    return pd.DataFrame(data)

class GarbageDataset(Dataset):
    def __init__(self, dataframe, image_transform=None, max_len=32, word_to_idx=None, class_to_idx=None):
        self.dataframe = dataframe
        self.image_transform = image_transform
        self.max_len = max_len
        self.word_to_idx = word_to_idx
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Get image path, text description, and label from the dataframe
        img_path = self.dataframe.iloc[idx]['image_path']
        text_desc = self.dataframe.iloc[idx]['text_description']
        label = self.dataframe.iloc[idx]['label']  

        # Load and preprocess the image
        image = Image.open(img_path).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)

        # Tokenize the text description
        tokens = text_desc.split()
        token_ids = [self.word_to_idx.get(token, self.word_to_idx['<UNK>']) for token in tokens]
        # Pad or truncate to max_len
        if len(token_ids) < self.max_len:
            token_ids += [self.word_to_idx['<PAD>']] * (self.max_len - len(token_ids))
        else:
            token_ids = token_ids[:self.max_len]
        token_ids = torch.tensor(token_ids, dtype=torch.long)

        # Convert string label to numeric label using the class mapping
        numeric_label = self.class_to_idx[label]

        # Return the image, token_ids, and numeric label
        return {
            'image': image,
            'token_ids': token_ids,  
            'label': torch.tensor(numeric_label, dtype=torch.long)  
        }

# Define the image model using MobileNetV3-Large
class ImageModel(nn.Module):
    def __init__(self):
        super(ImageModel, self).__init__()
        self.model = models.mobilenet_v3_large(pretrained=True)
        # Remove the last classification layer
        self.model.classifier = nn.Identity()
        # Add a new classifier for feature extraction
        self.feature_extractor = nn.Sequential(
            nn.Linear(960, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )

    def forward(self, x):
        x = self.model(x)
        x = self.feature_extractor(x)
        return x  # Output feature vector of size 256

# Define the text model using Embedding and LSTM
class TextModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=300, hidden_dim=128, padding_idx=0):
        super(TextModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 256)

    def forward(self, x):
        # x is of shape (batch_size, sequence_length)
        x = self.embedding(x)
        # x is now of shape (batch_size, sequence_length, embedding_dim)
        _, (h_n, _) = self.lstm(x)
        # h_n is of shape (1, batch_size, hidden_dim)
        x = h_n.squeeze(0)  # Shape: (batch_size, hidden_dim)
        x = self.fc(x)      # Shape: (batch_size, 256)
        return x  # Output feature vector of size 256

class GarbageClassifier(nn.Module):
    def __init__(self, num_classes=4, vocab_size=None, embedding_dim=300, hidden_dim=128, padding_idx=0):
        super(GarbageClassifier, self).__init__()
        # Image feature extraction with MobileNetV3-Large
        self.image_model = ImageModel()
        # Text feature extraction with Embedding + LSTM
        self.text_model = TextModel(vocab_size, embedding_dim, hidden_dim, padding_idx)
        # Fusion and classification layers
        self.fusion = nn.Sequential(
            nn.Linear(256 + 256, 128),  # Combine features
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, image, token_ids):
        # Get image features
        image_features = self.image_model(image)
        # Get text features
        text_features = self.text_model(token_ids)
        # Combined features from both image and text
        combined_features = torch.cat((image_features, text_features), dim=1)
        combined_output = self.fusion(combined_features)

        return combined_output


In [5]:

run = wandb.init(project='garbage-collection')

# Define classes and map them to indices
class_names = ['Green', 'Blue', 'Black', 'TTR']  
class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}
idx_to_class = {idx: class_name for idx, class_name in enumerate(class_names)}

# Extract the data
trainset_df = extract_data_from_folders(trainset_dir)
valset_df = extract_data_from_folders(valset_dir)
testset_df = extract_data_from_folders(testset_dir)

# Build vocabulary from training text descriptions
train_texts = trainset_df['text_description'].tolist()
word_to_idx, idx_to_word = build_vocab(train_texts, min_freq=1)
vocab_size = len(word_to_idx)
print(f'Vocabulary size: {vocab_size}')

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Create datasets
trainset = GarbageDataset(trainset_df, image_transform=transform, class_to_idx=class_to_idx, word_to_idx=word_to_idx)
valset = GarbageDataset(valset_df, image_transform=transform, class_to_idx=class_to_idx, word_to_idx=word_to_idx)
testset = GarbageDataset(testset_df, image_transform=transform, class_to_idx=class_to_idx, word_to_idx=word_to_idx)

batch_size = 64

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=6)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=6)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=6)

class_weights = compute_class_weight('balanced', classes=np.unique(trainset_df['label']), y=trainset_df['label'])
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Initialize the model, loss function, and optimizer
model = GarbageClassifier(num_classes=len(class_names), vocab_size=vocab_size, padding_idx=word_to_idx['<PAD>']).to(device)

# Freeze all parameters in MobileNetV3-Large
for param in model.image_model.parameters():
    param.requires_grad = False

# Unfreeze the last layers of MobileNetV3-Large
for param in model.image_model.model.features[-1].parameters():
    param.requires_grad = True
for param in model.image_model.feature_extractor.parameters():
    param.requires_grad = True

# Optionally, unfreeze BatchNorm layers
for module in model.image_model.modules():
    if isinstance(module, nn.BatchNorm2d):
        for param in module.parameters():
            param.requires_grad = True

# You can decide whether to freeze any layers in the TextModel
# For now, we keep all parameters trainable
# For large datasets, you might want to freeze the embedding layer

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)


wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: tahmidkazi829 (tahmidkazi829-university-of-calgary-in-alberta). Use `wandb login --relogin` to force relogin


Vocabulary size: 10121


Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to C:\Users\tahmi/.cache\torch\hub\checkpoints\mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:02<00:00, 9.44MB/s]


In [None]:

num_epochs = 10
wandb.config = {"epochs": num_epochs, "batch_size": batch_size, "learning_rate": 0.001}

best_val_acc = 0.0  # Variable to track the best validation accuracy

for epoch in range(wandb.config['epochs']):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    print('-' * 10)
    
    # Training phase
    model.train()
    running_loss = 0.0
    running_corrects = 0

    for i, batch in enumerate(trainloader, 0):
        images = batch['image'].to(device)
        token_ids = batch['token_ids'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images, token_ids)
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update running loss and accuracy
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data)

    # Compute epoch loss and accuracy
    epoch_loss = running_loss / len(trainset)
    epoch_acc = running_corrects.double() / len(trainset)

    print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
    wandb.log({"Training Loss": epoch_loss, "Training Accuracy": epoch_acc})
    
    # Validation phase
    model.eval()
    val_running_loss = 0.0
    val_running_corrects = 0

    with torch.no_grad():
        for batch in valloader:
            images = batch['image'].to(device)
            token_ids = batch['token_ids'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            outputs = model(images, token_ids)
            
            # Compute loss
            loss = criterion(outputs, labels)
            
            # Update running loss and accuracy
            val_running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_running_corrects += torch.sum(preds == labels.data)

    # Compute validation loss and accuracy
    val_loss = val_running_loss / len(valset)
    val_acc = val_running_corrects.double() / len(valset)

    print(f'Validation Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
    wandb.log({"Validation Loss": val_loss, "Validation Accuracy": val_acc})
    
    # Save the best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f"New best model found! Saving model with validation accuracy: {best_val_acc:.4f}")
        torch.save(model.state_dict(), 'best_garbage_model.pth')

# Load the best model and evaluate on the test set
model.load_state_dict(torch.load('best_garbage_model.pth'))
model.eval()
test_running_corrects = 0

with torch.no_grad():
    for batch in testloader:
        images = batch['image'].to(device)
        token_ids = batch['token_ids'].to(device)
        labels = batch['label'].to(device)

        # Forward pass
        outputs = model(images, token_ids)

        # Predictions
        _, preds = torch.max(outputs, 1)

        test_running_corrects += torch.sum(preds == labels.data)

test_acc = test_running_corrects.double() / len(testset)
print(f'Test Accuracy: {test_acc:.4f}')
wandb.log({"Test Accuracy": test_acc})

wandb.finish()


Epoch 1/10
----------
