In [3]:
import os
import tensorflow as tf
import tensorflow_datasets as tfds
import pandas as pd
import numpy as np
import shutil
import xml.etree.ElementTree as ET
from PIL import Image
import csv

2025-03-12 12:02:53.592938: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741777373.607581   25295 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741777373.611631   25295 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1741777373.625489   25295 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1741777373.625511   25295 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1741777373.625513   25295 computation_placer.cc:177] computation placer alr

In [4]:
def download_voc_dataset(output_dir, dataset_name="voc", dataset_version="2007"):
    """
    Downloads the Pascal VOC object detection dataset using TensorFlow Datasets,
    organizes it into train/val/test/all folders, and creates a CSV file with labels.
    Uses original image filenames instead of renaming them.
    
    If the dataset is already downloaded, it will skip the download step.
    Also generates VOC-format XML annotation files.
    
    Args:
        output_dir (str): Directory to save the dataset
        dataset_name (str): Dataset name ('voc')
        dataset_version (str): Dataset version ('2007' or '2012')
    
    Returns:
        dict: Paths to the organized dataset directories and label CSV
    """
    import xml.etree.ElementTree as ET
    from xml.dom import minidom
    import time
    
    # Create output directory structure
    os.makedirs(output_dir, exist_ok=True)
    train_dir = os.path.join(output_dir, "train")
    val_dir = os.path.join(output_dir, "val")
    test_dir = os.path.join(output_dir, "test")
    all_dir = os.path.join(output_dir, "all")
    
    for directory in [train_dir, val_dir, test_dir, all_dir]:
        os.makedirs(directory, exist_ok=True)
        # Create subdirectories for images and annotations
        os.makedirs(os.path.join(directory, "images"), exist_ok=True)
        os.makedirs(os.path.join(directory, "annotations"), exist_ok=True)
    
    # Define the dataset name with version
    full_dataset_name = f"{dataset_name}/{dataset_version}"
    
    # Set a custom download directory to avoid permission issues
    temp_data_dir = os.path.join(os.path.expanduser("~"), "tfds_temp")
    os.makedirs(temp_data_dir, exist_ok=True)
    
    # Check if the dataset is already downloaded
    builder = tfds.builder(full_dataset_name, data_dir=temp_data_dir)
    
    try:
        # Check if dataset info exists, which indicates the dataset is downloaded
        if builder.info_path.exists():
            print(f"Dataset {full_dataset_name} already exists, skipping download...")
            download = False
        else:
            print(f"Downloading {full_dataset_name} dataset...")
            download = True
    except Exception as e:
        # If any error occurs during checking, attempt to download
        print(f"Checking dataset status failed ({str(e)}). Attempting to download {full_dataset_name}...")
        download = True
    
    # Download and prepare the dataset
    try:
        dataset, info = tfds.load(
            name=full_dataset_name,
            with_info=True,
            split=['train', 'validation', 'test'],
            download=download,
            data_dir=temp_data_dir
        )
        print(f"Dataset loaded successfully")
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        print("Please check your internet connection and try again.")
        return None
    
    # Create a list to store label information for CSV
    label_data = []
    
    # VOC class names
    class_names = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]
    
    # Function to create VOC XML annotation file
    def create_voc_xml(filename, width, height, objects_list, folder, output_path):
        # Create the root element
        root = ET.Element("annotation")
        
        # Add basic elements
        ET.SubElement(root, "folder").text = folder
        ET.SubElement(root, "filename").text = filename
        ET.SubElement(root, "path").text = filename  # Add path element
        
        # Add size information
        size = ET.SubElement(root, "size")
        ET.SubElement(size, "width").text = str(width)
        ET.SubElement(size, "height").text = str(height)
        ET.SubElement(size, "depth").text = "3"  # Assuming RGB
        
        # Add segmented (always 0 for object detection)
        ET.SubElement(root, "segmented").text = "0"
        
        # Add object elements
        for obj in objects_list:
            object_elem = ET.SubElement(root, "object")
            ET.SubElement(object_elem, "name").text = obj['class_name']
            ET.SubElement(object_elem, "pose").text = "Unspecified"
            ET.SubElement(object_elem, "truncated").text = "0"
            ET.SubElement(object_elem, "difficult").text = "0"
            
            bbox = ET.SubElement(object_elem, "bndbox")
            ET.SubElement(bbox, "xmin").text = str(obj['xmin'])
            ET.SubElement(bbox, "ymin").text = str(obj['ymin'])
            ET.SubElement(bbox, "xmax").text = str(obj['xmax'])
            ET.SubElement(bbox, "ymax").text = str(obj['ymax'])
        
        # Create pretty XML string - fix the bytes to string conversion
        xml_str = minidom.parseString(ET.tostring(root).decode('utf-8')).toprettyxml(indent="  ")
        
        # Save to file
        try:
            with open(output_path, "w") as f:
                f.write(xml_str)
            return True
        except Exception as e:
            print(f"Error writing XML file {output_path}: {str(e)}")
            return False
    
    # Process each split
    split_dirs = {
        0: train_dir,  # train
        1: val_dir,    # validation
        2: test_dir    # test
    }
    
    for split_idx, split_dataset in enumerate(dataset):
        split_dir = split_dirs[split_idx]
        split_name = ['train', 'validation', 'test'][split_idx]
        print(f"Processing {split_name} split...")
        
        # Check if this split has already been processed
        split_images_dir = os.path.join(split_dir, "images")
        split_annot_dir = os.path.join(split_dir, "annotations")
        existing_images = len(os.listdir(split_images_dir)) if os.path.exists(split_images_dir) else 0
        existing_annots = len(os.listdir(split_annot_dir)) if os.path.exists(split_annot_dir) else 0
        
        if existing_images > 0 and existing_annots > 0:
            print(f"Found {existing_images} existing images and {existing_annots} annotations in {split_name} split. Checking if complete...")
            
            # Quick check: count approximate number of examples in this split
            sample_count = 0
            for _ in split_dataset:
                sample_count += 1
                if sample_count > 100:  # Just check the first 100 to save time
                    break
            
            if existing_images >= sample_count * 0.9 and existing_annots >= sample_count * 0.9:
                print(f"Split {split_name} appears to be already processed. Skipping...")
                continue
        
        # Materialize the dataset into a list to avoid streaming it twice
        try:
            print(f"Loading {split_name} split data into memory...")
            split_data = list(split_dataset)
            print(f"Loaded {len(split_data)} examples for {split_name} split")
        except Exception as e:
            print(f"Error materializing dataset: {str(e)}")
            print("Trying to process streaming instead...")
            split_data = split_dataset
        
        # Dictionary to keep track of objects for each image, to create XML annotations
        image_objects = {}
        
        # First pass: collect all objects for each image
        print(f"Collecting object information for {split_name} split...")
        start_time = time.time()
        for i, example in enumerate(split_data):
            if i % 100 == 0:
                elapsed = time.time() - start_time
                print(f"  Scanning objects for image {i} in {split_name} split ({elapsed:.2f} seconds elapsed)")
            
            try:
                # Get original filename directly
                original_filename = example['image/filename'].numpy().decode('utf-8')
                image = example['image'].numpy()
                objects = example['objects']
                
                # Initialize objects list for this image if not already done
                if original_filename not in image_objects:
                    image_objects[original_filename] = {
                        'width': image.shape[1],
                        'height': image.shape[0],
                        'objects': []
                    }
                
                # Process bounding boxes and labels
                if 'bbox' in objects:
                    bboxes = objects['bbox'].numpy()  # [ymin, xmin, ymax, xmax] format
                    labels = objects['label'].numpy()
                    
                    # Add each object to the list
                    for j, (bbox, label_idx) in enumerate(zip(bboxes, labels)):
                        ymin, xmin, ymax, xmax = bbox
                        
                        # Ensure label_idx is valid
                        if label_idx < 0 or label_idx >= len(class_names):
                            print(f"  Warning: Invalid label index {label_idx} for image {original_filename}")
                            continue
                            
                        class_name = class_names[label_idx]
                        
                        # Convert normalized coordinates to pixel coordinates
                        xmin_px = max(0, int(xmin * image.shape[1]))
                        ymin_px = max(0, int(ymin * image.shape[0]))
                        xmax_px = min(image.shape[1], int(xmax * image.shape[1]))
                        ymax_px = min(image.shape[0], int(ymax * image.shape[0]))
                        
                        # Skip invalid boxes
                        if xmin_px >= xmax_px or ymin_px >= ymax_px:
                            print(f"  Warning: Invalid box dimensions for {original_filename}, object {j}")
                            continue
                        
                        # Add to objects list for this image
                        image_objects[original_filename]['objects'].append({
                            'class_name': class_name,
                            'class_id': int(label_idx),
                            'xmin': xmin_px,
                            'ymin': ymin_px,
                            'xmax': xmax_px,
                            'ymax': ymax_px
                        })
                        
                        # Add to label data for CSV
                        label_data.append({
                            'image_filename': original_filename,
                            'split': split_name,
                            'class_name': class_name,
                            'class_id': int(label_idx),
                            'xmin': xmin_px,
                            'ymin': ymin_px,
                            'xmax': xmax_px,
                            'ymax': ymax_px,
                            'width': image.shape[1],
                            'height': image.shape[0]
                        })
            except Exception as e:
                print(f"  Error processing example {i} in {split_name} split: {str(e)}")
                continue
        
        # Second pass: process images and create annotations
        processed_images = set()  # Initialize set to keep track of processed images
        print(f"Saving images and creating annotations for {split_name} split...")
        start_time = time.time()
        
        for i, example in enumerate(split_data):
            if i % 100 == 0:
                elapsed = time.time() - start_time
                print(f"  Processing image {i} in {split_name} split ({elapsed:.2f} seconds elapsed)")
            
            try:
                # Extract data
                image = example['image'].numpy()
                original_filename = example['image/filename'].numpy().decode('utf-8')
                
                # Skip if already processed this image
                if original_filename in processed_images:
                    continue
                
                # Mark as processed
                processed_images.add(original_filename)
                
                # Save image to both split directory and all directory
                img_path = os.path.join(split_dir, "images", original_filename)
                all_img_path = os.path.join(all_dir, "images", original_filename)
                
                if not os.path.exists(img_path) or not os.path.exists(all_img_path):
                    try:
                        img = Image.fromarray(image)
                        img.save(img_path)
                        img.save(all_img_path)
                    except Exception as e:
                        print(f"  Error saving image {original_filename}: {str(e)}")
                
                # Create annotation XML files
                if original_filename in image_objects:
                    # Create filename for annotation (change extension to .xml)
                    xml_filename = os.path.splitext(original_filename)[0] + '.xml'
                    
                    # Paths for both split dir and all dir
                    xml_path = os.path.join(split_dir, "annotations", xml_filename)
                    all_xml_path = os.path.join(all_dir, "annotations", xml_filename)
                    
                    # Only create if doesn't exist
                    if not os.path.exists(xml_path):
                        create_voc_xml(
                            filename=original_filename,
                            width=image_objects[original_filename]['width'],
                            height=image_objects[original_filename]['height'],
                            objects_list=image_objects[original_filename]['objects'],
                            folder=split_name,
                            output_path=xml_path
                        )
                    
                    if not os.path.exists(all_xml_path):
                        create_voc_xml(
                            filename=original_filename,
                            width=image_objects[original_filename]['width'],
                            height=image_objects[original_filename]['height'],
                            objects_list=image_objects[original_filename]['objects'],
                            folder="all",
                            output_path=all_xml_path
                        )
            except Exception as e:
                print(f"  Error in second pass for image {i} in {split_name} split: {str(e)}")
                continue
    
    # Check if label CSV already exists
    csv_path = os.path.join(output_dir, "voc_labels.csv")
    if os.path.exists(csv_path) and len(label_data) > 0:
        print(f"Label CSV file already exists at: {csv_path}")
        # Optionally, you could append new data or update existing data
        try:
            existing_df = pd.read_csv(csv_path)
            # Combine existing data with new data
            combined_df = pd.concat([existing_df, pd.DataFrame(label_data)]).drop_duplicates()
            combined_df.to_csv(csv_path, index=False)
            print(f"Updated existing CSV with new data.")
        except Exception as e:
            print(f"Error updating existing CSV: {str(e)}")
            print(f"Creating new CSV file...")
            pd.DataFrame(label_data).to_csv(csv_path, index=False)
    elif len(label_data) > 0:
        # Create new CSV file with all labels
        pd.DataFrame(label_data).to_csv(csv_path, index=False)
        print(f"Label CSV file created at: {csv_path}")
    else:
        print("No label data collected. CSV file not created.")
    
    return {
        'train_dir': train_dir,
        'val_dir': val_dir,
        'test_dir': test_dir,
        'all_dir': all_dir,
        'labels_csv': csv_path if os.path.exists(csv_path) else None
    }

In [9]:
#output_paths = download_voc_dataset(output_dir="Data/voc2007", dataset_version="2007")
print(f"VOC dataset saved to: {output_paths}")
print("\nDataset organization complete!")
print(f"Train directory: {output_paths['train_dir']}")
print(f"Validation directory: {output_paths['val_dir']}")
print(f"Test directory: {output_paths['test_dir']}")
print(f"All data directory: {output_paths['all_dir']}")
print(f"Labels CSV file: {output_paths['labels_csv']}")

# Print CSV stats
df = pd.read_csv(output_paths['labels_csv'])
print(f"\nTotal objects: {len(df)}")
print(f"Total images: {df['image_filename'].nunique()}")
print("\nClass distribution:")
print(df['class_name'].value_counts())
print('num classes',len(df['class_name'].unique()))

VOC dataset saved to: {'train_dir': 'Data/voc2007/train', 'val_dir': 'Data/voc2007/val', 'test_dir': 'Data/voc2007/test', 'all_dir': 'Data/voc2007/all', 'labels_csv': 'Data/voc2007/voc_labels.csv'}

Dataset organization complete!
Train directory: Data/voc2007/train
Validation directory: Data/voc2007/val
Test directory: Data/voc2007/test
All data directory: Data/voc2007/all
Labels CSV file: Data/voc2007/voc_labels.csv

Total objects: 30638
Total images: 9963

Class distribution:
class_name
person         10674
car             3185
chair           2806
bottle          1291
pottedplant     1217
bird            1175
dog             1068
sofa             821
bicycle          807
horse            801
boat             791
cat              759
motorbike        759
tvmonitor        728
cow              685
sheep            664
aeroplane        642
train            630
diningtable      609
bus              526
Name: count, dtype: int64
num classes 20
