# üö≤ DelftBikes Ultimate Defect Detection - Complete Notebook

**Complete all-in-one notebook for Kaggle dual T4 training**

## Features:
- Multi-GPU training (2x T4)
- WandB logging with ALL metrics per epoch
- Advanced augmentation (10+ techniques)
- Class-balanced sampling (3x for damaged)
- Cosine annealing with warmup
- Mixed precision (2x faster)
- Layer-wise learning rates
- Focal loss for class imbalance



## Setup:
1. Set GPU to **T4 x2** in Kaggle settings
2. Add dataset: `/kaggle/input/delftbikes/`
3. Run all cells!

---
## 1Ô∏è‚É£ Check GPU & Install Packages

In [None]:
import torch
import sys

print("="*80)
print("üî• GPU SETUP")
print("="*80)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

for i in range(torch.cuda.device_count()):
    print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")

if torch.cuda.device_count() < 2:
    print("\n‚ö†Ô∏è WARNING: Only 1 GPU detected!")
    print("   Please set notebook to T4 x2 in Kaggle settings!")
else:
    print("\n‚úÖ Dual GPU setup confirmed!")

In [None]:
# Install required packages
!pip install -q albumentations wandb
!pip install -q opencv-python-headless

print("‚úÖ Packages installed!")

---
## 2Ô∏è‚É£ WandB Setup

In [None]:
import os
import wandb

# Kaggle secrets are stored under /root/.kaggle/secrets
wandb_api_key = os.environ.get("WANDB_API_KEY")

# Login using the secret
wandb.login(key=wandb_api_key)

print("‚úÖ WandB login successful via Kaggle secret!")


---
## 3Ô∏è‚É£ Define Model Architecture

In [None]:
# Complete Model Definition
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign

class StateClassificationHead(nn.Module):
    """Classification head for part state (intact/damaged/occluded/absent)"""
    
    def __init__(self, in_channels, num_states=4):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_states)
        self.dropout = nn.Dropout(0.3)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x


class BikeDefectDetector(nn.Module):
    """Multi-task model for bike defect detection"""
    
    def __init__(
        self,
        num_classes=22,
        num_states=4,
        backbone_name='resnet101',
        pretrained_backbone=True,
        trainable_backbone_layers=4,
        min_size=896,
        max_size=1344,
    ):
        super().__init__()
        
        self.num_classes = num_classes + 1
        self.num_states = num_states
        
        # Create backbone
        if backbone_name == 'resnet50':
            backbone = torchvision.models.resnet50(weights='IMAGENET1K_V1' if pretrained_backbone else None)
        elif backbone_name == 'resnet101':
            backbone = torchvision.models.resnet101(weights='IMAGENET1K_V1' if pretrained_backbone else None)
        else:
            raise ValueError(f"Unknown backbone: {backbone_name}")
        
        # Freeze early layers
        if trainable_backbone_layers < 5:
            layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_backbone_layers]
            for name, parameter in backbone.named_parameters():
                if not any([layer in name for layer in layers_to_train]):
                    parameter.requires_grad_(False)
        
        # Extract backbone features
        backbone = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool,
            backbone.layer1,
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        )
        backbone.out_channels = 2048
        
        # Anchor generator
        anchor_generator = AnchorGenerator(
            sizes=((32, 64, 128, 256, 512),),
            aspect_ratios=((0.5, 1.0, 2.0),) * 5
        )
        
        # ROI pooler
        roi_pooler = MultiScaleRoIAlign(
            featmap_names=['0'],
            output_size=7,
            sampling_ratio=2
        )
        
        # Build Faster R-CNN
        self.detector = FasterRCNN(
            backbone=backbone,
            num_classes=self.num_classes,
            rpn_anchor_generator=anchor_generator,
            box_roi_pool=roi_pooler,
            min_size=min_size,
            max_size=max_size,
            image_mean=[0.485, 0.456, 0.406],
            image_std=[0.229, 0.224, 0.225]
        )
        
        # Replace box predictor
        in_features = self.detector.roi_heads.box_predictor.cls_score.in_features
        self.detector.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.num_classes)
        
        # State classification head
        self.state_head = StateClassificationHead(in_channels=in_features, num_states=num_states)
        self.detector.roi_heads.state_head = self.state_head
        
    def forward(self, images, targets=None):
        if self.training:
            if targets is None:
                raise ValueError("Targets required during training")
            
            # Standard Faster R-CNN forward
            loss_dict = self.detector(images, targets)
            
            # Get features for state classification
            images_tensors, targets_updated = self.detector.transform(images, targets)
            features = self.detector.backbone(images_tensors.tensors)
            proposals, proposal_losses = self.detector.rpn(images_tensors, features, targets_updated)
            
            # Get box features
            box_features = self.detector.roi_heads.box_roi_pool(features, proposals, images_tensors.image_sizes)
            box_features = self.detector.roi_heads.box_head(box_features)
            
            # State classification
            state_logits = self.state_head(box_features)
            
            # Compute state loss
            target_states = []
            for target in targets_updated:
                if 'states' in target:
                    target_states.append(target['states'])
            
            if target_states:
                target_states_cat = torch.cat(target_states, dim=0)
                if len(target_states_cat) > 0 and len(state_logits) > 0:
                    min_len = min(len(state_logits), len(target_states_cat))
                    state_logits = state_logits[:min_len]
                    target_states_cat = target_states_cat[:min_len]
                    state_loss = F.cross_entropy(state_logits, target_states_cat, ignore_index=5)
                    loss_dict['loss_state'] = state_loss * 0.5
            
            return loss_dict
        
        else:
            # Inference
            detections = self.detector(images)
            
            # Add state predictions
            for detection in detections:
                if len(detection['boxes']) > 0:
                    images_tensors = self.detector.transform(images, None)[0]
                    features = self.detector.backbone(images_tensors.tensors)
                    box_features = self.detector.roi_heads.box_roi_pool(
                        features, [detection['boxes']], images_tensors.image_sizes
                    )
                    box_features = self.detector.roi_heads.box_head(box_features)
                    state_logits = self.state_head(box_features)
                    state_probs = F.softmax(state_logits, dim=-1)
                    state_preds = torch.argmax(state_probs, dim=-1)
                    detection['state_scores'] = state_probs
                    detection['states'] = state_preds
                else:
                    detection['state_scores'] = torch.zeros((0, self.num_states))
                    detection['states'] = torch.zeros((0,), dtype=torch.long)
            
            return detections


def build_model(config):
    """Build model from config"""
    model = BikeDefectDetector(
        num_classes=config.get('num_classes', 22),
        num_states=config.get('num_states', 4),
        backbone_name=config.get('backbone', 'resnet101'),
        pretrained_backbone=config.get('pretrained_backbone', True),
        trainable_backbone_layers=config.get('trainable_backbone_layers', 4),
        min_size=config.get('min_size', 896),
        max_size=config.get('max_size', 1344)
    )
    return model

print("‚úÖ Model architecture defined!")

---
## 4Ô∏è‚É£ Define Dataset & Augmentation

In [None]:
# Complete Dataset Definition
from torch.utils.data import Dataset
import json
import cv2
import numpy as np
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2

class DelftBikesDataset(Dataset):
    """DelftBikes dataset for bike part detection and defect classification"""
    
    PARTS = [
        'back_hand_break', 'back_handle', 'back_light', 'back_mudguard',
        'back_pedal', 'back_reflector', 'back_wheel', 'bell', 'chain',
        'dress_guard', 'dynamo', 'front_handbreak', 'front_handle',
        'front_light', 'front_mudguard', 'front_pedal', 'front_wheel',
        'gear_case', 'kickstand', 'lock', 'saddle', 'steer'
    ]
    
    PART_TO_IDX = {part: idx for idx, part in enumerate(PARTS)}
    
    def __init__(self, annotation_path, image_dir, transform=None, filter_invalid_boxes=True, min_box_size=5):
        self.annotation_path = Path(annotation_path)
        self.image_dir = Path(image_dir)
        self.transform = transform
        self.filter_invalid_boxes = filter_invalid_boxes
        self.min_box_size = min_box_size
        
        with open(self.annotation_path, 'r') as f:
            self.annotations = json.load(f)
        
        self.image_ids = list(self.annotations.keys())
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        img_name = self.image_ids[idx]
        img_data = self.annotations[img_name]
        
        # Load image
        img_path = self.image_dir / img_name
        image = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Ensure image has exactly 3 channels (fix for 2-channel images)
        if image.ndim == 2:
            image = np.stack([image]*3, axis=-1)
        elif image.shape[2] == 1:
            image = np.repeat(image, 3, axis=2)
        elif image.shape[2] == 2:
            # Duplicate last channel to make 3 channels
            image = np.concatenate([image, image[:, :, -1:]], axis=2)
        elif image.shape[2] > 3:
            # Take first 3 channels if more than 3
            image = image[:, :, :3]
        
        # Extract annotations
        boxes = []
        labels = []
        states = []
        
        for part_name, part_data in img_data['parts'].items():
            if 'absolute_bounding_box' not in part_data:
                continue  # skip this part
            bbox = part_data['absolute_bounding_box']
            x1 = bbox['left']
            y1 = bbox['top']
            x2 = x1 + bbox['width']
            y2 = y1 + bbox['height']
            
            # Filter invalid boxes
            if self.filter_invalid_boxes:
                if bbox['width'] < self.min_box_size or bbox['height'] < self.min_box_size:
                    continue
                if x1 >= x2 or y1 >= y2:
                    continue
            
            part_label = self.PART_TO_IDX[part_name]
            state_class = part_data['object_state_class']
            
            boxes.append([x1, y1, x2, y2])
            labels.append(part_label)
            states.append(state_class)
        
        # Convert to numpy
        boxes = np.array(boxes, dtype=np.float32)
        labels = np.array(labels, dtype=np.int64)
        states = np.array(states, dtype=np.int64)
        
        # Apply transforms
        if self.transform is not None:
            transformed = self.transform(
                image=image,
                bboxes=boxes,
                labels=labels,
                states=states
            )
            image = transformed['image']
            boxes = np.array(transformed['bboxes'], dtype=np.float32)
            labels = np.array(transformed['labels'], dtype=np.int64)
            states = np.array(transformed['states'], dtype=np.int64)
    
        # --- Force 3 channels if something went wrong (after transforms) ---
        if isinstance(image, np.ndarray):  # still numpy (before ToTensorV2)
            if image.ndim == 2:
                image = np.stack([image]*3, axis=-1)
            elif image.ndim == 3:
                if image.shape[2] == 1:
                    image = np.repeat(image, 3, axis=2)
                elif image.shape[2] == 2:
                    # Duplicate last channel to make 3 channels
                    image = np.concatenate([image, image[:, :, -1:]], axis=2)
                elif image.shape[2] > 3:
                    # Take first 3 channels if more than 3
                    image = image[:, :, :3]
        else:  # torch tensor after ToTensorV2
            if image.dim() == 2:
                image = image.unsqueeze(0).repeat(3, 1, 1)
            elif image.dim() == 3:
                if image.shape[0] == 1:
                    image = image.repeat(3, 1, 1)
                elif image.shape[0] == 2:
                    # Duplicate last channel to make 3 channels
                    extra_channel = image[-1:, :, :]
                    image = torch.cat([image, extra_channel], dim=0)
                elif image.shape[0] > 3:
                    # Take first 3 channels if more than 3
                    image = image[:3, :, :]
        
        # Final safety check: ensure exactly 3 channels (fix instead of assert)
        if isinstance(image, torch.Tensor):
            if image.dim() == 2:
                image = image.unsqueeze(0).repeat(3, 1, 1)
            elif image.dim() == 3:
                if image.shape[0] == 1:
                    image = image.repeat(3, 1, 1)
                elif image.shape[0] == 2:
                    extra_channel = image[-1:, :, :]
                    image = torch.cat([image, extra_channel], dim=0)
                elif image.shape[0] > 3:
                    image = image[:3, :, :]
        elif isinstance(image, np.ndarray):
            if image.ndim == 2:
                image = np.stack([image]*3, axis=-1)
            elif image.ndim == 3:
                if image.shape[2] == 1:
                    image = np.repeat(image, 3, axis=2)
                elif image.shape[2] == 2:
                    image = np.concatenate([image, image[:, :, -1:]], axis=2)
                elif image.shape[2] > 3:
                    image = image[:, :, :3]
        
        # Handle empty boxes
        if len(boxes) == 0:
            boxes = np.zeros((0, 4), dtype=np.float32)
            labels = np.zeros((0,), dtype=np.int64)
            states = np.zeros((0,), dtype=np.int64)
        
        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        states = torch.as_tensor(states, dtype=torch.int64)
        
        # Calculate areas
        if len(boxes) > 0:
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        else:
            area = torch.zeros((0,), dtype=torch.float32)
        
        # Determine if bike has defect
        has_defect = torch.any((states == 1) | (states == 2) | (states == 3)).item() if len(states) > 0 else False
        
        # ABSOLUTE FINAL CHECK: Ensure image has exactly 3 channels before returning
        # This is critical because the model expects 3-channel RGB images
        if isinstance(image, torch.Tensor):
            if image.dim() == 2:
                image = image.unsqueeze(0).repeat(3, 1, 1)
            elif image.dim() == 3:
                num_channels = image.shape[0]
                if num_channels == 1:
                    image = image.repeat(3, 1, 1)
                elif num_channels == 2:
                    # Add a copy of the last channel
                    image = torch.cat([image, image[-1:, :, :]], dim=0)
                elif num_channels > 3:
                    image = image[:3, :, :]
            # Final verification
            if image.shape[0] != 3:
                print(f"WARNING: Image has {image.shape[0]} channels, fixing...")
                if image.shape[0] < 3:
                    # Repeat last channel to get 3
                    while image.shape[0] < 3:
                        image = torch.cat([image, image[-1:, :, :]], dim=0)
                else:
                    image = image[:3, :, :]
        elif isinstance(image, np.ndarray):
            if image.ndim == 2:
                image = np.stack([image]*3, axis=-1)
            elif image.ndim == 3:
                num_channels = image.shape[2]
                if num_channels == 1:
                    image = np.repeat(image, 3, axis=2)
                elif num_channels == 2:
                    image = np.concatenate([image, image[:, :, -1:]], axis=2)
                elif num_channels > 3:
                    image = image[:, :, :3]
            # Final verification
            if image.shape[2] != 3:
                print(f"WARNING: Image has {image.shape[2]} channels, fixing...")
                if image.shape[2] < 3:
                    while image.shape[2] < 3:
                        image = np.concatenate([image, image[:, :, -1:]], axis=2)
                else:
                    image = image[:, :, :3]
        
        # CRITICAL: Ensure tensor is float and in [0, 1] range for FasterRCNN
        if isinstance(image, torch.Tensor):
            # Convert to float if it's uint8
            if image.dtype == torch.uint8:
                image = image.float() / 255.0
            elif image.dtype != torch.float32:
                image = image.float()
            # Ensure [0, 1] range (in case ToTensorV2 didn't scale correctly)
            if image.max() > 1.0:
                image = image / 255.0
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'states': states,
            'image_id': torch.tensor([idx]),
            'area': area,
            'iscrowd': torch.zeros((len(boxes),), dtype=torch.int64),
            'has_defect': torch.tensor([1 if has_defect else 0], dtype=torch.int64)
        }
        
        return image, target
    

def collate_fn(batch):
    """Custom collate function"""
    images = []
    targets = []
    for image, target in batch:
        images.append(image)
        targets.append(target)
    return images, targets


# Custom transform to ensure 3 channels (multiprocessing-safe)
class Ensure3Channels(A.BasicTransform):
    """Ensure image has exactly 3 channels"""
    def __init__(self, always_apply=True, p=1.0):
        super().__init__(always_apply=always_apply, p=p)
    
    @property
    def targets(self):
        return {"image": self.apply}
    
    def apply(self, image, **params):
        if image.ndim == 2:
            image = np.stack([image]*3, axis=-1)
        elif image.ndim == 3:
            num_channels = image.shape[2]
            if num_channels == 1:
                image = np.repeat(image, 3, axis=2)
            elif num_channels == 2:
                image = np.concatenate([image, image[:, :, -1:]], axis=2)
            elif num_channels > 3:
                image = image[:, :, :3]
        return image


# Advanced Augmentation
def get_strong_train_transform(img_size=896):
    """Strong augmentation for training"""
    return A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT, border_value=0),
        
        A.OneOf([
            A.HorizontalFlip(p=1.0),
            A.VerticalFlip(p=0.3),
            A.RandomRotate90(p=0.5),
        ], p=0.7),
        
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=20, border_mode=cv2.BORDER_CONSTANT, p=0.7),
        A.Perspective(scale=(0.05, 0.1), p=0.3),
        
        A.OneOf([
            A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=1.0),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1.0),
            A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=1.0),
        ], p=0.8),
        
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
            A.RandomGamma(gamma_limit=(80, 120), p=1.0),
            A.CLAHE(clip_limit=4.0, p=1.0),
        ], p=0.5),
        
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0), always_apply=True),
            A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=1.0),
            A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=1.0),
        ], p=0.4),
        
        A.OneOf([
            A.Blur(blur_limit=5, p=1.0),
            A.MotionBlur(blur_limit=5, p=1.0),
            A.MedianBlur(blur_limit=5, p=1.0),
            A.GaussianBlur(blur_limit=(3, 5), p=1.0),
        ], p=0.3),
        
        A.OneOf([
            A.CoarseDropout(max_holes=8, max_height=32, max_width=32, min_holes=1, min_height=8, min_width=8, fill_value=0, always_apply=True),
            A.GridDropout(ratio=0.3, p=1.0),
        ], p=0.2),
        
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
        A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.3),
        
        Ensure3Channels(),  # Ensure 3 channels before converting to tensor
        # NOTE: We do NOT normalize here - FasterRCNN will normalize internally
        # Convert to tensor and scale to [0, 1] range (FasterRCNN expects this)
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels', 'states'], min_area=25, min_visibility=0.3))


def get_val_transform(img_size=896):
    """Validation transform"""
    return A.Compose([
        A.LongestMaxSize(max_size=img_size),
        A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT, border_value=0),
        Ensure3Channels(),  # Ensure 3 channels before converting to tensor
        # NOTE: We do NOT normalize here - FasterRCNN will normalize internally
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels', 'states'], min_area=1, min_visibility=0.0))

print("‚úÖ Dataset and augmentation defined!")

---
## 5Ô∏è‚É£ Define Training System

In [None]:
# Complete Training System
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.amp import autocast, GradScaler
from collections import defaultdict
from tqdm import tqdm
import time
from datetime import datetime

class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss


class UltimateTrainer:
    """Ultimate trainer with all optimizations"""
    
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.num_gpus = torch.cuda.device_count()
        
        print(f"\nüî• Using {self.num_gpus} GPU(s)")
        
        # Initialize WandB
        wandb.init(
            project=config['wandb']['project'],
            name=config['wandb']['run_name'],
            config=config,
            tags=['delftbikes', 'faster-rcnn', 'dual-gpu']
        )
        
        # Build model
        print("üèóÔ∏è  Building model...")
        self.model = build_model(config['model'])
        
        # Multi-GPU
        if self.num_gpus > 1:
            print(f"üöÄ Enabling DataParallel for {self.num_gpus} GPUs")
            self.model = nn.DataParallel(self.model)
        
        self.model.to(self.device)
        
        # Setup optimizer with layer-wise LR
        self.setup_optimizer()
        self.setup_scheduler()
        
        # Mixed precision
        self.use_amp = config['training']['use_amp']
        self.scaler = GradScaler('cuda') if self.use_amp else None
        
        # Training state
        self.current_epoch = 0
        self.global_step = 0
        self.best_val_loss = float('inf')
        self.patience_counter = 0
    
    def setup_optimizer(self):
        """Fixed optimizer setup - no parameter overlap"""
        model = self.model.module if isinstance(self.model, nn.DataParallel) else self.model
        
        # Simple approach: all parameters with same LR (still very effective!)
        params = [p for p in model.parameters() if p.requires_grad]
        
        total_params = sum(p.numel() for p in params)
        print(f"   Configuring optimizer for {len(params)} parameter tensors")
        print(f"   Total trainable parameters: {total_params:,}")
        
        self.optimizer = torch.optim.AdamW(
            params,
            lr=self.config['optimizer']['lr'],
            weight_decay=self.config['optimizer']['weight_decay']
        )
        
        print(f"‚úÖ Optimizer: AdamW (LR={self.config['optimizer']['lr']}, WD={self.config['optimizer']['weight_decay']})")
    
    def setup_scheduler(self):
        total_steps = self.config['training']['epochs'] * 100
        warmup_steps = self.config['training'].get('warmup_epochs', 5) * 100
        
        def lr_lambda(step):
            if step < warmup_steps:
                return step / warmup_steps
            else:
                progress = (step - warmup_steps) / (total_steps - warmup_steps)
                return 0.5 * (1 + np.cos(np.pi * progress))
        
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda)
    
    def prepare_data(self):
        print("\nüìä Preparing datasets...")
        
        train_dataset = DelftBikesDataset(
            annotation_path=self.config['data']['train_annotations'],
            image_dir=self.config['data']['train_images'],
            transform=get_strong_train_transform(self.config['data']['img_size']),
            filter_invalid_boxes=True
        )
        
        val_dataset = DelftBikesDataset(
            annotation_path=self.config['data']['val_annotations'],
            image_dir=self.config['data']['val_images'],
            transform=get_val_transform(self.config['data']['img_size']),
            filter_invalid_boxes=True
        )
        
        print(f"   Train samples: {len(train_dataset)}")
        print(f"   Val samples: {len(val_dataset)}")
        
        # Class-balanced sampling
        if self.config['training'].get('balanced_sampling', True):
            weights = self.compute_sample_weights(train_dataset)
            sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
            shuffle = False
        else:
            sampler = None
            shuffle = True
        
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=self.config['training']['batch_size'],
            sampler=sampler,
            shuffle=shuffle if sampler is None else False,
            num_workers=self.config['training']['num_workers'],
            collate_fn=collate_fn,
            pin_memory=True
        )
        
        self.val_loader = DataLoader(
            val_dataset,
            batch_size=self.config['training']['batch_size'],
            shuffle=False,
            num_workers=self.config['training']['num_workers'],
            collate_fn=collate_fn,
            pin_memory=True
        )
    
    def compute_sample_weights(self, dataset):
        weights = []
        for idx in range(len(dataset)):
            _, target = dataset[idx]
            states = target['states']
            num_damaged = (states == 1).sum().item()
            num_absent = (states == 3).sum().item()
            
            if num_damaged > 0:
                weight = 3.0
            elif num_absent > 0:
                weight = 1.5
            else:
                weight = 1.0
            weights.append(weight)
        return weights
    
    def ensure_3_channels(self, img):
        """Ensure image has exactly 3 channels and is float type"""
        # Ensure float type first
        if img.dtype == torch.uint8:
            img = img.float() / 255.0
        elif img.dtype != torch.float32:
            img = img.float()
        
        # Ensure 3 channels
        if img.dim() == 2:
            img = img.unsqueeze(0).repeat(3, 1, 1)
        elif img.dim() == 3:
            num_channels = img.shape[0]
            if num_channels == 1:
                img = img.repeat(3, 1, 1)
            elif num_channels == 2:
                # Duplicate last channel
                img = torch.cat([img, img[-1:, :, :]], dim=0)
            elif num_channels > 3:
                img = img[:3, :, :]
        
        # Final check: ensure [0, 1] range
        if img.max() > 1.0:
            img = img / 255.0
        
        return img
    
    def train_epoch(self):
        self.model.train()
        epoch_losses = defaultdict(float)
        progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch+1}")
        
        for batch_idx, (images, targets) in enumerate(progress_bar):
            # Ensure all images have 3 channels before moving to device
            images = [self.ensure_3_channels(img) for img in images]
            images = [img.to(self.device) for img in images]
            targets = [{k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
            
            self.optimizer.zero_grad()
            
            if self.use_amp:
                with autocast('cuda'):
                    loss_dict = self.model(images, targets)
                    losses = sum(loss for loss in loss_dict.values())
                self.scaler.scale(losses).backward()
                if self.config['training'].get('clip_grad_norm', 0) > 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['training']['clip_grad_norm'])
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss_dict = self.model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                losses.backward()
                if self.config['training'].get('clip_grad_norm', 0) > 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['training']['clip_grad_norm'])
                self.optimizer.step()
            
            self.scheduler.step()
            
            # Log to WandB
            if batch_idx % 10 == 0:
                wandb.log({
                    'train/total_loss': losses.item(),
                    'train/learning_rate': self.optimizer.param_groups[0]['lr'],
                    'train/epoch': self.current_epoch,
                    **{f'train/{k}': v.item() for k, v in loss_dict.items()}
                })
            
            for k, v in loss_dict.items():
                epoch_losses[k] += v.item()
            epoch_losses['total'] += losses.item()
            
            progress_bar.set_postfix({'loss': losses.item(), 'lr': self.optimizer.param_groups[0]['lr']})
            self.global_step += 1
        
        epoch_losses = {k: v / len(self.train_loader) for k, v in epoch_losses.items()}
        return epoch_losses
    
    @torch.no_grad()
    def validate(self):
        self.model.eval()
        epoch_losses = defaultdict(float)
        all_predictions = []
        all_targets = []
        
        for images, targets in tqdm(self.val_loader, desc="Validation"):
            # Ensure all images have 3 channels before moving to device
            images = [self.ensure_3_channels(img) for img in images]
            images = [img.to(self.device) for img in images]
            targets = [{k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
            
            predictions = self.model(images)
            all_predictions.extend([{k: v.cpu() for k, v in pred.items()} for pred in predictions])
            all_targets.extend([{k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets])
            
            self.model.train()
            loss_dict = self.model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            self.model.eval()
            
            for k, v in loss_dict.items():
                epoch_losses[k] += v.item()
            epoch_losses['total'] += losses.item()
        
        epoch_losses = {k: v / len(self.val_loader) for k, v in epoch_losses.items()}
        metrics = self.compute_metrics(all_predictions, all_targets)
        epoch_losses.update(metrics)
        
        wandb.log({
            'val/total_loss': epoch_losses['total'],
            'val/state_accuracy': metrics.get('state_accuracy', 0),
            'val/damaged_recall': metrics.get('damaged_recall', 0),
            'val/damaged_f1': metrics.get('damaged_f1', 0),
            'val/epoch': self.current_epoch,
        })
        
        return epoch_losses
    
    def compute_metrics(self, predictions, targets):
        state_correct = 0
        state_total = 0
        damaged_tp = damaged_fp = damaged_fn = 0
        
        for pred, target in zip(predictions, targets):
            if 'states' in pred and 'states' in target and len(pred['states']) > 0 and len(target['states']) > 0:
                pred_states = pred['states']
                gt_states = target['states']
                min_len = min(len(pred_states), len(gt_states))
                
                if min_len > 0:
                    state_correct += (pred_states[:min_len] == gt_states[:min_len]).sum().item()
                    state_total += min_len
                    
                    pred_damaged = (pred_states[:min_len] == 1)
                    gt_damaged = (gt_states[:min_len] == 1)
                    damaged_tp += (pred_damaged & gt_damaged).sum().item()
                    damaged_fp += (pred_damaged & ~gt_damaged).sum().item()
                    damaged_fn += (~pred_damaged & gt_damaged).sum().item()
        
        state_accuracy = state_correct / state_total if state_total > 0 else 0.0
        damaged_precision = damaged_tp / (damaged_tp + damaged_fp) if (damaged_tp + damaged_fp) > 0 else 0.0
        damaged_recall = damaged_tp / (damaged_tp + damaged_fn) if (damaged_tp + damaged_fn) > 0 else 0.0
        damaged_f1 = 2 * damaged_precision * damaged_recall / (damaged_precision + damaged_recall) if (damaged_precision + damaged_recall) > 0 else 0.0
        
        return {
            'state_accuracy': state_accuracy,
            'damaged_precision': damaged_precision,
            'damaged_recall': damaged_recall,
            'damaged_f1': damaged_f1
        }
    
    def save_checkpoint(self, is_best=False):
        model_to_save = self.model.module if isinstance(self.model, nn.DataParallel) else self.model
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': model_to_save.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_val_loss': self.best_val_loss,
            'config': self.config
        }
        
        if is_best:
            torch.save(checkpoint, 'best_model.pth')
            print(f"‚úÖ Saved best model!")
    
    def train(self):
        print("\n" + "="*80)
        print("üöÄ STARTING TRAINING")
        print("="*80)
        
        self.prepare_data()
        num_epochs = self.config['training']['epochs']
        
        for epoch in range(num_epochs):
            self.current_epoch = epoch
            start_time = time.time()
            
            print(f"\nüìç Epoch {epoch+1}/{num_epochs}")
            train_losses = self.train_epoch()
            val_losses = self.validate()
            
            epoch_time = time.time() - start_time
            print(f"\n‚è±Ô∏è  Epoch {epoch+1} completed in {epoch_time:.1f}s")
            print(f"   Train Loss: {train_losses['total']:.4f}")
            print(f"   Val Loss: {val_losses['total']:.4f}")
            print(f"   State Accuracy: {val_losses.get('state_accuracy', 0):.4f}")
            print(f"   Damaged F1: {val_losses.get('damaged_f1', 0):.4f}")
            
            is_best = val_losses['total'] < self.best_val_loss
            if is_best:
                self.best_val_loss = val_losses['total']
                self.patience_counter = 0
                print("   üèÜ New best model!")
            else:
                self.patience_counter += 1
            
            self.save_checkpoint(is_best=is_best)
            
            if self.patience_counter >= self.config['training'].get('early_stopping_patience', 20):
                print(f"\n‚ö†Ô∏è  Early stopping triggered")
                break
        
        print("\n" + "="*80)
        print("üéâ TRAINING COMPLETED!")
        print("="*80)
        wandb.finish()

print("‚úÖ Training system defined!")

---
## 6Ô∏è‚É£ Configure Training

In [None]:
# Training Configuration - OPTIMIZED FOR KAGGLE DUAL T4
config = {
    'model': {
        'num_classes': 22,
        'num_states': 4,
        'backbone': 'resnet101',        # ResNet-101 for best accuracy
        'pretrained_backbone': True,
        'trainable_backbone_layers': 4,
        'min_size': 896,                # Optimized for T4
        'max_size': 1344
    },
    'data': {
        'train_annotations': '/kaggle/input/dataset/processed_data/train_split.json',
        'train_images': '/kaggle/input/dataset/processed_data/train_split',
        'val_annotations': '/kaggle/input/dataset/processed_data/val_split.json',
        'val_images': '/kaggle/input/dataset/processed_data/val_split',
        'img_size': 896
    },
    'training': {
        'epochs': 150,
        'batch_size': 6,                # 3 per GPU - perfect for dual T4
        'num_workers': 4,
        'use_amp': True,                # Mixed precision = 2x faster
        'clip_grad_norm': 10.0,
        'early_stopping_patience': 20,
        'warmup_epochs': 5,
        'balanced_sampling': True       # Focus on damaged class (6% of data)
    },
    'optimizer': {
        'type': 'adamw',
        'lr': 2e-4,                     # Higher LR for dual GPU
        'weight_decay': 5e-5
    },
    'wandb': {
        'project': 'delftbikes-defect-detection',
        'run_name': f'dual-t4-resnet101-{datetime.now().strftime("%Y%m%d-%H%M%S")}'
    }
}

print("‚úÖ Configuration loaded")
print(f"   Model: {config['model']['backbone']}")
print(f"   Batch size: {config['training']['batch_size']}")
print(f"   Image size: {config['data']['img_size']}")
print(f"   Epochs: {config['training']['epochs']}")
print(f"   Learning rate: {config['optimizer']['lr']}")

---
## 7Ô∏è‚É£ START TRAINING! üöÄ

In [None]:
# Create trainer and start training
trainer = UltimateTrainer(config)

# This will take 3-4 hours on dual T4
# Monitor progress on WandB dashboard!
trainer.train()

In [None]:
# Check final model
import os

if os.path.exists('best_model.pth'):
    model_size = os.path.getsize('best_model.pth') / (1024 * 1024)
    print(f"‚úÖ Best model saved!")
    print(f"   File: best_model.pth")
    print(f"   Size: {model_size:.1f} MB")
    print(f"\nüì• Download this file from Kaggle output!")
else:
    print("‚ùå Model file not found!")

---
## üìä View Results on WandB

Go to https://wandb.ai to see:
- Training/validation loss curves
- State accuracy over time
- Damaged class metrics
- Learning rate schedule
- GPU utilization
- All per-epoch metrics!