In [None]:
class EnhancedVehicleDamageDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.data = pd.read_csv(csv_path, sep='|')
        self.img_dir = img_dir
        self.transform = transform
        self._filter_and_group_classes()
        self.damage_classes = sorted(label_to_cls_danos.keys())
        self.part_classes = sorted(label_to_cls_piezas.keys())
        self.suggestion_classes = sorted(label_to_cls_sugerencia.keys())

    def _filter_and_group_classes(self):
        def group_parts(part_id):
            if part_id not in label_to_cls_piezas:
                return 99
            return part_id
        self.data['Piezas del Vehículo'] = self.data['Piezas del Vehículo'].apply(group_parts)
        for task in ['Tipos de Daño', 'Piezas del Vehículo', 'Sugerencia']:
            class_counts = self.data[task].value_counts()
            valid_classes = class_counts[class_counts >= MIN_SAMPLES_PER_CLASS].index
            self.data = self.data[self.data[task].isin(valid_classes)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.data.iloc[idx, 0])
        image = Image.open(img_path).convert('RGB')
        damage_label = torch.zeros(len(self.damage_classes), dtype=torch.float32)
        part_label = torch.zeros(len(self.part_classes), dtype=torch.float32)
        suggestion_label = torch.zeros(len(self.suggestion_classes), dtype=torch.float32)
        damage_indices = str(self.data.iloc[idx, 1]).split(',') if ',' in str(self.data.iloc[idx, 1]) else [str(self.data.iloc[idx, 1])]
        part_indices = str(self.data.iloc[idx, 2]).split(',') if ',' in str(self.data.iloc[idx, 2]) else [str(self.data.iloc[idx, 2])]
        suggestion_indices = str(self.data.iloc[idx, 3]).split(',') if ',' in str(self.data.iloc[idx, 3]) else [str(self.data.iloc[idx, 3])]
        for di in damage_indices:
            di_int = int(di.strip()) - 1
            if 0 <= di_int < len(damage_label):
                damage_label[di_int] = 1.0
        for pi in part_indices:
            pi_int = int(pi.strip()) - 1
            if 0 <= pi_int < len(part_label):
                part_label[pi_int] = 1.0
        for si in suggestion_indices:
            si_int = int(si.strip()) - 1
            if 0 <= si_int < len(suggestion_label):
                suggestion_label[si_int] = 1.0
        labels = {
            'damage': damage_label,
            'part': part_label,
            'suggestion': suggestion_label
        }
        if self.transform:
            image = self.transform(image)
        return image, labels


In [None]:
class EnhancedDamageClassifier(nn.Module):
    def __init__(self, num_damage_types, num_parts, num_suggestions):
        super().__init__()
        self.backbone = models.resnet50(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.attention = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()  # Changed from Softmax to Sigmoid for multi-label attention
        )
        self.shared = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.damage_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_damage_types)
        )
        self.part_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_parts)
        )
        self.suggestion_head = nn.Sequential(
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_suggestions)
        )
    def forward(self, x):
        features = self.backbone(x)
        attended = self.attention(features) * features
        shared = self.shared(attended)
        return {
            'damage': self.damage_head(shared),
            'part': self.part_head(shared),
            'suggestion': self.suggestion_head(shared)
        }


In [None]:
def train_enhanced_model(model, train_loader, val_loader, num_epochs):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    early_stopper = EarlyStopper(patience=5)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE)
            labels = {k: v.to(DEVICE) for k, v in labels.items()}
            optimizer.zero_grad()
            outputs = model(inputs)
            loss_damage = criterion(outputs['damage'], labels['damage'])
            loss_part = criterion(outputs['part'], labels['part'])
            loss_suggestion = criterion(outputs['suggestion'], labels['suggestion'])
            loss = 0.4 * loss_damage + 0.4 * loss_part + 0.2 * loss_suggestion
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        val_loss = evaluate_loss(model, val_loader, criterion)
        val_metrics = evaluate_enhanced_model(model, val_loader)
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {running_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}')
        for task in val_metrics:
            print(f'{task} - Accuracy: {val_metrics[task]["accuracy"]:.4f} | F1: {val_metrics[task]["f1_macro"]:.4f}')
        if early_stopper(val_loss):
            print("Early stopping triggered!")
            break
    return model


In [None]:
def evaluate_enhanced_model(model, loader, threshold=0.5):
    model.eval()
    metrics = {}
    with torch.no_grad():
        for task in ['damage', 'part', 'suggestion']:
            all_preds = []
            all_labels = []
            for inputs, labels in loader:
                inputs = inputs.to(DEVICE)
                outputs = model(inputs)
                probs = torch.sigmoid(outputs[task])
                preds = (probs > threshold).int()
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels[task].cpu().numpy())
            metrics[task] = {
                'accuracy': accuracy_score(all_labels, all_preds),
                'f1_macro': f1_score(all_labels, all_preds, average='macro', zero_division=0)
            }
            print(f"\n{task} - F1 Macro: {metrics[task]['f1_macro']:.4f}")
            print(classification_report(all_labels, all_preds, zero_division=0))
    return metrics


In [None]:
def predict_damage(image_path, threshold=0.5):
    image = Image.open(image_path).convert('RGB')
    image = data_transforms['val'](image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        outputs = model(image)
    damage_probs = torch.sigmoid(outputs['damage'])[0]
    part_probs = torch.sigmoid(outputs['part'])[0]
    suggestion_probs = torch.sigmoid(outputs['suggestion'])[0]
    damage_preds = [label_to_cls_danos[i+1] for i, p in enumerate(damage_probs) if p > threshold]
    part_preds = [label_to_cls_piezas[i+1] for i, p in enumerate(part_probs) if p > threshold]
    suggestion_preds = [label_to_cls_sugerencia[i+1] for i, p in enumerate(suggestion_probs) if p > threshold]
    return {
        'damage': damage_preds,
        'part': part_preds,
        'suggestion': suggestion_preds
    }
