In [1]:
from dotenv import load_dotenv
import os
from utils.s3_download import list_all_objects, s3_client
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from multiprocessing import Pool


load_dotenv(override=True)

BUCKET_NAME = os.getenv("BUCKET_NAME")
RAW_FOLDER = os.getenv("RAW_FOLDER")
CLEAN_FOLDER = os.getenv("CLEAN_FOLDER")
SELECTED_FOLDER = os.getenv("SELECTED_FOLDER")
LOCAL_BUCKET_FOLDER = os.getenv("LOCAL_BUCKET_FOLDER")
VIDEO_EXTRACTION_FOLDER = os.getenv("VIDEO_EXTRACTION_FOLDER")
EXCLUDE_FOLDER = os.getenv("EXCLUDE_FOLDER")

IMAGE_SIZE = eval(os.getenv("IMAGE_SIZE"))

In [2]:
def list_files_to_download(s3_folder:str):
    s3_files = list_all_objects(s3_client, BUCKET_NAME, s3_folder)
    print(f"Number of files in {BUCKET_NAME}/{s3_folder}: {len(s3_files)}")
    s3_files = [k['Key'] for k in s3_files if not k['Key'].endswith('/')]
    return s3_files

pipeline_s3_files = list_files_to_download(RAW_FOLDER)
vid_s3_files = list_files_to_download(VIDEO_EXTRACTION_FOLDER + '/2023-10-31T13-38-28')
clean_s3_files = [f for f in list_files_to_download(CLEAN_FOLDER) if f.endswith('.json')]
exclude_s3_files = list_files_to_download(EXCLUDE_FOLDER)

listing s3 objects: 42it [00:07,  5.41it/s]


Number of files in sg-implement/prod_raw: 41201


listing s3 objects: 1it [00:00,  6.55it/s]


Number of files in sg-implement/video_extraction/2023-10-31T13-38-28: 822


listing s3 objects: 11it [00:02,  5.27it/s]


Number of files in sg-implement/prod_clean: 10515


listing s3 objects: 2it [00:00,  8.59it/s]

Number of files in sg-implement/prod_exclude: 1217





In [3]:
s3_files = pipeline_s3_files + vid_s3_files + clean_s3_files + exclude_s3_files

In [None]:
def download_and_resize(file:str, s3_client=s3_client):
    local_save_folder = Path(f"{LOCAL_BUCKET_FOLDER}")
    local_img_path = local_save_folder / file
    if local_img_path.exists():
        return False
    local_img_path.parent.mkdir(parents=True, exist_ok=True)
    s3_client.download_file(BUCKET_NAME, file, str(local_img_path))
    if local_img_path.suffix == '.jpg':
        img = Image.open(local_img_path)
        img = img.resize(IMAGE_SIZE)
        img.save(local_img_path)
    return True

def download_images(s3_files:list[str]):
    with Pool(32) as p:
        r = list(tqdm(p.imap(download_and_resize, s3_files), total=len(s3_files)))
    num_downloaded = sum(r)
    print(f"Number of files downloaded: {num_downloaded}")
    return num_downloaded

num_downloaded = download_images(s3_files)

 92%|████████████████████████████████████████████████████████████████████▉      | 44548/48493 [00:09<00:03, 1236.18it/s]