In [27]:
from datasets import load_dataset
from pathlib import Path
from PIL import Image
import json

import base64
import imghdr
import requests
from io import BytesIO


In [22]:
# hf_dataset = 'diffusers/dog-example'
hf_dataset = 'wraps/flux1_dev-small'
ds = load_dataset(hf_dataset)


In [25]:
def get_image(image_input):
    if isinstance(image_input, Image.Image):
        return image_input
    
    elif isinstance(image_input, bytes):
        try:
            return Image.open(BytesIO(image_input))
        except Exception:
            raise ValueError("Invalid image bytes")
    
    elif isinstance(image_input, str):
        # Check if it's a URL
        if image_input.startswith("http"):
            try:
                response = requests.get(image_input, stream=True)
                response.raise_for_status()
                return Image.open(response.raw)
            except requests.RequestException:
                raise ValueError("Invalid image URL")

        # Check if it's Base64 encoded
        elif image_input.startswith("data:image"):
            try:
                header, encoded = image_input.split(",", 1)  # Remove header
                decoded_bytes = base64.b64decode(encoded, validate=True)
                if imghdr.what(None, decoded_bytes) is None:
                    raise ValueError("Invalid Base64 image")
                return Image.open(BytesIO(decoded_bytes))
            except (ValueError, TypeError):
                raise ValueError("Invalid Base64 image data")

        else:
            raise ValueError("Unsupported image format")
    
    else:
        raise TypeError("Input must be a URL, Base64 string, bytes, or PIL Image")

In [None]:
ds_dir = Path.cwd().joinpath('datasets')
ds_dir.mkdir(exist_ok=True)

def create_local_dataset(ds_dir,hf_dataset):

    for subset in ds.keys():
        subset_dir = ds_dir.joinpath(subset)
        subset_dir.mkdir(exist_ok=True)

    for i, item in enumerate(ds[subset]):
        image = get_image(item["image"])
        prompt = item["prompt"] 
        
        img_path = subset_dir.joinpath(f"{i}.jpg")
        prompt_path = subset_dir.joinpath(f"{i}.txt")

        if isinstance(image, Image.Image):
            image.save(img_path)
        elif isinstance(image, str): 
            img_path.write_bytes(Path(image).read_bytes())
        with open(prompt_path, "w", encoding="utf-8") as f:
            f.write(prompt)
        pass

