In [1]:
import os
from PIL import Image
import json
import torch


class ImageJSONDataset(torch.utils.data.Dataset):
    """
    Custom PyTorch dataset for image and JSON pairs.

    Args:
        root_dir (str): Path to the root directory containing image folders.
        transform (callable, optional): A function/transform to apply to the image.
    """

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_dirs = [
            os.path.join(root_dir, img_folder)
            for img_folder in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, img_folder))
        ]
        self.transform = transform

    def __len__(self):
        # Count the total number of images (one per folder)
        return len(self.image_dirs)

    def __getitem__(self, idx):
        image_dir = self.image_dirs[idx]
        image_path = os.path.join(image_dir, image_dir.split("/")[-1] + ".png")

        # Load image
        image = Image.open(image_path).convert("RGB")  # Assuming RGB images

        # Load JSON files from offsets directory
        offsets_dir = os.path.join(image_dir, "offsets")
        json_files = [
            os.path.join(offsets_dir, f)
            for f in os.listdir(offsets_dir)
            if f.endswith(".json")
        ]
        offsets = []
        for json_file in json_files:
            with open(json_file, "r") as f:
                offsets.append(json.load(f))

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # Return image and offsets as a tuple
        return image, offsets