# 0. Import libraries

In [1]:
import numpy as np
import os
import cv2
from collections import defaultdict

# 1. Define the visualization effects

In [2]:
# Define a list of colors for different classes (for simplicity, we assume 10 classes)
bbox_colors = [
    (255, 0, 0),    # Class 0: Blue
    (0, 255, 0),    # Class 1: Green
    (0, 0, 255),    # Class 2: Red
    (255, 255, 0),  # Class 3: Cyan
    (255, 0, 255),  # Class 4: Magenta
    (0, 255, 255),  # Class 5: Yellow
    (128, 0, 128),  # Class 6: Purple
    (128, 128, 0),  # Class 7: Olive
    (128, 128, 128),# Class 8: Gray
    (0, 128, 255)   # Class 9: Orange
]


# 2. Group the labels and images by basenames

In [3]:
def find_files_with_same_basename(directory):
    # 用于存储文件名和对应的扩展名
    file_dict = defaultdict(set)

    # 遍历指定目录下的所有文件
    for _, _, files in os.walk(directory):
        for file in files:
            basename, ext = os.path.splitext(file)
            file_dict[basename].add(ext.lstrip("."))  # 去掉扩展名前面的点

    # 构建结果列表
    result = [{"basename": basename, "extensions": list(extensions)} for basename, extensions in file_dict.items() if len(extensions) > 1]
    
    return result

# 3. Draw with ground-truth

In [4]:
def draw_box_on_image(image, ground_truth: str, index: int):
    # Split the ground thuth string into separate parts
    # str -> class, lx, ly, rx, ry
    ground_truth = ground_truth.split(" ")
    _, cx, cy, w, h = int(ground_truth[0]), float(ground_truth[1]), float(ground_truth[2]), float(ground_truth[3]), float(ground_truth[4])
    width, height = image.shape[1], image.shape[0]

    # Calculate the top-left and bottom-right coordinates
    lx = int((cx - w / 2) * width)
    ly = int((cy - h / 2) * height)
    rx = int((cx + w / 2) * width)
    ry = int((cy + h / 2) * height)

    # Select color based on class
    box_color = bbox_colors[index % len(bbox_colors)]

    # Draw the bounding box
    cv2.rectangle(image, (lx, ly), (rx, ry), box_color, 2)

    return image

# 4. Search the image by basename

In [5]:
def find_image_file(extensions, src_directory, basename):
    # 查找图片文件
    for ext in extensions:
        if ext in ["jpg", "jpeg", "png", "bmp"]:  # 假设这些是支持的图片格式
            image_file_path = os.path.join(src_directory, f"{basename}.{ext}")
            if os.path.exists(image_file_path):
                return image_file_path
    return None

# 5. Main

In [None]:
# 请求用户输入路径
src_directory = input("请输入文件夹的路径: ")
dst_directory = input("请输入目标文件夹的路径: ")

# Check if the destination directory exists
if not os.path.exists(dst_directory):
    os.makedirs(dst_directory)

# Check if the source directory exists
if os.path.exists(src_directory):
    # Group files with the same basename
    file_groups = find_files_with_same_basename(src_directory)

    # Draw the bounding boxes on the images
    for file_group in file_groups:

        # 查找并加载图片文件
        image_file_path = find_image_file(file_group["extensions"], src_directory, file_group["basename"])
        if not image_file_path:
            print(f"没有找到对应的图片文件: {file_group['basename']}")
            continue

        image = cv2.imread(image_file_path)
        if image is None:
            print(f"无法加载图片: {image_file_path}")
            continue

        # 处理txt文件并绘制边框
        txt_file_path = os.path.join(src_directory, f"{file_group['basename']}.txt")
        if os.path.exists(txt_file_path):
            with open(txt_file_path, "r") as file:
                try:
                    for idx, line in enumerate(file):
                        image = draw_box_on_image(image, line.strip(), idx)
                except Exception as e:
                    print(f"数据格式无效: {txt_file_path} - {e}")
                    continue
        else:
            print(f"没有找到对应的标签文件: {txt_file_path}")
            continue

        # Save the image with bounding boxes
        output_image_path = os.path.join(dst_directory, f"{file_group['basename']}_bbox.jpg")
        if image is not None:
            cv2.imwrite(output_image_path, image)

else:
    print("指定的文件夹不存在")