In [None]:
%load_ext autoreload  
%autoreload 2

import warnings  
warnings.filterwarnings('ignore')

In [3]:
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
from skimage.draw import disk, rectangle, polygon
import json
import os

class DatasetGenerator:
    def __init__(self, image_size=(64, 64), num_images=1000):
        self.image_size = image_size
        self.num_images = num_images

    def generate(self, properties_distribution):
        images = []
        labels = []
        for _ in range(self.num_images):
            img, label = self.generate_image(properties_distribution)
            images.append(img)
            labels.append(label)

        return np.array(images), labels  # Note: labels is a list of lists

    def generate_image(self, properties_distribution):
        num_shapes = np.random.randint(1, properties_distribution["max_shapes"] + 1)
        img = np.zeros(self.image_size, dtype=np.float32)
        shapes = []
        occlusion_mode = properties_distribution["occlusion_mode"]

        for _ in range(num_shapes):
            shape = np.random.choice(properties_distribution["shapes"])
            scale = np.random.uniform(*properties_distribution["scales"])
            orientation = np.random.uniform(*properties_distribution["orientations"])
            position_x = np.random.uniform(*properties_distribution["positions"])
            position_y = np.random.uniform(*properties_distribution["positions"])

            shape_img = self.draw_shape(shape, scale, orientation, position_x, position_y)

            if occlusion_mode == "no_occlusion":
                shape_img = np.where(img == 0, shape_img, 0)
            elif occlusion_mode == "allow_occlusion":
                img = np.maximum(img, shape_img)
            elif occlusion_mode == "crop_boundary":
                shape_img = np.where(shape_img > 0, 1, 0)
                img = np.maximum(img, shape_img)

            img = np.maximum(img, shape_img)
            shapes.append({
                "shape": shape,
                "scale": scale,
                "orientation": orientation,
                "position_x": position_x,
                "position_y": position_y,
            })

        img = np.clip(img, 0, 1)  # Ensure pixel values are in [0, 1]
        return img, shapes

    def draw_shape(self, shape, scale, orientation, position_x, position_y):
        img = np.zeros(self.image_size, dtype=np.float32)
        
        if shape == "circle":
            radius = int(scale * min(self.image_size) / 2)
            center_y = int(position_y * self.image_size[0])
            center_x = int(position_x * self.image_size[1])
            rr, cc = disk((center_y, center_x), radius, shape=img.shape)
            img[rr, cc] = 1
            
        elif shape == "square":
            side = int(scale * min(self.image_size) / 2)
            start_x = int(position_x * self.image_size[1] - side / 2)
            start_y = int(position_y * self.image_size[0] - side / 2)
            end_x = start_x + side
            end_y = start_y + side
            rr, cc = rectangle(
                start=(start_y, start_x), end=(end_y, end_x), shape=img.shape
            )
            img[rr, cc] = 1
            
        elif shape == "triangle":
            side = int(scale * min(self.image_size) / 2)
            half_side = side // 2
            start_x = int(position_x * self.image_size[1])
            start_y = int(position_y * self.image_size[0])
            triangle = np.array(
                [
                    [start_y - half_side, start_x - half_side],
                    [start_y - half_side, start_x + half_side],
                    [start_y + half_side, start_x],
                ]
            )
            triangle[:, 0] = np.clip(triangle[:, 0], 0, self.image_size[0] - 1)
            triangle[:, 1] = np.clip(triangle[:, 1], 0, self.image_size[1] - 1)
            rr, cc = polygon(triangle[:, 0], triangle[:, 1], shape=img.shape)
            img[rr, cc] = 1

        img = rotate(img, angle=orientation, reshape=False)
        return img

    def visualize_sample(self, images, labels, num_samples=5):
        plt.figure(figsize=(10, 2))
        for i in range(num_samples):
            plt.subplot(1, num_samples, i + 1)
            plt.imshow(images[i], cmap="gray")
            plt.title(f"{labels[i][0]['shape']}")
            plt.axis("off")
        plt.show()


def generate_and_display_images(shapes, min_scale, max_scale, min_orientation, max_orientation, min_position, max_position, max_shapes, occlusion_mode, num_images):
    properties_distribution = {
        "shapes": shapes,
        "scales": (min_scale, max_scale),
        "orientations": (min_orientation, max_orientation),
        "positions": (min_position, max_position),
        "max_shapes": int(max_shapes),
        "occlusion_mode": occlusion_mode
    }

    generator = DatasetGenerator(image_size=(64, 64), num_images=int(num_images))
    images, labels = generator.generate(properties_distribution)

    # Save images and metadata
    save_dir = "generated_images"
    os.makedirs(save_dir, exist_ok=True)
    metadata = []

    for i, (img, label) in enumerate(zip(images, labels)):
        img_pil = Image.fromarray((img * 255).astype(np.uint8))
        img_filename = os.path.join(save_dir, f"image_{i}.png")
        img_pil.save(img_filename)

        label_filename = os.path.join(save_dir, f"metadata_{i}.json")
        with open(label_filename, "w") as f:
            json.dump(label, f, indent=2)

        metadata.append((img_filename, label_filename))

    generator.visualize_sample(images, labels)
    
    return [Image.open(f) for f, _ in metadata[:10]]  # Limit to 10 samples


interface = gr.Interface(
    fn=generate_and_display_images,
    inputs=[
        gr.CheckboxGroup(choices=["circle", "square", "triangle"], label="Shapes", value=["circle", "square", "triangle"]),
        gr.Slider(0.1, 1.0, label="Min Scale", value=0.2),
        gr.Slider(0.1, 1.0, label="Max Scale", value=0.8),
        gr.Slider(0, 360, label="Min Orientation", value=0),
        gr.Slider(0, 360, label="Max Orientation", value=360),
        gr.Slider(0.0, 1.0, label="Min Position", value=0.2),
        gr.Slider(0.0, 1.0, label="Max Position", value=0.8),
        gr.Number(label="Max Shapes", value=5),
        gr.Radio(["allow_occlusion", "no_occlusion", "crop_boundary"], label="Occlusion Mode", value="allow_occlusion"),
        gr.Textbox(label="Number of Images", value="1000")
    ],
    outputs=gr.Gallery(label="Generated Images with Metadata", show_label=False, show_download_button=False),
    title="Dataset Generator",
    description="Generate images with different shapes and properties along with metadata."
)

if __name__ == "__main__":
    interface.launch(share=True)


Running on local URL:  http://127.0.0.1:7863
Running on public URL: https://d9a87d16a295bc413a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
