In [None]:
import os

In [None]:
if not os.path.exists("transforms.py"):
        with open("transforms.py", "w") as f:
            f.write("""
import random
import torch
from torchvision.transforms import functional as F

class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target
""")

In [None]:
import os
import numpy as np
import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import transforms as T
from tqdm import tqdm
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import random
from sklearn.model_selection import train_test_split
import cv2

class YOLOtoRCNNDataset(Dataset):
    def __init__(self, img_dir, label_dir, transforms=None, class_names=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transforms = transforms
        
        # Get all image files with annotations
        self.imgs = [f for f in os.listdir(img_dir) if os.path.exists(os.path.join(label_dir, os.path.splitext(f)[0] + '.txt'))]
        
        # Default class names
        if class_names is None:
            self.class_names = {
                0: 'background',  # Required for RCNN
                1: 'biker',      
                2: 'car',      
                3: 'pedestrian',      
                4: 'traffic-light',
                5: 'traffic-light-green',
                6: 'traffic-light-greenleft',
                7: 'traffic-light-red',
                8: 'traffic-light-redleft',
                9: 'traffic-light-yellow',
                10: 'traffic-light-yellowleft',
                11: 'truck',
                12:'arret'
            }
        else:
            self.class_names = class_names
            
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.img_dir, self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        
        # Load annotation
        label_path = os.path.join(self.label_dir, os.path.splitext(self.imgs[idx])[0] + '.txt')
        
        # Get image dimensions
        width, height = img.size
        
        boxes = []
        labels = []
        
        # Check if file exists and is not empty
        if os.path.exists(label_path) and os.path.getsize(label_path) > 0:
            with open(label_path, 'r') as f:
                for line in f.readlines():
                    line = line.strip()
                    if not line:  # Skip empty lines
                        continue
                    
                    data = line.split(' ')
                    if len(data) < 5:  # Skip invalid lines
                        continue
                        
                    # YOLO format: class_id center_x center_y width height (normalized)
                    try:
                        class_id = int(data[0])
                        x_center = float(data[1]) * width
                        y_center = float(data[2]) * height
                        box_width = float(data[3]) * width
                        box_height = float(data[4]) * height
                        
                        # Convert to [x_min, y_min, x_max, y_max] format for RCNN
                        x_min = x_center - (box_width / 2)
                        y_min = y_center - (box_height / 2)
                        x_max = x_center + (box_width / 2)
                        y_max = y_center + (box_height / 2)
                        
                        boxes.append([x_min, y_min, x_max, y_max])
                        # Add 1 to class_id to account for background class in RCNN
                        labels.append(class_id + 1)
                    except (ValueError, IndexError):
                        # Skip invalid entries
                        continue
        
        # Handle empty annotations
        if len(boxes) == 0:
            # Create a dummy box outside the image (will be filtered out during training)
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0), dtype=torch.int64)
            area = torch.zeros((0), dtype=torch.float32)
            iscrowd = torch.zeros((0), dtype=torch.int64)
        else:
            # Convert to tensors
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        
        image_id = torch.tensor([idx])
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        return img, target

# Transforms for data augmentation and normalization
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

# Function to get the Faster R-CNN model
def get_model(num_classes):
    # Load pre-trained model
    model = fasterrcnn_resnet50_fpn(weights='DEFAULT')
    
    # Replace the classifier with a new one for our number of classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model

# Custom collate function to handle empty annotations
def collate_fn(batch):
    return tuple(zip(*batch))

# Training function
def train_one_epoch(model, optimizer, data_loader, device):
    model.train()
    
    total_loss = 0
    num_batches = 0
    
    for images, targets in tqdm(data_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        # Skip batches with only empty annotations
        if all(len(t["boxes"]) == 0 for t in targets):
            continue
            
        try:
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            total_loss += losses.item()
            num_batches += 1
        except Exception as e:
            print(f"Error in batch: {e}")
            continue
    
    if num_batches == 0:
        return 0.0
    return total_loss / num_batches

# Evaluation function
def evaluate(model, data_loader, device):
    model.eval()
    metric = MeanAveragePrecision()
    
    with torch.no_grad():
        for images, targets in tqdm(data_loader):
            images = list(img.to(device) for img in images)
            
            # Skip images with no annotations during evaluation
            valid_targets = []
            valid_images = []
            valid_indices = []
            
            for i, target in enumerate(targets):
                if len(target["boxes"]) > 0:
                    valid_targets.append({k: v.to(device) for k, v in target.items()})
                    valid_images.append(images[i])
                    valid_indices.append(i)
            
            if not valid_images:
                continue
                
            try:
                outputs = model(valid_images)
                
                # Format outputs for MeanAveragePrecision
                preds = []
                for output in outputs:
                    pred = {
                        'boxes': output['boxes'],
                        'scores': output['scores'],
                        'labels': output['labels']
                    }
                    preds.append(pred)
                
                metric.update(preds, valid_targets)
            except Exception as e:
                print(f"Error during evaluation: {e}")
                continue
    
    return metric.compute()

# Set random seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Main function
def main():
    # Set random seed
    set_seed()
    
    # Define paths
    base_dir = "/kaggle/input/yolo-dataset/dataset_1/dataset_1/train"
    img_dir = os.path.join(base_dir, "images")
    label_dir = os.path.join(base_dir, "labels")
    
    print("Starting RCNN training with YOLO annotations...")
    print("Note: YOLO class IDs will be shifted by +1 to accommodate RCNN's background class")
    
    # Check if directories exist
    if not os.path.exists(img_dir) or not os.path.exists(label_dir):
        print(f"Error: Directories not found - images: {os.path.exists(img_dir)}, labels: {os.path.exists(label_dir)}")
        return
    
    # Automatically determine the class IDs from your YOLO annotations
    print("Scanning labels to determine classes...")
    unique_class_ids = set()

    for i in range(12):
      unique_class_ids.add(i)
    
    
    print(f"Found {len(unique_class_ids)} unique classes in YOLO annotations: {sorted(unique_class_ids)}")
    
    # Define class names with background as 0 and YOLO classes shifted by +1
    class_names = {0: 'background'} 
    
    # Add the YOLO classes with +1 offset
    for yolo_class_id in sorted(unique_class_ids):
        rcnn_class_id = yolo_class_id + 1
        class_names[rcnn_class_id] = f'class{yolo_class_id}'
    
    print(f"Mapped to RCNN classes: {class_names}")
    
    
    # Number of classes (including background)
    num_classes = len(class_names)
    
    # Create dataset
    full_dataset = YOLOtoRCNNDataset(img_dir, label_dir, get_transform(train=True), class_names)
    
    # Filter out empty annotations for training
    valid_indices = []
    print("Checking for valid annotations...")
    for idx in tqdm(range(len(full_dataset))):
        _, target = full_dataset[idx]
        if len(target["boxes"]) > 0:
            valid_indices.append(idx)
    
    print(f"Found {len(valid_indices)}/{len(full_dataset)} images with valid annotations")
    
    # Split dataset into train and test
    train_indices, val_indices = train_test_split(valid_indices, test_size=0.2, random_state=42)
    
    print(f"Training on {len(train_indices)} images, validating on {len(val_indices)} images")
    
    # Create data samplers
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)
    
    # Create data loaders with custom collate function
    train_loader = DataLoader(
        full_dataset,
        batch_size=4,
        sampler=train_sampler,
        collate_fn=collate_fn
    )
    
    val_dataset = YOLOtoRCNNDataset(img_dir, label_dir, get_transform(train=False), class_names)
    val_loader = DataLoader(
        val_dataset,
        batch_size=2,
        sampler=val_sampler,
        collate_fn=collate_fn
    )
    
    # Get device
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(f"Using device: {device}")
    
    # Get model
    model = get_model(num_classes)
    model.to(device)
    
    # Define optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    
    # Learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    
    # Number of epochs
    num_epochs = 1
    
    # Training loop
    print("Starting training...")
    for epoch in range(num_epochs):
        # Train for one epoch
        train_loss = train_one_epoch(model, optimizer, train_loader, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}")
        
        # Update the learning rate
        lr_scheduler.step()
        
        # Evaluate on the validation dataset
        if (epoch + 1) % 2 == 0 or (epoch + 1) == num_epochs:
            print("Evaluating...")
            metrics = evaluate(model, val_loader, device)
            print(f"Validation mAP: {metrics['map']:.4f}")
            print(f"mAP@50: {metrics['map_50']:.4f}")
            print(f"mAP@75: {metrics['map_75']:.4f}")
    
    # Save the model
    torch.save(model.state_dict(), "faster_rcnn_model.pth")
    print("Model saved to faster_rcnn_model.pth")
    
    # Final evaluation
    print("Performing final evaluation...")
    final_metrics = evaluate(model, val_loader, device)
    print("Final Evaluation Metrics:")
    for k, v in final_metrics.items():
        print(f"{k}: {v:.4f}")

    # Visualize a few predictions
    visualize_predictions(model, val_dataset, val_indices[:5], device, class_names)

def visualize_predictions(model, dataset, indices, device, class_names):
    model.eval()
    
    # Make sure we have indices to visualize
    if not indices:
        print("No valid images to visualize")
        return
        
    fig, axs = plt.subplots(len(indices), 2, figsize=(15, 5 * len(indices)))
    
    # Handle single image case
    if len(indices) == 1:
        axs = np.array([axs])
    
    for i, idx in enumerate(indices):
        img, target = dataset[idx]
        # Original image with ground truth
        img_np = img.permute(1, 2, 0).cpu().numpy()
        
        # Handle single image case
        if len(indices) == 1:
            ax_gt = axs[0]
            ax_pred = axs[1]
        else:
            ax_gt = axs[i, 0]
            ax_pred = axs[i, 1]
            
        ax_gt.imshow(img_np)
        ax_gt.set_title('Ground Truth')
        
        # Draw ground truth boxes
        if len(target['boxes']) > 0:
            for box, label in zip(target['boxes'], target['labels']):
                box = box.cpu().numpy()
                rect = plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1],
                                   linewidth=2, edgecolor='r', facecolor='none')
                ax_gt.add_patch(rect)
                ax_gt.text(box[0], box[1], class_names[label.item()], 
                             color='white', fontsize=10, bbox=dict(facecolor='r', alpha=0.5))
        else:
            ax_gt.text(10, 10, "No annotations", color='red', fontsize=12)
        
        # Prediction
        ax_pred.imshow(img_np)
        ax_pred.set_title('Prediction')
        
        with torch.no_grad():
            try:
                prediction = model([img.to(device)])
                prediction = {k: v.cpu() for k, v in prediction[0].items()}
                
                if len(prediction['boxes']) > 0:
                    for box, score, label in zip(prediction['boxes'], prediction['scores'], prediction['labels']):
                        if score > 0.5:  # Only show predictions with confidence > 0.5
                            box = box.numpy()
                            rect = plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1],
                                               linewidth=2, edgecolor='b', facecolor='none')
                            ax_pred.add_patch(rect)
                            ax_pred.text(box[0], box[1], f"{class_names[label.item()]}: {score:.2f}", 
                                         color='white', fontsize=10, bbox=dict(facecolor='b', alpha=0.5))
                else:
                    ax_pred.text(10, 10, "No predictions", color='blue', fontsize=12)
            except Exception as e:
                print(f"Error visualizing predictions for image {idx}: {e}")
                ax_pred.text(10, 10, "Error in prediction", color='red', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('predictions.png')
    plt.close()
    print("Predictions visualization saved to predictions.png")

if __name__ == "__main__":
    # Create transforms.py file if it doesn't exist
    if not os.path.exists("transforms.py"):
        with open("transforms.py", "w") as f:
            f.write("""
import random
import torch
from torchvision.transforms import functional as F

class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            # Only flip if there are boxes
            if len(bbox) > 0:
                bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
                target["boxes"] = bbox
        return image, target
""")
    
    main()