In [None]:
import polars as pl
from pathlib import Path
import logging
from utils import set_up_logging, delete_corrupt_images

set_up_logging(Path("../logs"))

RETRY_COUNT = 10
WORKER_COUNT = 64
HTTP_TIMEOUT = 30
TARGET_PATH = Path("/bulk2/downloaded-unsplash")

TARGET_PATH = TARGET_PATH.resolve()
assert TARGET_PATH.exists()
delete_corrupt_images(list(TARGET_PATH.glob("*")))

In [None]:
photos = (
    pl.scan_csv(
        "../data/unsplash-full/photos.tsv000",
        separator="\t",
        infer_schema_length=100000,
    )
    .filter(pl.col("photo_featured") == "t")
    .sort("photo_id")
    .select("photo_id", "photo_url")
    .collect()
)

photos.limit(10)

In [None]:
keys = {path.name.split(".")[0] for path in TARGET_PATH.glob("*")}
photos = photos.filter(~pl.col("photo_id").is_in(keys))
logging.info(f"Found {len(photos)} missing photos")

In [None]:
import concurrent.futures
import requests
from tqdm import tqdm
from typing import List
from time import sleep

progress = tqdm(total=len(photos))

def download_image(row):
    for retry_count in range(RETRY_COUNT):
        try:
            logging.debug(f"Downloading {row['photo_id']} from {row['photo_url']}")
            response = requests.get(row["photo_image_url"], timeout=HTTP_TIMEOUT)
            response.raise_for_status()
            extension = response.headers["Content-Type"].split("/")[-1]
            filename = TARGET_PATH / f"{row['photo_id']}.{extension}"
            with open(filename, "wb") as f:
                f.write(response.content)
            logging.debug(f"Downloaded {row['photo_id']} to {filename}")
            with progress.get_lock():
                progress.update(1)
            return
        except Exception as e:
            logging.error(
                f"Error downloading {row['photo_id']} from {row['photo_url']} (retry {retry_count}): {e}"
            )
            sleep(retry_count * 0.5)


with concurrent.futures.ThreadPoolExecutor(max_workers=WORKER_COUNT) as executor:
    futures: List[concurrent.futures.Future] = []
    for row in photos.to_dicts():
        future = executor.submit(download_image, row)
        futures.append(future)

    progress.display()
    concurrent.futures.wait(futures)
progress.close()

In [None]:
delete_corrupt_images(list(TARGET_PATH.glob("*")))