Para mejorar el rendimiento del modelo (actualmente 41% daños, 31% partes, 67% sugerencias), podemos implementar varias mejoras:

1. Optimización de hiperparámetros:
    - Ajustar learning rate, batch size y número de épocas
    - Probar diferentes optimizadores (Adam, SGD con momentum)

2. Arquitectura del modelo:
    - Usar un backbone más potente (ResNet50 en lugar de ResNet18)
    - Añadir más capas a las cabezas de clasificación
    - Implementar atención espacial

3. Datos:
    - Aumentar el dataset con más imágenes
    - Mejorar el balanceo de clases
    - Aplicar más transformaciones de data augmentation

4. Técnicas de entrenamiento:
    - Usar learning rate scheduling
    - Implementar early stopping
    - Añadir regularización (dropout, weight decay)

5. Evaluación:
    - Analizar matriz de confusión para entender errores
    - Visualizar casos donde falla el modelo

In [2]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models, transforms
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
from collections import Counter

# Configuration 2. Optimización del entrenamiento:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
NUM_EPOCHS = 100
MIN_SAMPLES_PER_CLASS = 20
LR = 1e-4
WEIGHT_DECAY = 1e-4

## Diccionario para Piezas del Vehículo
label_to_cls_piezas = {
    1: "Antiniebla delantero derecho",
    2: "Antiniebla delantero izquierdo",
    3: "Capó",
    4: "Cerradura capo",
    5: "Cerradura maletero",
    6: "Cerradura puerta",
    7: "Espejo lateral derecho",
    8: "Espejo lateral izquierdo",
    9: "Faros derecho",
    10: "Faros izquierdo",
    11: "Guardabarros delantero derecho",
    12: "Guardabarros delantero izquierdo",
    13: "Guardabarros trasero derecho",
    14: "Guardabarros trasero izquierdo",
    15: "Luz indicadora delantera derecha",
    16: "Luz indicadora delantera izquierda",
    17: "Luz indicadora trasera derecha",
    18: "Luz indicadora trasera izquierda",
    19: "Luz trasera derecho",
    20: "Luz trasera izquierdo",
    21: "Maletero",
    22: "Manija derecha",
    23: "Manija izquierda",
    24: "Marco de la ventana",
    25: "Marco de las puertas",
    26: "Moldura capó",
    27: "Moldura puerta delantera derecha",
    28: "Moldura puerta delantera izquierda",
    29: "Moldura puerta trasera derecha",
    30: "Moldura puerta trasera izquierda",
    31: "Parabrisas delantero",
    32: "Parabrisas trasero",
    33: "Parachoques delantero",
    34: "Parachoques trasero",
    35: "Puerta delantera derecha",
    36: "Puerta delantera izquierda",
    37: "Puerta trasera derecha",
    38: "Puerta trasera izquierda",
    39: "Rejilla, parrilla",
    40: "Rueda",
    41: "Tapa de combustible",
    42: "Tapa de rueda",
    43: "Techo",
    44: "Techo corredizo",
    45: "Ventana delantera derecha",
    46: "Ventana delantera izquierda",
    47: "Ventana trasera derecha",
    48: "Ventana trasera izquierda",
    49: "Ventanilla delantera derecha",
    50: "Ventanilla delantera izquierda",
    51: "Ventanilla trasera derecha",
    52: "Ventanilla trasera izquierda"
}

# Clases que seran eliminadas del set de datos
# 2: "Corrosión",
# Diccionario para Tipos de Daño (completo)
label_to_cls_danos = {
    1: "Abolladura",
    2: "Deformación",
    3: "Desprendimiento",
    4: "Fractura",
    5: "Rayón",
    6: "Rotura"
}

# Diccionario para Sugerencia (completo)
label_to_cls_sugerencia = {
    1: "Reparar",
    2: "Reemplazar"
}

# Cargar los datasets
train_expanded = pd.read_csv('data/fotos_siniestros/datasets/train.csv', sep='|')
val_expanded = pd.read_csv('data/fotos_siniestros/datasets/val.csv', sep='|')
test_expanded = pd.read_csv('data/fotos_siniestros/datasets/test.csv', sep='|')

# Helper functions
# 3. Sobremuestreo (Oversampling) Mejorado
# Modifica tu WeightedRandomSampler para priorizar clases minoritarias:
def get_sample_weights(dataset, task='Tipos de Daño'):  # Cambiado de 'damage' a 'Tipos de Daño'
    """Calcula pesos para oversampling basado en frecuencia inversa"""
    class_counts = Counter(dataset.data[task])  # Ahora usa el nombre correcto de la columna
    total_samples = len(dataset)
    num_classes = len(class_counts)
    
    weights = [total_samples / (num_classes * class_counts[cls]) for cls in dataset.data[task]]
    return torch.tensor(weights, dtype=torch.float32)

In [3]:
# Dataset Class
class EnhancedVehicleDamageDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.data = pd.read_csv(csv_path, sep='|')
        self.img_dir = img_dir
        self.transform = transform
        self._filter_and_group_classes()

    def _filter_and_group_classes(self):
        """Filter rare classes and group similar vehicle parts"""
        def group_parts(part_id):
            # Group all parts not in our main dictionary
            if part_id not in label_to_cls_piezas:
                return 99
            return part_id
            
        self.data['Piezas del Vehículo'] = self.data['Piezas del Vehículo'].apply(group_parts)
        
        # Filter classes with insufficient samples
        for task in ['Tipos de Daño', 'Piezas del Vehículo', 'Sugerencia']:
            class_counts = self.data[task].value_counts()
            valid_classes = class_counts[class_counts >= MIN_SAMPLES_PER_CLASS].index
            self.data = self.data[self.data[task].isin(valid_classes)]
            
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.data.iloc[idx, 0])
        image = Image.open(img_path).convert('RGB')
        
        labels = {
            'damage': torch.tensor(self.data.iloc[idx, 1] - 1, dtype=torch.long),
            'part': torch.tensor(self.data.iloc[idx, 2] - 1, dtype=torch.long),
            'suggestion': torch.tensor(self.data.iloc[idx, 3] - 1, dtype=torch.long)
        }
        
        if self.transform:
            image = self.transform(image)
            
        return image, labels

### Cambios Clave Resumidos:
    - Pesos de clase en las funciones de pérdida.
    - Focal Loss para enfoque en clases difíciles.
    - Sampler mejorado para oversampling de clases minoritarias.
    - Early Stopping para evitar overfitting.
    - Métricas F1-score (mejor que accuracy para desbalanceo).

### Estas modificaciones son mínimas pero impactantes. Más mejoras, explorar técnicas avanzadas como SMOTE o arquitecturas jerárquicas.

In [4]:
# Model Architecture
class EnhancedDamageClassifier(nn.Module):
    def __init__(self, num_damage_types, num_parts, num_suggestions):
        super().__init__()
        # Backbone ResNet50
        self.backbone = models.resnet50(pretrained=True)
        in_features = self.backbone.fc.in_features  # 2048 features
        self.backbone.fc = nn.Identity()

        # Capa de atención
        self.attention = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Softmax(dim=1)
        )
        
        # Capa compartida con dimensiones correctas
        self.shared = nn.Sequential(
            nn.Linear(in_features, 1024),  # Ahora recibe 2048 y sale 1024
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),  # Reducción gradual
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Cabezales de clasificación con dimensiones correctas
        self.damage_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_damage_types)
        )

        self.part_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_parts)
        )
        
        self.suggestion_head = nn.Sequential(
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_suggestions)
        )
        
    def forward(self, x):
        features = self.backbone(x)  # 2048 features
        attended = self.attention(features) * features  # Atención aplicada
        shared = self.shared(attended)  # 512 features
        
        return {
            'damage': self.damage_head(shared),
            'part': self.part_head(shared),
            'suggestion': self.suggestion_head(shared)
        }

# Training Function
# 1. Balanceo con Pesos de Clase (Solución más simple)
# Explicación:
# Los pesos penalizan más los errores en clases minoritarias (ej: clase 4 en daños tiene peso 4.77).
# Para Piezas del Vehículo (45 clases), no usamos pesos por simplicidad (pero podrías agregarlos si tienes memoria suficiente).
# 2. Focal Loss (Alternativa a CrossEntropy)
# Para manejar mejor las clases extremadamente desbalanceadas (como las piezas del vehículo), añade esta clase y modifica el criterio:
# Explicación:
# gamma=2 enfoca más en muestras difíciles (útil para clases raras).
# Combínalo con los pesos (alpha) para mejores resultados.
# Modifica la función de pérdida para incluir los pesos sugeridos:
# 3. Sobremuestreo (Oversampling) Mejorado
# Modifica tu WeightedRandomSampler para priorizar clases minoritarias:
# Explicación:
# Ahora el sampler prioriza las clases de daño minoritarias (como "Desprendimiento" o "Fractura").
# Puedes cambiar task a 'part' o 'suggestion' según necesites.
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt)**self.gamma * ce_loss
        return focal_loss.mean()  # Asegura que sea un escalar
    
class EarlyStopper:
    def __init__(self, patience=3, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_loss = float('inf')

    def __call__(self, val_loss):
        if val_loss < self.min_loss - self.min_delta:
            self.min_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
def evaluate_loss(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(DEVICE)
            labels = {k: v.to(DEVICE) for k, v in labels.items()}
            outputs = model(inputs)
            
            loss = 0.4 * criterion['damage'](outputs['damage'], labels['damage']) + \
                   0.4 * criterion['part'](outputs['part'], labels['part']) + \
                   0.2 * criterion['suggestion'](outputs['suggestion'], labels['suggestion'])
            total_loss += loss.item()
    
    return total_loss / len(loader)
    
def train_enhanced_model(model, train_loader, val_loader, num_epochs):
    damage_weights = torch.tensor([0.51, 2.65, 4.01, 4.77, 0.75, 0.52], dtype=torch.float32, device=DEVICE)
    suggestion_weights = torch.tensor([0.72, 1.62], dtype=torch.float32, device=DEVICE)
    
    criterion = {
        'damage': FocalLoss(alpha=damage_weights, gamma=2.0),
        'part': FocalLoss(gamma=2.0),
        'suggestion': FocalLoss(alpha=suggestion_weights, gamma=2.0)
    }
    
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    early_stopper = EarlyStopper(patience=5)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE)
            labels = {k: v.to(DEVICE) for k, v in labels.items()}
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = 0.4 * criterion['damage'](outputs['damage'], labels['damage']) + \
                   0.4 * criterion['part'](outputs['part'], labels['part']) + \
                   0.2 * criterion['suggestion'](outputs['suggestion'], labels['suggestion'])
            
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Validation
        val_loss = evaluate_loss(model, val_loader, criterion)
        val_metrics = evaluate_enhanced_model(model, val_loader)
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {running_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}')
        for task in val_metrics:
            print(f'{task} - Accuracy: {val_metrics[task]["accuracy"]:.4f} | F1: {val_metrics[task]["f1_macro"]:.4f}')
        
        if early_stopper(val_loss):
            print("Early stopping triggered!")
            break
    
    return model

# Evaluation Function
# 5. Métricas Mejoradas (F1-score)
# Modifica evaluate_enhanced_model para mostrar métricas por clase:

def evaluate_enhanced_model(model, loader):
    model.eval()
    metrics = {}
    
    with torch.no_grad():
        for task in ['damage', 'part', 'suggestion']:
            all_preds = []
            all_labels = []
            
            for inputs, labels in loader:
                inputs = inputs.to(DEVICE)
                outputs = model(inputs)
                _, preds = torch.max(outputs[task], 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels[task].cpu().numpy())
            
            # Añade F1-score (macro para clases desbalanceadas)
            metrics[task] = {
                'accuracy': accuracy_score(all_labels, all_preds),
                'f1_macro': f1_score(all_labels, all_preds, average='macro', zero_division=0)
            }
            
            print(f"\n{task} - F1 Macro: {metrics[task]['f1_macro']:.4f}")
            print(classification_report(all_labels, all_preds, zero_division=0))
    
    return metrics

In [None]:
# Last Execution 11:10:29 PM
# Execution Time 594m 57.8s
# Overhead Time 12m 6.1s
# Render Times
# VS Code Builtin Notebook Output Renderer 5ms

if __name__ == '__main__':
    # Data transforms 3. Data augmentation mejorada:
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomRotation(10),
            transforms.RandomAffine(0, shear=10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    # Create datasets and data loaders
    train_dataset = EnhancedVehicleDamageDataset(
        'data/fotos_siniestros/datasets/train.csv',
        'data/fotos_siniestros/',
        data_transforms['train']
    )
    
    val_dataset = EnhancedVehicleDamageDataset(
        'data/fotos_siniestros/datasets/val.csv',
        'data/fotos_siniestros/',
        data_transforms['val']
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=WeightedRandomSampler(
            weights=get_sample_weights(train_dataset, task='Tipos de Daño'),  # Usa 'Tipos de Daño' en lugar de 'damage'
            num_samples=len(train_dataset),
            replacement=True
        ),
        num_workers=4
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4
    )

    model = EnhancedDamageClassifier(
        num_damage_types=len(label_to_cls_danos),
        num_parts=len(label_to_cls_piezas),  # Now includes all parts
        num_suggestions=len(label_to_cls_sugerencia)
    ).to(DEVICE)
    
    trained_model = train_enhanced_model(model, train_loader, val_loader, NUM_EPOCHS)
    
    # Save model
    torch.save(trained_model.state_dict(), 'enhanced_damage_classifier_by_blackboxAI.pth')




damage - F1 Macro: 0.1352
              precision    recall  f1-score   support

           0       0.38      0.05      0.08       303
           1       0.14      0.23      0.18       103
           2       0.15      0.69      0.24        80
           3       0.04      0.43      0.08        30
           4       0.43      0.14      0.21       114
           5       0.27      0.01      0.02       303

    accuracy                           0.13       933
   macro avg       0.24      0.26      0.14       933
weighted avg       0.29      0.13      0.10       933


part - F1 Macro: 0.0416
              precision    recall  f1-score   support

           1       0.00      0.00      0.00         0
           2       0.06      0.04      0.05        93
           8       0.00      0.00      0.00        47
           9       0.00      0.00      0.00        63
          10       0.00      0.00      0.00        49
          11       0.10      0.12      0.11       104
          12       0.00   

The model has been trained with the following final metrics:

   Epoch 1/100

   Loss: 2.5068
   damage Accuracy: 0.1840
   part Accuracy: 0.0983
   suggestion Accuracy: 0.5801

Classification Report for damage:
              precision    recall  f1-score   support

           0       0.71      0.05      0.10       222
           2       0.09      0.14      0.11        37
           3       0.14      0.39      0.21        76
           4       0.00      0.00      0.00        26
           5       0.33      0.80      0.47       128
           6       0.52      0.16      0.25       223

    accuracy                           0.26       712
   macro avg       0.30      0.26      0.19       712
weighted avg       0.46      0.26      0.22       712



Classification Report for part:
              precision    recall  f1-score   support

           5       0.11      0.11      0.11        55
          12       0.03      0.06      0.04        36
          13       0.00      0.00      0.00        50
          14       0.11      0.18      0.13        57
          15       0.26      0.22      0.24        77
          16       0.00      0.00      0.00         0
          17       0.00      0.00      0.00        28
          23       0.00      0.00      0.00        20
          24       0.00      0.00      0.00         0
          29       0.00      0.00      0.00        20
          37       0.50      0.48      0.49        33
          39       0.28      0.03      0.05       168
          40       0.27      0.60      0.37        73
          43       0.00      0.00      0.00        30
          45       0.00      0.00      0.00        32
          47       0.00      0.00      0.00         0
          48       0.17      0.73      0.28        33

    accuracy                           0.17       712
   macro avg       0.10      0.14      0.10       712
weighted avg       0.17      0.17      0.13       712



Classification Report for suggestion:
              precision    recall  f1-score   support

           0       0.67      0.84      0.75       445
           1       0.55      0.31      0.40       267

    accuracy                           0.64       712
   macro avg       0.61      0.58      0.57       712
weighted avg       0.62      0.64      0.62       712


   Epoch 82/100

   Loss: 0.8887
   damage Accuracy: 0.4115
   part Accuracy: 0.3062
   suggestion Accuracy: 0.6404

Classification Report for damage:
              precision    recall  f1-score   support

           0       0.42      0.50      0.45       222
           2       0.08      0.05      0.07        37
           3       0.21      0.32      0.25        76
           4       0.17      0.08      0.11        26
           5       0.55      0.45      0.50       128
           6       0.49      0.43      0.46       223

    accuracy                           0.41       712
   macro avg       0.32      0.30      0.31       712
weighted avg       0.42      0.41      0.41       712

Classification Report for part:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           5       0.50      0.07      0.13        55
          12       0.16      0.44      0.23        36
          13       0.12      0.02      0.03        50
          14       0.29      0.04      0.06        57
          15       0.33      0.03      0.05        77
          16       0.00      0.00      0.00         0
          17       0.00      0.00      0.00        28
          23       0.40      0.10      0.16        20
          24       0.00      0.00      0.00         0
          29       0.12      0.05      0.07        20
          37       0.89      0.94      0.91        33
          39       0.33      0.58      0.42       168
          40       0.46      0.63      0.53        73
          42       0.00      0.00      0.00         0
          43       0.40      0.07      0.11        30
          44       0.00      0.00      0.00         0
          45       0.33      0.06      0.11        32
          48       0.25      0.45      0.32        33

    accuracy                           0.31       712
   macro avg       0.24      0.18      0.17       712
weighted avg       0.34      0.31      0.26       712

Classification Report for suggestion:
              precision    recall  f1-score   support

           0       0.71      0.72      0.72       445
           1       0.52      0.51      0.52       267

    accuracy                           0.64       712
   macro avg       0.62      0.62      0.62       712
weighted avg       0.64      0.64      0.64       712

In [6]:
if __name__ == '__main__':
    # Data transforms 3. Data augmentation mejorada:
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomRotation(10),
            transforms.RandomAffine(0, shear=10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    # Create datasets and data loaders
    train_dataset = EnhancedVehicleDamageDataset(
        'data/fotos_siniestros/datasets/train.csv',
        'data/fotos_siniestros/',
        data_transforms['train']
    )
    
    val_dataset = EnhancedVehicleDamageDataset(
        'data/fotos_siniestros/datasets/val.csv',
        'data/fotos_siniestros/',
        data_transforms['val']
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=WeightedRandomSampler(
            weights=get_sample_weights(train_dataset),
            num_samples=len(train_dataset),
            replacement=True
        ),
        num_workers=4
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4
    )
    
    # Initialize and train model
    # model = SimplifiedDamageClassifier(
    #     num_damage_types=len(label_to_cls_danos),
    #     num_parts=len(label_to_cls_piezas),  # Now includes all parts
    #     num_suggestions=len(label_to_cls_sugerencia)
    # ).to(DEVICE)

    model = EnhancedDamageClassifier(
        num_damage_types=len(label_to_cls_danos),
        num_parts=len(label_to_cls_piezas),  # Now includes all parts
        num_suggestions=len(label_to_cls_sugerencia)
    ).to(DEVICE)
# Load the trained model
model.load_state_dict(torch.load('enhanced_damage_classifier_by_blackboxAI.pth'))
model.eval()



EnhancedDamageClassifier(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequ

In [7]:
# Prediction function
def predict_damage(image_path):
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image = data_transforms['val'](image).unsqueeze(0).to(DEVICE)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(image)
    
    # Convert predictions to labels
    damage_pred = torch.argmax(outputs['damage'], 1).item()
    part_pred = torch.argmax(outputs['part'], 1).item()
    suggestion_pred = torch.argmax(outputs['suggestion'], 1).item()
    
    return {
        'damage': label_to_cls_danos[damage_pred + 1],
        'part': label_to_cls_piezas[part_pred + 1],
        'suggestion': label_to_cls_sugerencia[suggestion_pred + 1]
    }

In [8]:
# Test on sample images
test_images = [
    'predecir/golpe_01.jpg',
    'predecir/golpe_02.jpg',
    'predecir/golpe_03.jpg',
    'predecir/Siniestro_01.jpg',
    'predecir/Siniestro_02.jpg',
    'predecir/Siniestro_03.jpg',
    'predecir/Siniestro_04.jpg',
    'predecir/rayon_01.jpg',
    'predecir/rayon_02.jpg',
    'predecir/mica_rota_01.jpg',
    'predecir/mica_rota-rayon_01.jpg'
]

for img_path in test_images:
    print(f"\nPredictions for {img_path}:")
    predictions = predict_damage(img_path)
    for task, pred in predictions.items():
        print(f"{task.capitalize()}: {pred}")


Predictions for predecir/golpe_01.jpg:
Damage: Deformación
Part: Puerta delantera izquierda
Suggestion: Reparar

Predictions for predecir/golpe_02.jpg:
Damage: Rayón
Part: Parachoques trasero
Suggestion: Reemplazar

Predictions for predecir/golpe_03.jpg:
Damage: Rayón
Part: Parachoques delantero
Suggestion: Reparar

Predictions for predecir/Siniestro_01.jpg:
Damage: Desprendimiento
Part: Parachoques delantero
Suggestion: Reemplazar

Predictions for predecir/Siniestro_02.jpg:
Damage: Rayón
Part: Parachoques delantero
Suggestion: Reemplazar

Predictions for predecir/Siniestro_03.jpg:
Damage: Desprendimiento
Part: Parachoques trasero
Suggestion: Reemplazar

Predictions for predecir/Siniestro_04.jpg:
Damage: Rayón
Part: Parachoques trasero
Suggestion: Reparar

Predictions for predecir/rayon_01.jpg:
Damage: Rayón
Part: Puerta delantera izquierda
Suggestion: Reparar

Predictions for predecir/rayon_02.jpg:
Damage: Rayón
Part: Parachoques trasero
Suggestion: Reparar

Predictions for predecir/

---