In [1]:
from hcmus.core import appconfig

[32m2025-06-09 16:09:31.877[0m | [1mINFO    [0m | [36mhcmus.core.appconfig[0m:[36m<module>[0m:[36m7[0m - [1mLoad DotEnv: True[0m


In [None]:
# COCO Dataset Merger and Multi-Model Training Pipeline
# Compatible with Kaggle environment

import os
import json
import shutil
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.models import detection
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class COCODatasetMerger:
    """Merge two COCO datasets with custom weights"""

    def __init__(self, dataset1_path, dataset2_path, output_path, weight1=0.5, weight2=0.5):
        self.dataset1_path = Path(dataset1_path)
        self.dataset2_path = Path(dataset2_path)
        self.output_path = Path(output_path)
        self.weight1 = weight1
        self.weight2 = weight2
        self.num_classes = 0  # Will be set during merging
        self.merged_categories = []  # Will store merged categories

        # Ensure weights sum to 1
        total_weight = weight1 + weight2
        self.weight1 = weight1 / total_weight
        self.weight2 = weight2 / total_weight

    def load_coco_annotation(self, ann_path):
        """Load COCO annotation file"""
        with open(ann_path, 'r') as f:
            return json.load(f)

    def merge_datasets(self):
        """Merge two COCO datasets with custom weights"""
        logger.info("Starting dataset merging...")

        # First, analyze categories from both datasets to determine number of classes
        self._analyze_categories()

        # Create output directory structure
        for split in ['train', 'val', 'test']:
            (self.output_path / split / 'images').mkdir(parents=True, exist_ok=True)

        # Process each split
        for split in ['train', 'val', 'test']:
            self._merge_split(split)

        logger.info("Dataset merging completed!")
        logger.info(f"Total number of classes: {self.num_classes}")
        logger.info(f"Categories: {[cat['name'] for cat in self.merged_categories]}")

        return self.num_classes, self.merged_categories

    def _analyze_categories(self):
        """Analyze categories from both datasets to determine total number of classes"""
        logger.info("Analyzing categories from both datasets...")

        all_categories = []

        # Get categories from both datasets by checking any available split
        for dataset_path in [self.dataset1_path, self.dataset2_path]:
            categories_found = False
            for split in ['train', 'val', 'test']:
                ann_path = dataset_path / split / f'annotations_{split}.json'
                if ann_path.exists():
                    ann_data = self.load_coco_annotation(ann_path)
                    if 'categories' in ann_data and ann_data['categories']:
                        all_categories.extend(ann_data['categories'])
                        categories_found = True
                        break

            if not categories_found:
                logger.warning(f"No categories found in dataset: {dataset_path}")

        # Merge categories and remove duplicates
        self.merged_categories = self._merge_categories_list(all_categories)
        self.num_classes = len(self.merged_categories) + 1  # +1 for background class

        logger.info(f"Found {len(self.merged_categories)} unique categories")
        logger.info(f"Total classes (including background): {self.num_classes}")

    def _merge_categories_list(self, all_categories):
        """Merge category lists from multiple sources, removing duplicates"""
        merged_cats = []
        seen_names = set()

        for cat in all_categories:
            if cat['name'] not in seen_names:
                merged_cats.append(cat.copy())
                seen_names.add(cat['name'])

        # Reassign IDs to be sequential starting from 1
        for i, cat in enumerate(merged_cats):
            cat['id'] = i + 1

        return merged_cats
        """Merge a specific split (train/val/test)"""
        logger.info(f"Merging {split} split...")

        # Load annotations
        ann1_path = self.dataset1_path / split / f'annotations_{split}.json'
        ann2_path = self.dataset2_path / split / f'annotations_{split}.json'

        if not ann1_path.exists() or not ann2_path.exists():
            logger.warning(f"Skipping {split} split - annotation files not found")
            return

        ann1 = self.load_coco_annotation(ann1_path)
        ann2 = self.load_coco_annotation(ann2_path)

        # Calculate number of samples based on weights
        total_samples1 = len(ann1['images'])
        total_samples2 = len(ann2['images'])

        n_samples1 = int(total_samples1 * self.weight1 / (self.weight1 + self.weight2 * total_samples2 / total_samples1))
        n_samples2 = int(total_samples2 * self.weight2 / (self.weight2 + self.weight1 * total_samples1 / total_samples2))

        # Sample images
        sampled_imgs1 = random.sample(ann1['images'], min(n_samples1, total_samples1))
        sampled_imgs2 = random.sample(ann2['images'], min(n_samples2, total_samples2))

    def _merge_split(self, split):
        """Merge a specific split (train/val/test)"""
        logger.info(f"Merging {split} split...")

        # Load annotations
        ann1_path = self.dataset1_path / split / f'annotations_{split}.json'
        ann2_path = self.dataset2_path / split / f'annotations_{split}.json'

        if not ann1_path.exists() and not ann2_path.exists():
            logger.warning(f"Skipping {split} split - no annotation files found")
            return

        # Load available annotations
        ann1 = self.load_coco_annotation(ann1_path) if ann1_path.exists() else {'images': [], 'annotations': []}
        ann2 = self.load_coco_annotation(ann2_path) if ann2_path.exists() else {'images': [], 'annotations': []}

        # Calculate number of samples based on weights
        total_samples1 = len(ann1['images'])
        total_samples2 = len(ann2['images'])

        if total_samples1 == 0 and total_samples2 == 0:
            logger.warning(f"No images found in {split} split")
            return

        # Calculate samples to include based on weights
        if total_samples1 > 0 and total_samples2 > 0:
            n_samples1 = max(1, int(total_samples1 * self.weight1))
            n_samples2 = max(1, int(total_samples2 * self.weight2))
        else:
            n_samples1 = total_samples1
            n_samples2 = total_samples2

        # Sample images
        sampled_imgs1 = random.sample(ann1['images'], min(n_samples1, total_samples1)) if total_samples1 > 0 else []
        sampled_imgs2 = random.sample(ann2['images'], min(n_samples2, total_samples2)) if total_samples2 > 0 else []

        # Create merged annotation using pre-analyzed categories
        merged_ann = {
            'info': ann1.get('info', {}) if ann1_path.exists() else ann2.get('info', {}),
            'licenses': ann1.get('licenses', []) if ann1_path.exists() else ann2.get('licenses', []),
            'categories': self.merged_categories,  # Use pre-analyzed categories
            'images': [],
            'annotations': []
        }

        # Create category ID mapping for both datasets
        cat_id_map1 = self._create_category_mapping(ann1.get('categories', []))
        cat_id_map2 = self._create_category_mapping(ann2.get('categories', []))

        # Process dataset 1
        img_id_mapping1 = {}
        ann_id_counter = 1

        for img in sampled_imgs1:
            new_img_id = len(merged_ann['images']) + 1
            img_id_mapping1[img['id']] = new_img_id

            new_img = img.copy()
            new_img['id'] = new_img_id
            merged_ann['images'].append(new_img)

            # Copy image file
            src_img = self.dataset1_path / split / 'images' / img['file_name']
            dst_img = self.output_path / split / 'images' / img['file_name']
            if src_img.exists():
                shutil.copy2(src_img, dst_img)

        # Add annotations for dataset 1
        for ann in ann1.get('annotations', []):
            if ann['image_id'] in img_id_mapping1:
                new_ann = ann.copy()
                new_ann['id'] = ann_id_counter
                new_ann['image_id'] = img_id_mapping1[ann['image_id']]
                # Map category ID to merged category system
                if ann['category_id'] in cat_id_map1:
                    new_ann['category_id'] = cat_id_map1[ann['category_id']]
                    merged_ann['annotations'].append(new_ann)
                    ann_id_counter += 1

        # Process dataset 2
        img_id_mapping2 = {}

        for img in sampled_imgs2:
            new_img_id = len(merged_ann['images']) + 1
            img_id_mapping2[img['id']] = new_img_id

            new_img = img.copy()
            new_img['id'] = new_img_id
            merged_ann['images'].append(new_img)

            # Copy image file
            src_img = self.dataset2_path / split / 'images' / img['file_name']
            dst_img = self.output_path / split / 'images' / img['file_name']
            if src_img.exists():
                shutil.copy2(src_img, dst_img)

        # Add annotations for dataset 2
        for ann in ann2.get('annotations', []):
            if ann['image_id'] in img_id_mapping2:
                new_ann = ann.copy()
                new_ann['id'] = ann_id_counter
                new_ann['image_id'] = img_id_mapping2[ann['image_id']]
                # Map category ID to merged category system
                if ann['category_id'] in cat_id_map2:
                    new_ann['category_id'] = cat_id_map2[ann['category_id']]
                    merged_ann['annotations'].append(new_ann)
                    ann_id_counter += 1

        # Save merged annotation
        output_ann_path = self.output_path / split / f'annotations_{split}.json'
        with open(output_ann_path, 'w') as f:
            json.dump(merged_ann, f)

        logger.info(f"{split} split merged: {len(merged_ann['images'])} images, {len(merged_ann['annotations'])} annotations")

    def _create_category_mapping(self, categories):
        """Create mapping from original category IDs to merged category IDs"""
        cat_map = {}
        for cat in categories:
            # Find corresponding category in merged categories
            for merged_cat in self.merged_categories:
                if merged_cat['name'] == cat['name']:
                    cat_map[cat['id']] = merged_cat['id']
                    break
        return cat_map

    def _merge_categories(self, cats1, cats2):
        """Legacy method - kept for compatibility"""
        return self._merge_categories_list(cats1 + cats2)

class COCODataset(Dataset):
    """COCO Dataset for PyTorch"""

    def __init__(self, root_dir, split='train', transforms=None):
        self.root_dir = Path(root_dir)
        self.split = split
        self.transforms = transforms

        # Load annotations
        ann_path = self.root_dir / split / f'annotations_{split}.json'
        with open(ann_path, 'r') as f:
            self.coco_data = json.load(f)

        self.images = self.coco_data['images']
        self.annotations = self.coco_data['annotations']
        self.categories = {cat['id']: cat['name'] for cat in self.coco_data['categories']}

        # Create image_id to annotations mapping
        self.img_to_anns = {}
        for ann in self.annotations:
            img_id = ann['image_id']
            if img_id not in self.img_to_anns:
                self.img_to_anns[img_id] = []
            self.img_to_anns[img_id].append(ann)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_info = self.images[idx]
        img_path = self.root_dir / self.split / 'images' / img_info['file_name']

        # Load image
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Get annotations for this image
        img_id = img_info['id']
        anns = self.img_to_anns.get(img_id, [])

        # Extract bounding boxes and labels
        boxes = []
        labels = []

        for ann in anns:
            x, y, w, h = ann['bbox']
            boxes.append([x, y, x + w, y + h])
            labels.append(ann['category_id'])

        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([img_id])
        }

        if self.transforms:
            # Apply albumentations transforms
            transformed = self.transforms(image=image, bboxes=boxes.numpy(), class_labels=labels.numpy())
            image = transformed['image']
            if 'bboxes' in transformed and len(transformed['bboxes']) > 0:
                target['boxes'] = torch.as_tensor(transformed['bboxes'], dtype=torch.float32)
        else:
            image = transforms.ToTensor()(image)

        return image, target

class ModelFactory:
    """Factory class to create different detection models"""

    @staticmethod
    def create_faster_rcnn(num_classes, backbone='resnet50'):
        """Create Faster R-CNN model"""
        if backbone == 'resnet34':
            # Custom Faster R-CNN with ResNet34 backbone
            from torchvision.models import resnet34
            backbone_model = resnet34(pretrained=True)
            backbone_model = nn.Sequential(*list(backbone_model.children())[:-2])
            model = detection.FasterRCNN(backbone_model, num_classes=num_classes)
        elif backbone == 'resnet50':
            model = detection.fasterrcnn_resnet50_fpn(pretrained=True)
            in_features = model.roi_heads.box_predictor.cls_score.in_features
            model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")

        return model

    @staticmethod
    def create_yolov8_alternative(num_classes):
        """Create YOLOv8-like model using torchvision components"""
        # Since YOLOv8 isn't directly available in torchvision, we'll use RetinaNet as alternative
        model = detection.retinanet_resnet50_fpn(pretrained=True)
        num_anchors = model.head.classification_head.num_anchors
        model.head.classification_head.num_classes = num_classes

        # Reinitialize the classification head
        model.head.classification_head.cls_logits = nn.Conv2d(
            256, num_anchors * num_classes, kernel_size=3, stride=1, padding=1
        )

        return model

    @staticmethod
    def create_detr_alternative(num_classes, backbone='resnet50'):
        """Create DETR-like model using available components"""
        # Since DETR isn't directly available, we'll use SSD as an alternative
        if backbone == 'resnet50':
            # Use MobileNetV3 SSD as it's available in torchvision
            model = detection.ssd300_vgg16(pretrained=True)
            # Adjust for number of classes
            model.head.classification_head.num_classes = num_classes
        else:
            # Fallback to Faster R-CNN for other backbones
            model = ModelFactory.create_faster_rcnn(num_classes, backbone)

        return model

class MultiModelTrainer:
    """Train multiple detection models"""

    def __init__(self, dataset_path, num_classes=None, device='cuda'):
        self.dataset_path = dataset_path
        self.device = device if torch.cuda.is_available() else 'cpu'

        # If num_classes not provided, infer from dataset
        if num_classes is None:
            self.num_classes = self._infer_num_classes()
        else:
            self.num_classes = num_classes

        logger.info(f"Training with {self.num_classes} classes (including background)")

        # Define transforms
        self.train_transforms = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))

        self.val_transforms = A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))

    def _infer_num_classes(self):
        """Infer number of classes from dataset annotations"""
        max_class_id = 0

        for split in ['train', 'val', 'test']:
            ann_path = Path(self.dataset_path) / split / f'annotations_{split}.json'
            if ann_path.exists():
                with open(ann_path, 'r') as f:
                    ann_data = json.load(f)

                if 'categories' in ann_data and ann_data['categories']:
                    # Get maximum category ID
                    for cat in ann_data['categories']:
                        max_class_id = max(max_class_id, cat['id'])
                    break

                # Fallback: check annotations for max category_id
                if 'annotations' in ann_data:
                    for ann in ann_data['annotations']:
                        max_class_id = max(max_class_id, ann['category_id'])

        # Add 1 for background class
        return max_class_id + 1

    def create_data_loaders(self, batch_size=4):
        """Create data loaders"""
        train_dataset = COCODataset(self.dataset_path, 'train', self.train_transforms)
        val_dataset = COCODataset(self.dataset_path, 'val', self.val_transforms)

        def collate_fn(batch):
            return tuple(zip(*batch))

        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True,
            collate_fn=collate_fn, num_workers=2
        )
        val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False,
            collate_fn=collate_fn, num_workers=2
        )

        return train_loader, val_loader

    def train_model(self, model, train_loader, val_loader, epochs=10, lr=0.001):
        """Train a single model"""
        model.to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

        best_loss = float('inf')

        for epoch in range(epochs):
            # Training phase
            model.train()
            train_loss = 0

            train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
            for images, targets in train_pbar:
                images = [img.to(self.device) for img in images]
                targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]

                optimizer.zero_grad()
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())

                losses.backward()
                optimizer.step()

                train_loss += losses.item()
                train_pbar.set_postfix({'loss': losses.item()})

            # Validation phase
            model.eval()
            val_loss = 0

            with torch.no_grad():
                val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
                for images, targets in val_pbar:
                    images = [img.to(self.device) for img in images]
                    targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]

                    loss_dict = model(images, targets)
                    losses = sum(loss for loss in loss_dict.values())
                    val_loss += losses.item()
                    val_pbar.set_postfix({'loss': losses.item()})

            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)

            logger.info(f'Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

            # Save best model
            if avg_val_loss < best_loss:
                best_loss = avg_val_loss
                torch.save(model.state_dict(), f'best_model_epoch_{epoch+1}.pth')

            scheduler.step()

        return model

    def train_all_models(self, epochs=10, batch_size=4):
        """Train all model configurations"""
        # Create data loaders
        train_loader, val_loader = self.create_data_loaders(batch_size)

        # Model configurations
        configs = [
            ('faster_rcnn', 'resnet34'),
            ('faster_rcnn', 'resnet50'),
            ('yolov8_alt', 'resnet50'),  # Using RetinaNet as alternative
            ('detr_alt', 'resnet50'),    # Using SSD as alternative
        ]

        results = {}

        for model_type, backbone in configs:
            logger.info(f"Training {model_type} with {backbone} backbone...")

            try:
                # Create model
                if model_type == 'faster_rcnn':
                    model = ModelFactory.create_faster_rcnn(self.num_classes, backbone)
                elif model_type == 'yolov8_alt':
                    model = ModelFactory.create_yolov8_alternative(self.num_classes)
                elif model_type == 'detr_alt':
                    model = ModelFactory.create_detr_alternative(self.num_classes, backbone)

                # Train model
                trained_model = self.train_model(model, train_loader, val_loader, epochs)

                # Save final model
                model_name = f'{model_type}_{backbone}'
                torch.save(trained_model.state_dict(), f'{model_name}_final.pth')
                results[model_name] = trained_model

                logger.info(f"Completed training {model_name}")

            except Exception as e:
                logger.error(f"Error training {model_type} with {backbone}: {str(e)}")
                continue

        return results

def main():
    """Main execution function"""
    # Configuration
    DATASET1_PATH = '/Volumes/Cucumber/Projects/datasets/sku110k'  # Update with actual path
    DATASET2_PATH = '/Volumes/Cucumber/Projects/datasets/hcmus-iid'  # Update with actual path
    OUTPUT_PATH = '/Volumes/Cucumber/Projects/datasets/merged-dataset'
    WEIGHT1 = 0.8  # Weight for dataset 1
    WEIGHT2 = 0.2  # Weight for dataset 2
    EPOCHS = 5  # Reduced for Kaggle time limits
    BATCH_SIZE = 4  # Reduced for memory constraints

    # Step 1: Merge datasets and get number of classes
    logger.info("Step 1: Merging datasets...")
    merger = COCODatasetMerger(DATASET1_PATH, DATASET2_PATH, OUTPUT_PATH, WEIGHT1, WEIGHT2)
    num_classes, categories = merger.merge_datasets()

    logger.info(f"Dataset merged successfully!")
    logger.info(f"Number of classes: {num_classes}")
    logger.info(f"Categories: {[cat['name'] for cat in categories]}")

    # Step 2: Train models with inferred number of classes
    logger.info("Step 2: Training models...")
    trainer = MultiModelTrainer(OUTPUT_PATH, num_classes)
    results = trainer.train_all_models(epochs=EPOCHS, batch_size=BATCH_SIZE)

    # Step 3: Summary
    logger.info("Training completed!")
    logger.info(f"Trained models: {list(results.keys())}")
    logger.info(f"Final number of classes used: {num_classes}")

    return results, num_classes, categories

# Example usage for Kaggle
if __name__ == "__main__":
    # Install required packages (uncomment if needed)
    # !pip install albumentations

    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    # Run main pipeline
    results, num_classes, categories = main()

    print("Pipeline completed successfully!")
    print(f"Number of classes: {num_classes}")
    print(f"Categories found: {[cat['name'] for cat in categories]}")
    print(f"Available models: {list(results.keys()) if results else 'None'}")


In [None]:

# Additional utility functions for evaluation and inference

def evaluate_model(model, test_loader, device='cuda'):
    """Evaluate a trained model"""
    model.eval()
    predictions = []

    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc='Evaluating'):
            images = [img.to(device) for img in images]
            outputs = model(images)
            predictions.extend(outputs)

    return predictions

def visualize_predictions(model, dataset, device='cuda', num_samples=5):
    """Visualize model predictions"""
    model.eval()

    fig, axes = plt.subplots(1, num_samples, figsize=(20, 4))

    for i in range(num_samples):
        image, target = dataset[i]

        with torch.no_grad():
            prediction = model([image.to(device)])

        # Convert image for visualization
        img_np = image.permute(1, 2, 0).numpy()
        img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
        img_np = np.clip(img_np, 0, 1)

        axes[i].imshow(img_np)
        axes[i].set_title(f'Sample {i+1}')
        axes[i].axis('off')

        # Draw bounding boxes
        boxes = prediction[0]['boxes'].cpu().numpy()
        scores = prediction[0]['scores'].cpu().numpy()

        for box, score in zip(boxes, scores):
            if score > 0.5:  # Confidence threshold
                x1, y1, x2, y2 = box
                axes[i].add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
                                              fill=False, color='red', linewidth=2))

    plt.tight_layout()
    plt.savefig('predictions_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()

# Configuration helper
def create_kaggle_config():
    """Create configuration for Kaggle environment"""
    config = {
        'dataset1_path': '/kaggle/input/dataset1',
        'dataset2_path': '/kaggle/input/dataset2',
        'output_path': '/kaggle/working/merged_dataset',
        'weight1': 0.6,
        'weight2': 0.4,
        'epochs': 5,
        'batch_size': 2,
        'learning_rate': 0.001,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }
    return config

def get_dataset_info(dataset_path):
    """Get information about a COCO dataset"""
    info = {
        'num_classes': 0,
        'categories': [],
        'splits': {},
        'total_images': 0,
        'total_annotations': 0
    }

    dataset_path = Path(dataset_path)

    # Check each split
    for split in ['train', 'val', 'test']:
        ann_path = dataset_path / split / f'annotations_{split}.json'
        if ann_path.exists():
            with open(ann_path, 'r') as f:
                ann_data = json.load(f)

            split_info = {
                'images': len(ann_data.get('images', [])),
                'annotations': len(ann_data.get('annotations', [])),
                'categories': len(ann_data.get('categories', []))
            }
            info['splits'][split] = split_info
            info['total_images'] += split_info['images']
            info['total_annotations'] += split_info['annotations']

            # Get categories (use first available split)
            if not info['categories'] and ann_data.get('categories'):
                info['categories'] = ann_data['categories']
                info['num_classes'] = len(ann_data['categories']) + 1  # +1 for background

    return info