In [1]:
# Vehicle Damage Detection System - Improved Version V7
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models, transforms
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
from collections import Counter

# Helper function for sample weights
def get_sample_weights(dataset):
    """Calculate sample weights for balanced sampling"""
    class_counts = Counter(dataset.data['Tipos de Daño'])
    weights = 1.0 / torch.tensor([class_counts[i] for i in dataset.data['Tipos de Daño']], dtype=torch.float)
    return (weights / weights.sum()).to(DEVICE)

In [2]:
# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
NUM_EPOCHS = 50
MIN_SAMPLES_PER_CLASS = 20  # Increased from 15 to 20

# Label mappings (unchanged from previous version)
label_to_cls_piezas = {...}
label_to_cls_danos = {...}
label_to_cls_sugerencia = {...}

## Enhanced Dataset Class with Improved Class Filtering

In [3]:
class EnhancedVehicleDamageDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.data = pd.read_csv(csv_path, sep='|')
        self.img_dir = img_dir
        self.transform = transform
        
        # Filter rare classes and group some vehicle parts
        self._filter_and_group_classes()
        
    def _filter_and_group_classes(self):
        """Filter rare classes and group similar vehicle parts"""
        # Group rare vehicle parts into broader categories
        def group_parts(part_id):
            rare_parts = [4,5,7,8,9,19,20,21,22,23,24,25,27,28,29,30,31,32,33,34,35,36,37,51,52,54,59,60,61,62]
            return 99 if part_id in rare_parts else part_id
            
        self.data['Piezas del Vehículo'] = self.data['Piezas del Vehículo'].apply(group_parts)
        
        # Filter classes with insufficient samples
        for task in ['Tipos de Daño', 'Piezas del Vehículo', 'Sugerencia']:
            class_counts = self.data[task].value_counts()
            valid_classes = class_counts[class_counts >= MIN_SAMPLES_PER_CLASS].index
            self.data = self.data[self.data[task].isin(valid_classes)]
            
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.data.iloc[idx, 0])
        image = Image.open(img_path).convert('RGB')
        
        labels = {
            'damage': torch.tensor(self.data.iloc[idx, 1] - 1, dtype=torch.long),
            'part': torch.tensor(self.data.iloc[idx, 2] - 1, dtype=torch.long),
            'suggestion': torch.tensor(self.data.iloc[idx, 3] - 1, dtype=torch.long)
        }
        
        if self.transform:
            image = self.transform(image)
            
        return image, labels