In [None]:
"""
Prepare a subset of the LAION dataset: extract valid images and save corresponding metadata

We use a subset of a single file of the LAION dataset (1.8GB):
[12933524 rows x 8 columns]
metadata:
SAMPLE_ID | URL | TEXT | LICENSE | NSFW | similarity | WIDTH | HEIGHT
"""
import os
import pandas as pd
import subprocess
import glob
import datetime
import shutil

In [None]:
ENTITY_COUNT = 500

FULL_LAION_PATH = "/Users/yavuz/data/part-00000-5b54c5d5-bbcf-484d-a2ce-0d6f73df1a36-c000.snappy.parquet"
PREP_DATASET_PATH = f"/Users/yavuz/data/LAION-{ENTITY_COUNT}/"

if os.path.exists(PREP_DATASET_PATH):
    print(f"Warning: {PREP_DATASET_PATH} exists!")
else:
    os.makedirs(PREP_DATASET_PATH)

IMAGES_PATH = PREP_DATASET_PATH + "images"
URLS_PATH = PREP_DATASET_PATH + "urls.txt"
SUCCEEDED_URLS_PATH = PREP_DATASET_PATH + "succeeded-urls.txt"
DATA_PATH = PREP_DATASET_PATH + "metadata.parquet"

In [None]:
def read_safe_data(path: str, count:int) -> pd.DataFrame:
    """
    Return non-nsfw entries from the full LAION dataset.
    """
    print(f"Reading {count} items from full LAION dataset...")
    df = pd.read_parquet(path)[:count]
    
    nsfw_removed_data = df[df["NSFW"]=="UNLIKELY"]
    print("Size after removing NSFW:", len(nsfw_removed_data))
    
    clean_url_data = nsfw_removed_data[~nsfw_removed_data['URL'].str.contains(',')]
    print("Size after removing URLs with commas:", len(clean_url_data))

    return clean_url_data

In [None]:
data = read_safe_data(FULL_LAION_PATH, ENTITY_COUNT)
data

In [None]:
def write_urls(data: pd.DataFrame, path: str) -> None:
    """
    Writes the URLs found in the dataframe to a file in the given path
    """
    with open(path, "w+") as f:
        for url in data["URL"]:
            f.write(url + "\n")
    print(f"Finished writing {len(data)} URLs to {path}")

write_urls(data, URLS_PATH)

In [None]:
def download_images(url_path: str, images_path: str):
    """
    download images from text file with list of urls 
    """
    if os.path.exists(images_path):
        print(f"Warning: {images_path} exists - renaming it...!")
        os.rename(IMAGES_PATH, IMAGES_PATH + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
        
    subprocess.call(["img2dataset", "--url_list="+url_path, "--output_folder="+images_path, "--thread_count=64", "--image_size=256"])

In [None]:
download_images(URLS_PATH, IMAGES_PATH)

In [None]:
def get_valid_file_ids(path:str) -> list[int]:
    """
    Return the ids of all files in IMAGES_PATH (recursively) that end with .jpg
    """
    files = glob.glob(path+"/*/*.jpg")
    files = [file.split('/')[-2:] for file in files]
    print(f"Found {len(files)} files")
    
    ids = [int(file[1].split('.')[0]) for file in files]
    ids.sort()
    return ids

In [None]:
ids = get_valid_file_ids(IMAGES_PATH)
ids

In [None]:
data_with_images = data.iloc[ids]
data_with_images

In [None]:
data_with_images = data_with_images.reset_index()
data_with_images

In [None]:
write_urls(data_with_images, SUCCEEDED_URLS_PATH)
#download_images(URLS_PATH, IMAGES_PATH)

In [None]:
# save metadata to parquet
data_with_images.to_parquet(DATA_PATH)

In [None]:
def move_files(images_path: str):
    """
    Rename (and move files across shards) so that we have a continuous range of file names from 0 to n
    """
    files = glob.glob(IMAGES_PATH+"/*/*.jpg")
    files.sort()

    for i in range(0, len(files)):
        shard = str(i // 10000).zfill(5)
        index = str(i % 10000).zfill(4)
        
        image_file = files[i]
        json_file = image_file.replace(".jpg", ".json")
        
        shutil.move(image_file, f"{images_path}/{shard}/{shard}{index}.jpg")
        shutil.move(json_file, f"{images_path}/{shard}/{shard}{index}.json")

In [None]:
move_files(IMAGES_PATH)