In [None]:
import boto3
import numpy as np
import pandas as pd
from skimage.color import rgb2lab, lab2rgb
import os
from io import BytesIO
from tqdm import tqdm
from PIL import Image
import pickle
from sklearn.cluster import KMeans
import itertools

# for all images in miro s3 bucket

In [None]:
sts = boto3.client("sts")
assumed_role_object = sts.assume_role(
    RoleArn="arn:aws:iam::760097843905:role/calm-assumable_read_role",
    RoleSessionName="AssumeRoleSession1",
)
credentials = assumed_role_object["Credentials"]

s3_platform = boto3.client(
    "s3",
    aws_access_key_id=credentials["AccessKeyId"],
    aws_secret_access_key=credentials["SecretAccessKey"],
    aws_session_token=credentials["SessionToken"],
)

In [None]:
s3_data_science = boto3.client("s3")

In [None]:
def get_s3_keys_as_generator(bucket):
    """Generate all the keys in an S3 bucket."""
    kwargs = {"Bucket": bucket}
    while True:
        resp = s3_platform.list_objects_v2(**kwargs)
        for obj in resp["Contents"]:
            yield obj["Key"]

        try:
            kwargs["ContinuationToken"] = resp["NextContinuationToken"]
        except KeyError:
            break

In [None]:
bucket_name = "wellcomecollection-miro-images-public"
all_keys = list(get_s3_keys_as_generator(bucket_name))

In [None]:
len(all_keys)

# get the ids that have already been processed

In [None]:
n_items_in_bucket = 164

In [None]:
palette_dicts = []
for i in tqdm(range(n_items_in_bucket + 1)):
    try:
        binary_data = s3_data_science.get_object(
            Bucket="model-core-data",
            Key="palette_similarity/palette_dict_{}.pkl".format(i),
        )["Body"].read()
        palette_dict = pickle.load(BytesIO(binary_data))
        palette_dicts.append(palette_dict)
    except:
        pass

In [None]:
palette_dict = {}
for d in palette_dicts:
    palette_dict.update(d)

len(palette_dict)

In [None]:
already_processed_ids = set(palette_dict.keys())

In [None]:
def id_from_object_key(object_key):
    image_id, _ = os.path.splitext(os.path.basename(object_key))
    return image_id

In [None]:
not_yet_processed_keys = [
    object_key
    for object_key in all_keys
    if id_from_object_key(object_key) not in already_processed_ids
]

In [None]:
len(not_yet_processed_keys)

# get their palettes

In [None]:
def get_image(object_key):
    image_object = s3_platform.get_object(Bucket=bucket_name, Key=object_key)
    image = Image.open(BytesIO(image_object["Body"].read()))
    if image.mode != "RGB":
        image = image.convert("RGB")
    image = image.resize((75, 75), resample=Image.BILINEAR)
    return image


def get_palette(image, palette_size=5):
    lab_image = rgb2lab(np.array(image)).reshape(-1, 3)
    clusters = KMeans(n_clusters=palette_size).fit(lab_image)
    return clusters.cluster_centers_

In [None]:
chunk_size, palette_dict = 1000, {}

for i, object_key in enumerate(tqdm(not_yet_processed_keys)):
    try:
        image = get_image(object_key)
        image_id = id_from_object_key(object_key)
        palette_dict[image_id] = get_palette(image)
    except:
        pass

    if (i % chunk_size == 0) and (i != 0):
        s3_data_science = boto3.client("s3")
        s3_data_science.put_object(
            Bucket="model-core-data",
            Key="palette_similarity/palette_dict_{}.pkl".format(
                (i // chunk_size) + n_items_in_bucket
            ),
            Body=pickle.dumps(palette_dict),
        )
        palette_dict = {}

# save the data