In [None]:
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import collections

# -------------------------------------------------------------------
# 1) Matplotlib DPI Settings for High-Quality Images
# -------------------------------------------------------------------
matplotlib.rcParams['figure.dpi'] = 300
matplotlib.rcParams['font.size'] = 12

# -------------------------------------------------------------------
# 2) Paths
# -------------------------------------------------------------------
dataset_path = "/content/dataset/Indoor Object Detection Dataset"  # Update if needed
annotation_folder = os.path.join(dataset_path, "annotation")

# -------------------------------------------------------------------
# 3) Parse XML Annotations
# -------------------------------------------------------------------
def parse_annotations(xml_files):
    """
    Parses multiple XML annotation files into a single dictionary:
        annotations[img_name] = {
            "boxes": [[xmin, ymin, xmax, ymax], ...],
            "labels": [label1, label2, ...]
        }
    Also returns a Counter for class distributions (bounding boxes).
    """
    annotations = {}
    class_counts = collections.Counter()

    for xml_file in xml_files:
        full_path = os.path.join(annotation_folder, xml_file)
        if not os.path.isfile(full_path):
            continue

        tree = ET.parse(full_path)
        root = tree.getroot()

        for image_el in root.findall(".//image"):
            img_name = os.path.basename(image_el.get("file"))
            boxes = []
            labels = []

            for box_el in image_el.findall("box"):
                xmin = int(box_el.get("left"))
                ymin = int(box_el.get("top"))
                w = int(box_el.get("width"))
                h = int(box_el.get("height"))
                xmax = xmin + w
                ymax = ymin + h

                label_el = box_el.find("label")
                label = label_el.text.strip() if label_el is not None else "unknown"

                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(label)
                class_counts[label] += 1

            annotations[img_name] = {
                "boxes": boxes,
                "labels": labels
            }

    return annotations, class_counts

# -------------------------------------------------------------------
# 4) Save Annotation & Update Dictionary
# -------------------------------------------------------------------
def save_annotation_for_augmented_image(
    img_filename, boxes, labels, annotation_folder, annotations_dict
):
    """
    Creates a new XML file in 'annotation_folder' for the augmented image,
    and also updates 'annotations_dict' so the new image is recognized
    in the same run (avoiding "No annotations found..." messages).
    """
    # Build the XML structure
    root = ET.Element("annotations")
    images_el = ET.SubElement(root, "images")
    image_el = ET.SubElement(images_el, "image", file=img_filename)

    for (xmin, ymin, xmax, ymax), lab in zip(boxes, labels):
        w = xmax - xmin
        h = ymax - ymin
        box_el = ET.SubElement(
            image_el,
            "box",
            left=str(xmin),
            top=str(ymin),
            width=str(w),
            height=str(h)
        )
        label_el = ET.SubElement(box_el, "label")
        label_el.text = lab

    # Write out the new XML
    xml_name = os.path.splitext(img_filename)[0] + ".xml"
    xml_path = os.path.join(annotation_folder, xml_name)
    tree = ET.ElementTree(root)
    tree.write(xml_path)
    print(f"Saved annotation for augmented image: {xml_path}")

    # Update the annotations dictionary immediately
    annotations_dict[img_filename] = {
        "boxes": [[xmin, ymin, xmax, ymax] for (xmin, ymin, xmax, ymax) in boxes],
        "labels": labels[:],
    }

# -------------------------------------------------------------------
# 5) Data Augmentation Functions
# -------------------------------------------------------------------
def rotate_image(image, boxes, angle=15):
    """Rotates an image and adjusts bounding boxes."""
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)

    # Rotate image
    rotated_image = cv2.warpAffine(image, rotation_matrix, (w, h))

    # Rotate bounding boxes
    rotated_boxes = []
    for xmin, ymin, xmax, ymax in boxes:
        points = np.array([
            [xmin, ymin],
            [xmax, ymin],
            [xmax, ymax],
            [xmin, ymax]
        ], dtype=np.float32)
        rotated_points = np.dot(rotation_matrix[:, :2], points.T).T + rotation_matrix[:, 2]
        x_min_new, y_min_new = rotated_points.min(axis=0)
        x_max_new, y_max_new = rotated_points.max(axis=0)
        rotated_boxes.append([
            int(x_min_new),
            int(y_min_new),
            int(x_max_new),
            int(y_max_new)
        ])

    return rotated_image, rotated_boxes

def flip_image(image, boxes, w):
    """Horizontally flips an image and adjusts bounding boxes."""
    flipped_image = cv2.flip(image, 1)
    flipped_boxes = []
    for xmin, ymin, xmax, ymax in boxes:
        new_xmin = w - xmax
        new_xmax = w - xmin
        flipped_boxes.append([new_xmin, ymin, new_xmax, ymax])
    return flipped_image, flipped_boxes

def adjust_brightness(image, factor=1.3):
    """Adjusts image brightness."""
    return np.clip(image.astype(np.float32) * factor, 0, 255).astype(np.uint8)

def apply_gaussian_blur(image):
    """Applies Gaussian blur to an image."""
    return cv2.GaussianBlur(image, (5, 5), 0)

# -------------------------------------------------------------------
# 6) Apply Augmentations & Save
# -------------------------------------------------------------------
def augment_and_save_images(image_path, annotations_dict):
    """
    Applies augmentations to one image, saves the augmented images,
    and also saves a new annotation XML for each augmented image.
    Returns a dict mapping {aug_type -> (aug_image_path, aug_boxes, labels)}.
    """
    img = cv2.imread(image_path)
    if img is None:
        print(f"Error loading {image_path}")
        return {}

    img_name = os.path.basename(image_path)
    ann = annotations_dict.get(img_name, {"boxes": [], "labels": []})
    boxes, labels = ann["boxes"], ann["labels"]

    if len(boxes) == 0:
        print(f"No bounding boxes for {img_name}, skipping augmentation.")
        return {}

    h, w = img.shape[:2]

    # Perform each augmentation
    angle_choice = random.choice([-15, 15])
    rotated_img, rotated_boxes = rotate_image(img.copy(), boxes, angle_choice)
    flipped_img, flipped_boxes = flip_image(img.copy(), boxes, w)
    bright_img = adjust_brightness(img.copy())
    blurred_img = apply_gaussian_blur(img.copy())

    augmentations = {
        "rotated":  (rotated_img,  rotated_boxes),
        "flipped":  (flipped_img,  flipped_boxes),
        "bright":   (bright_img,   boxes),  # same boxes
        "blurred":  (blurred_img,  boxes)   # same boxes
    }

    aug_results = {}
    dir_name = os.path.dirname(image_path)

    for aug_type, (aug_img, aug_boxes) in augmentations.items():
        # Construct new image name
        base, ext = os.path.splitext(img_name)
        aug_img_name = f"{base}_{aug_type}{ext}"
        aug_img_path = os.path.join(dir_name, aug_img_name)

        # Save the augmented image
        cv2.imwrite(aug_img_path, aug_img)

        # Save new XML annotation (and update dictionary!)
        save_annotation_for_augmented_image(
            img_filename=aug_img_name,
            boxes=aug_boxes,
            labels=labels,
            annotation_folder=annotation_folder,
            annotations_dict=annotations_dict
        )

        # Store info for potential visualization
        aug_results[aug_type] = (aug_img_path, aug_boxes, labels)

    return aug_results

# -------------------------------------------------------------------
# 7) Visualization (Optional)
# -------------------------------------------------------------------
def draw_bounding_boxes(image, boxes, labels):
    """Draw bounding boxes and labels on an image."""
    for (xmin, ymin, xmax, ymax), lab in zip(boxes, labels):
        cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
        cv2.putText(image, lab, (xmin, ymin - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8,
                    (0, 255, 0), 2, cv2.LINE_AA)
    return image

def visualize_augmented_samples(original_img_path, aug_img_path, aug_boxes, aug_labels):
    """
    Display the original image (with boxes) vs. augmented image (with boxes) side by side.
    """
    orig_img = cv2.imread(original_img_path)
    aug_img = cv2.imread(aug_img_path)
    if orig_img is None or aug_img is None:
        return

    # Original bounding boxes
    orig_name = os.path.basename(original_img_path)
    global annotations
    ann = annotations.get(orig_name, {"boxes": [], "labels": []})
    orig_boxes, orig_labels = ann["boxes"], ann["labels"]

    orig_drawn = draw_bounding_boxes(orig_img.copy(), orig_boxes, orig_labels)
    aug_drawn = draw_bounding_boxes(aug_img.copy(), aug_boxes, aug_labels)

    plt.figure(figsize=(8, 4), dpi=300)
    plt.subplot(1, 2, 1)
    plt.imshow(cv2.cvtColor(orig_drawn, cv2.COLOR_BGR2RGB))
    plt.title("Original")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(cv2.cvtColor(aug_drawn, cv2.COLOR_BGR2RGB))
    plt.title("Augmented")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

# -------------------------------------------------------------------
# 8) Count Images & Class Distribution
# -------------------------------------------------------------------
def count_images_and_class_distribution(dataset_path, annotation_folder):
    """
    1) Count how many images (all .jpg/.png) in 'sequence_*' folders.
    2) Parse all XML files to get bounding box class distribution.
    """
    # Count images
    total_images = 0
    for seq in os.listdir(dataset_path):
        seq_path = os.path.join(dataset_path, seq)
        if os.path.isdir(seq_path) and seq.startswith("sequence_"):
            # Count files that end in .jpg/.png/.jpeg
            image_files = [f for f in os.listdir(seq_path)
                           if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            total_images += len(image_files)

    # Finally, parse ALL XML files (original + new)
    all_xml_files = [f for f in os.listdir(annotation_folder) if f.endswith(".xml")]
    _, class_counts = parse_annotations(all_xml_files)

    print("\n--- Dataset Summary After Augmentation ---")
    print(f"Total images in dataset: {total_images}")
    print("Class distribution (bounding boxes):")
    for cls, count in class_counts.items():
        print(f"  {cls}: {count}")

# -------------------------------------------------------------------
# 9) Main: Process All Sequences
# -------------------------------------------------------------------
def process_all_sequences():
    """
    1) Parse the original XML files (annotation_s1.xml to annotation_s6.xml)
       to build our 'annotations' dictionary.
    2) For each 'sequence_*' folder, pick 3 images, augment them, and
       visualize up to 3 augmented results.
    3) At the end, parse all XML files again to see the final distribution.
    """
    # (Alternative Start) If you want to parse ALL existing XMLs from the start:
    #   all_xml_files = [f for f in os.listdir(annotation_folder) if f.endswith(".xml")]
    #   global annotations
    #   annotations, _ = parse_annotations(all_xml_files)

    # Otherwise, parse only the original known XMLs:
    original_xmls = [f"annotation_s{i}.xml" for i in range(1, 7)]
    global annotations
    annotations, _ = parse_annotations(original_xmls)

    # 2) For each sequence_* folder, pick 3 images, augment them, and visualize
    for seq in sorted(os.listdir(dataset_path)):
        seq_path = os.path.join(dataset_path, seq)
        if not os.path.isdir(seq_path) or not seq.startswith("sequence_"):
            continue

        print(f"\nProcessing folder: {seq}...")
        image_files = [f for f in os.listdir(seq_path) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
        if not image_files:
            print(f"No images found in {seq}")
            continue

        # Randomly pick 3 images from the sequence
        sample_count = min(3, len(image_files))
        chosen_files = random.sample(image_files, sample_count)

        # We will store info about augmented images for visualization
        aug_visual_list = []

        for img_file in chosen_files:
            img_path = os.path.join(seq_path, img_file)
            # Perform augmentations (rotated, flipped, bright, blurred)
            aug_results = augment_and_save_images(img_path, annotations)

            # Pick exactly one augmented type to visualize (e.g., "rotated")
            if aug_results:
                chosen_aug_type = random.choice(list(aug_results.keys()))
                aug_img_path, aug_boxes, aug_labels = aug_results[chosen_aug_type]
                aug_visual_list.append((img_path, aug_img_path, aug_boxes, aug_labels))

        # Now visualize up to 3 augmented results side-by-side
        max_to_show = min(3, len(aug_visual_list))
        for i in range(max_to_show):
            orig_path, a_path, a_boxes, a_labels = aug_visual_list[i]
            visualize_augmented_samples(orig_path, a_path, a_boxes, a_labels)

    # 3) After all augmentations, parse ALL XMLs for final distribution
    count_images_and_class_distribution(dataset_path, annotation_folder)

# -------------------------------------------------------------------
# 10) Run
# -------------------------------------------------------------------
if __name__ == "__main__":
    process_all_sequences()