In [None]:
pip install datasets joblib


Verify Dataset

In [None]:
from datasets import load_dataset

# Load the Parquet dataset using the specific branch for Parquet files
dataset = load_dataset("takara-ai/sangyo_no_yume_industrial_dreams", split="train", streaming=True)

# Load a single example to inspect the structure
example = next(iter(dataset))

# Print the structure of the dataset
print(example)

# Print the available columns
print("Columns:", example.keys())


- read gallery_config.yaml
- clear existing images
- contact datasets
- put images in gallery
- record URL's and prompts in gallery_images.yaml

# IMAGE DATASET RETRIEVER

In [None]:
import yaml
from datasets import load_dataset
import os
import random
import shutil
import logging
from multiprocessing import Pool, cpu_count
from functools import partial

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_base_dir():
    try:
        return os.path.dirname(os.path.abspath(__file__))
    except NameError:
        return os.getcwd()

def clear_directory(directory):
    logging.info(f"Clearing directory: {directory}")
    for filename in os.listdir(directory):
        file_path = os.path.join(directory, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
                logging.debug(f"Deleted file: {file_path}")
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
                logging.debug(f"Deleted directory: {file_path}")
        except Exception as e:
            logging.error(f'Failed to delete {file_path}. Reason: {e}')

def process_dataset(dataset_config, gallery_dir, base_dir):
    dataset_name = dataset_config['name']
    num_images = dataset_config['num_images']
    logging.info(f"Processing dataset: {dataset_name}, images to select: {num_images}")

    output_dir = os.path.join(gallery_dir, dataset_name)
    logging.info(f"Output directory: {output_dir}")
    
    if os.path.exists(output_dir):
        clear_directory(output_dir)
    else:
        os.makedirs(output_dir)
        logging.info(f"Created output directory: {output_dir}")

    logging.info(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name, revision="refs/convert/parquet", split="train", streaming=True)
    
    buffer_size = min(num_images * 10, 1000)
    logging.info(f"Buffer size: {buffer_size}")
    buffer = []
    
    logging.info("Filling buffer with examples")
    for i, example in enumerate(dataset):
        if len(buffer) < buffer_size:
            buffer.append(example)
        else:
            if random.random() < buffer_size / (buffer_size + 1):
                replace_index = random.randint(0, buffer_size - 1)
                buffer[replace_index] = example
        
        if len(buffer) >= buffer_size:
            logging.info(f"Buffer filled after {i+1} iterations")
            break
    
    logging.info(f"Selecting {num_images} images from buffer")
    selected_images = random.sample(buffer, min(num_images, len(buffer)))
    
    gallery_images = []
    for i, example in enumerate(selected_images, 1):
        image_filename = f"{example['seed']}.png"
        image_path = os.path.join(output_dir, image_filename)
        example['image'].save(image_path)
        logging.info(f"Saved image {i}/{len(selected_images)}: {image_path}")
        
        gallery_images.append({
            'image_url': f'/assets/images/gallery/{dataset_name}/{image_filename}',
            'positive_prompt': example['positive_prompt'],
        })
    
    return gallery_images

def main():
    logging.info("Script started")

    base_dir = get_base_dir()
    logging.info(f"Base directory: {base_dir}")

    gallery_dir = os.path.abspath(os.path.join(base_dir, '..', 'env', 'assets', 'images', 'gallery'))
    logging.info(f"Gallery directory: {gallery_dir}")

    config_path = os.path.join(base_dir, 'gallery_config.yaml')
    logging.info(f"Loading config from: {config_path}")
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)

    # Use multiprocessing to process datasets in parallel
    num_processes = min(cpu_count(), len(config['datasets']))
    logging.info(f"Using {num_processes} processes for parallel processing")

    with Pool(num_processes) as pool:
        process_func = partial(process_dataset, gallery_dir=gallery_dir, base_dir=base_dir)
        results = pool.map(process_func, config['datasets'])

    # Flatten the list of gallery images
    gallery_images = [image for sublist in results for image in sublist]

    output_file = os.path.join(base_dir, 'gallery_images.yaml')
    logging.info(f"Saving gallery images to: {output_file}")
    with open(output_file, 'w') as file:
        yaml.dump({'images': gallery_images}, file)

    logging.info(f"Generated {output_file} with {len(gallery_images)} randomly selected images from the streaming datasets.")
    logging.info("Script completed")

if __name__ == "__main__":
    main()
