In [None]:
import os
import pandas as pd
import subprocess
import glob
import datetime

In [None]:
FULL_LAION_PATH = "/Users/yavuz/data/part-00000-5b54c5d5-bbcf-484d-a2ce-0d6f73df1a36-c000.snappy.parquet"

ENTITY_COUNT = 10000
PREP_DATASET_PATH = "/Users/yavuz/data/LAION-10000/"

if os.path.exists(PREP_DATASET_PATH):
    print(f"Warning: {PREP_DATASET_PATH} exists!")
else:
    os.makedirs(PREP_DATASET_PATH)

IMAGES_PATH = PREP_DATASET_PATH + "images"
URLS_PATH = PREP_DATASET_PATH + "urls.txt"
SUCCEEDED_URLS_PATH = PREP_DATASET_PATH + "succeeded-urls.txt"
DATA_PATH = PREP_DATASET_PATH + "metadata.parquet"

In [None]:
def read_safe_data(path: str, count:int) -> pd.DataFrame:
    """
    Return non-nsfw entries from the full LAION dataset.
    """
    print(f"Reading {count} items from full LAION dataset...")
    df = pd.read_parquet(path)[:count]
    
    nsfw_removed_data = df[df["NSFW"]=="UNLIKELY"]
    print("Size after removing NSFW:", len(nsfw_removed_data))
    
    clean_url_data = nsfw_removed_data[~nsfw_removed_data['URL'].str.contains(',')]
    print("Size after removing URLs with commas:", len(clean_url_data))

    return clean_url_data

In [None]:
data = read_safe_data(FULL_LAION_PATH, ENTITY_COUNT)
data

In [None]:
def write_urls(data: pd.DataFrame, path: str) -> None:
    """
    Writes the URLs found in the dataframe to a file in the given path
    """
    with open(path, "w+") as f:
        for url in data["URL"]:
            f.write(url + "\n")
    print(f"Finished writing {len(data)} URLs to {path}")

write_urls(data, URLS_PATH)

In [None]:
def download_images(url_path: str, images_path: str):
    """
    download images from text file with list of urls 
    """
    if os.path.exists(images_path):
        print(f"Warning: {images_path} exists - renaming it...!")
        os.rename(IMAGES_PATH, IMAGES_PATH + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
        
    subprocess.call(["img2dataset", "--url_list="+url_path, "--output_folder="+images_path, "--thread_count=64", "--image_size=256"])

In [None]:
download_images(URLS_PATH, IMAGES_PATH)

In [None]:
#get a list of all files in IMAGES_PATH that end with .jpg
files = glob.glob(IMAGES_PATH+"/*/*.jpg")
files = [file.split('/')[-2:] for file in files]
len(files)

In [None]:
ids = [int(file[1].split('.')[0]) for file in files]
ids.sort()

In [None]:
data_with_images = data.iloc[ids]
data_with_images

In [None]:
data_with_images.iloc[-1]

In [None]:
write_urls(data_with_images, SUCCEEDED_URLS_PATH)
#download_images(URLS_PATH, IMAGES_PATH)

In [None]:
# save metadata to parquet
data_with_images.to_parquet(DATA_PATH)