# Hierarchical Waste Classifier - Vision Transformer

This notebook trains a **Hierarchical Vision Transformer** for waste classification:
- **30 fine-grained classes** (specific waste types)  
- **7 super categories** (Metal, Paper, Glass, Plastic, Styrofoam, Organic, Textiles)

## Quick Start
**Just run all cells in sequence** - the notebook is designed to work smoothly from top to bottom.

### Key Features:
- **Dual-head architecture** with shared ViT backbone
- **Hierarchical loss function** combining fine-grained and super-class predictions
- **Live training progress** with tqdm bars
- **Automatic model saving** and performance tracking
- **Comprehensive analysis** with confusion matrices

### Requirements:
Update the `dataset_path` in the next cell to point to your dataset location.

In [None]:
# Import all required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision.transforms import v2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random
import timm
from tqdm.auto import tqdm
import gc
import time
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from timm import create_model
from sklearn.metrics import confusion_matrix, classification_report

print("✅ All libraries imported successfully!")

In [None]:
# Dataset setup
dataset_path = '/kaggle/input/household-waste-30-classes/images/images'

# Check if dataset path exists
if not os.path.exists(dataset_path):
    print(f"Dataset path not found: {dataset_path}")
    print("Please update the dataset_path variable with the correct path to your dataset")
else:
    print(f"Dataset path found: {dataset_path}")

full_dataset = ImageFolder(dataset_path)
class_map_dict = full_dataset.class_to_idx
class_names = list(class_map_dict.keys())
print(f'Number of classes: {len(class_names)}')

In [None]:
# Define hierarchical classification system
SUPER_CLASSES = {
    0: ['aerosol_cans', 'aluminum_food_cans', 'aluminum_soda_cans', 'steel_food_cans'],
    1: ['cardboard_boxes', 'cardboard_packaging', 'magazines', 'newspaper', 'office_paper', 'paper_cups'],
    2: ['glass_beverage_bottles', 'glass_cosmetic_containers', 'glass_food_jars'],
    3: ['disposable_plastic_cutlery', 'plastic_cup_lids', 'plastic_detergent_bottles', 'plastic_food_containers', 'plastic_shopping_bags', 'plastic_soda_bottles', 'plastic_straws', 'plastic_trash_bags', 'plastic_water_bottles'],
    4: ['styrofoam_cups', 'styrofoam_food_containers'],
    5: ['coffee_grounds', 'eggshells', 'food_waste', 'tea_bags'],
    6: ['clothing', 'shoes']
}

SUPER_CLASS_NAMES = {
    0: 'Metal_Aluminum', 1: 'Cardboard_Paper', 2: 'Glass_Containers',
    3: 'Plastic_Items', 4: 'Styrofoam_Products', 5: 'Organic_Waste', 6: 'Textiles_Clothing'
}

# Build mappings
CLASS_TO_SUPER = {}
for super_id, class_list in SUPER_CLASSES.items():
    for class_name in class_list:
        CLASS_TO_SUPER[class_name] = super_id

CLASS_IDX_TO_SUPER_IDX = {}
for class_name, class_idx in class_map_dict.items():
    if class_name in CLASS_TO_SUPER:
        super_idx = CLASS_TO_SUPER[class_name]
        CLASS_IDX_TO_SUPER_IDX[class_idx] = super_idx

print(f"Hierarchical Classification Setup Complete:")
print(f"• Fine-grained classes: {len(class_names)}")
print(f"• Super classes: {len(SUPER_CLASS_NAMES)}")

In [None]:
# Data transforms and loaders
train_transform = v2.Compose([
    v2.RandomResizedCrop(224, scale=(0.7, 1.0)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = v2.Compose([
    v2.Resize((224, 224)), v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class SimpleTransformWrapper(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    def __getitem__(self, index):
        image, target = self.dataset[index]
        if self.transform:
            image = self.transform(image)
        return image, target
    def __len__(self):
        return len(self.dataset)

# Split dataset
torch.manual_seed(42)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# Apply transforms
train_dataset_transformed = SimpleTransformWrapper(train_dataset, transform=train_transform)
test_dataset_transformed = SimpleTransformWrapper(test_dataset, transform=test_transform)

# Create data loaders
batch_size = 64
train_loader = DataLoader(train_dataset_transformed, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset_transformed, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Dataset split: {len(train_dataset_transformed)} train, {len(test_dataset_transformed)} test samples")

In [None]:
# Hierarchical Vision Transformer Model
class HierarchicalViT(nn.Module):
    def __init__(self, num_fine_classes=30, num_super_classes=7, model_name='vit_small_patch16_224', dropout=0.1):
        super().__init__()
        self.backbone = create_model(model_name, pretrained=True, num_classes=0)
        feature_dim = self.backbone.num_features
        
        self.feature_processor = nn.Sequential(nn.LayerNorm(feature_dim), nn.Dropout(dropout))
        
        self.fine_classifier = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(feature_dim // 2, num_fine_classes)
        )
        
        self.super_classifier = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 4), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(feature_dim // 4, num_super_classes)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        processed_features = self.feature_processor(features)
        fine_logits = self.fine_classifier(processed_features)
        super_logits = self.super_classifier(processed_features)
        return fine_logits, super_logits

# Setup device and model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
waste_classifier = HierarchicalViT(num_fine_classes=30, num_super_classes=7, dropout=0.3).to(device)
print(f"Model created on {device}")

In [None]:
# Loss functions
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()

class HierarchicalLoss(nn.Module):
    def __init__(self, fine_weight=0.7, super_weight=0.3, focal_gamma=2.0):
        super().__init__()
        self.fine_weight = fine_weight
        self.super_weight = super_weight
        self.fine_loss = FocalLoss(gamma=focal_gamma)
        self.super_loss = FocalLoss(gamma=focal_gamma)
    
    def forward(self, fine_logits, super_logits, fine_targets, super_targets):
        loss_fine = self.fine_loss(fine_logits, fine_targets.long())
        loss_super = self.super_loss(super_logits, super_targets.long())
        total_loss = self.fine_weight * loss_fine + self.super_weight * loss_super
        return total_loss, loss_fine, loss_super

criterion_hierarchical = HierarchicalLoss(fine_weight=0.7, super_weight=0.3, focal_gamma=2.0)
print("Loss functions initialized")

In [None]:
# Training functions
def train_epoch_hierarchical(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss, fine_correct, super_correct, total_samples = 0, 0, 0, 0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for images, fine_labels in pbar:
        images, fine_labels = images.to(device), fine_labels.to(device)
        super_labels = torch.tensor([CLASS_IDX_TO_SUPER_IDX.get(int(label), 0) for label in fine_labels.cpu()]).to(device)
        
        fine_logits, super_logits = model(images)
        loss, _, _ = criterion(fine_logits, super_logits, fine_labels, super_labels)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        _, fine_pred = torch.max(fine_logits, 1)
        _, super_pred = torch.max(super_logits, 1)
        fine_correct += (fine_pred == fine_labels).sum().item()
        super_correct += (super_pred == super_labels).sum().item()
        total_samples += images.size(0)
        total_loss += loss.item()
        
        pbar.set_postfix({'Loss': f'{loss.item():.4f}', 'Fine': f'{100*fine_correct/total_samples:.1f}%'})
    
    return total_loss/len(train_loader), 100*fine_correct/total_samples, 100*super_correct/total_samples

def eval_epoch_hierarchical(model, test_loader, criterion, device):
    model.eval()
    total_loss, fine_correct, super_correct, total_samples = 0, 0, 0, 0
    
    with torch.no_grad():
        for images, fine_labels in test_loader:
            images, fine_labels = images.to(device), fine_labels.to(device)
            super_labels = torch.tensor([CLASS_IDX_TO_SUPER_IDX.get(int(label), 0) for label in fine_labels.cpu()]).to(device)
            
            fine_logits, super_logits = model(images)
            loss, _, _ = criterion(fine_logits, super_logits, fine_labels, super_labels)
            
            _, fine_pred = torch.max(fine_logits, 1)
            _, super_pred = torch.max(super_logits, 1)
            fine_correct += (fine_pred == fine_labels).sum().item()
            super_correct += (super_pred == super_labels).sum().item()
            total_samples += images.size(0)
            total_loss += loss.item()
    
    return total_loss/len(test_loader), 100*fine_correct/total_samples, 100*super_correct/total_samples

print("Training functions defined")

In [None]:
# Main training loop
EPOCHS = 15
optimizer = torch.optim.AdamW(waste_classifier.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5)

train_losses, val_losses = [], []
train_fine_accs, val_fine_accs = [], []
train_super_accs, val_super_accs = [], []
best_val_acc = 0

print(f"Starting training for {EPOCHS} epochs...")
print("-" * 80)

for epoch in range(1, EPOCHS + 1):
    train_loss, train_fine_acc, train_super_acc = train_epoch_hierarchical(
        waste_classifier, train_loader, criterion_hierarchical, optimizer, device
    )
    
    val_loss, val_fine_acc, val_super_acc = eval_epoch_hierarchical(
        waste_classifier, test_loader, criterion_hierarchical, device
    )
    
    scheduler.step()
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_fine_accs.append(train_fine_acc)
    val_fine_accs.append(val_fine_acc)
    train_super_accs.append(train_super_acc)
    val_super_accs.append(val_super_acc)
    
    lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch:2d} | LR: {lr:.6f} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"         | Fine: {train_fine_acc:.1f}%/{val_fine_acc:.1f}% | Super: {train_super_acc:.1f}%/{val_super_acc:.1f}%")
    
    current_acc = val_fine_acc + val_super_acc
    if current_acc > best_val_acc:
        best_val_acc = current_acc
        torch.save(waste_classifier.state_dict(), 'best_hierarchical_model.pth')
        print(f"         *** New best model saved! Combined accuracy: {current_acc/2:.1f}%")

print(f"\nTraining completed! Best combined accuracy: {best_val_acc/2:.1f}%")

In [None]:
# Visualization
if len(train_losses) > 0:
    epochs = range(1, len(train_losses) + 1)
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Hierarchical Training Results', fontsize=16, fontweight='bold')
    
    ax1.plot(epochs, train_losses, 'b-', label='Training', linewidth=2)
    ax1.plot(epochs, val_losses, 'r-', label='Validation', linewidth=2)
    ax1.set_title('Loss')
    ax1.set_xlabel('Epoch')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(epochs, train_fine_accs, 'g-', label='Training', linewidth=2)
    ax2.plot(epochs, val_fine_accs, 'orange', label='Validation', linewidth=2)
    ax2.set_title('Fine-grained Accuracy (30 classes)')
    ax2.set_xlabel('Epoch')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    ax3.plot(epochs, train_super_accs, 'purple', label='Training', linewidth=2)
    ax3.plot(epochs, val_super_accs, 'brown', label='Validation', linewidth=2)
    ax3.set_title('Super-class Accuracy (7 categories)')
    ax3.set_xlabel('Epoch')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    categories = ['Fine-grained', 'Super-class']
    final_train = [train_fine_accs[-1], train_super_accs[-1]]
    final_val = [val_fine_accs[-1], val_super_accs[-1]]
    
    x = np.arange(len(categories))
    width = 0.35
    
    ax4.bar(x - width/2, final_train, width, label='Training', alpha=0.8)
    ax4.bar(x + width/2, final_val, width, label='Validation', alpha=0.8)
    ax4.set_title('Final Accuracy Comparison')
    ax4.set_xticks(x)
    ax4.set_xticklabels(categories)
    ax4.legend()
    
    plt.tight_layout()
    plt.show()
    
    print(f"Best Fine-grained: {max(val_fine_accs):.1f}%")
    print(f"Best Super-class: {max(val_super_accs):.1f}%")
else:
    print("No training data to visualize")

In [None]:
# Final cleanup
gc.collect()
print("Training complete! Model saved as 'best_hierarchical_model.pth'")