In [None]:
import os
import io
from tqdm import tqdm
from datetime import datetime

from PIL import Image
from datasets import Features, Value, ClassLabel, Dataset, concatenate_datasets, load_dataset, load_from_disk
from datasets import Image as DImage

In [None]:
anchor_dict = {}
positive_dict = {}
category_dict = {}
post_id_dict = {}

labels = ['bag', 'bottom', 'dress', 'hat', 'shoes', 'outer', 'top']
label2id = {
    l: i for i, l in enumerate(labels)
}

pair_id = 0

categories = [i for i in os.listdir('./kream_anchor_positive_images') if not i.startswith('.')]
for category in categories:
    post_ids = os.listdir(f'./kream_anchor_positive_images/{category}')
    for post_id in tqdm(post_ids, desc=category):
        category_id = label2id[category]
        anchor_image_path = f'./kream_anchor_positive_images/{category}/{post_id}/anchor.jpg'
        positive_image_path = f'./kream_anchor_positive_images/{category}/{post_id}/positive.jpg'

        anchor_dict[pair_id] = anchor_image_path
        positive_dict[pair_id] = positive_image_path
        category_dict[pair_id] = category_id
        post_id_dict[pair_id] = post_id

        pair_id += 1

assert len(anchor_dict) == len(positive_dict) == len(category_dict)
print(f'anchor: {len(anchor_dict)}')
print(f'positive: {len(positive_dict)}')
print(f'category: {len(category_dict)}')
print(f'post_id: {len(post_id_dict)}')

In [None]:
class_label = ClassLabel(names=labels)
features = Features({
    'anchor_image': DImage(decode=True),
    'positive_image': DImage(decode=True),
    'category': class_label,
    'post_id': Value('string'),
})

seg = 0
for i in range(0, len(anchor_dict), 5000):
    data_list = []
    batch_pair_ids = list(sorted(anchor_dict.keys()))[i: i + 5000]
    for pair_id in tqdm(batch_pair_ids):
        anchor_image_path = anchor_dict[pair_id]
        try:
            anchor_image = Image.open(anchor_image_path).convert('RGB')
        except Exception:
            continue
        anchor_jpeg_buffer = io.BytesIO()
        anchor_image.save(anchor_jpeg_buffer, format='JPEG')
        anchor_jpeg_buffer.seek(0)
        anchor_image = Image.open(anchor_jpeg_buffer)

        positive_image_path = positive_dict[pair_id]
        try:
            positive_image = Image.open(positive_image_path).convert('RGB')
        except Exception:
            continue
        positive_jpeg_buffer = io.BytesIO()
        positive_image.save(positive_jpeg_buffer, format='JPEG')
        positive_jpeg_buffer.seek(0)
        positive_image = Image.open(positive_jpeg_buffer)

        category = category_dict[pair_id]
        data_list.append(
            {
                'anchor_image': anchor_image,
                'positive_image': positive_image,
                'category': category,
                'post_id': post_id_dict[pair_id],
            }
        )

    dataset = Dataset.from_list(data_list, features=features)
    dataset.save_to_disk(f'./kream_data_segs/seg_{seg}')
    seg += 1

In [None]:
sub_datasets = []
segs = os.listdir('./kream_data_segs')
for seg in segs:
    sub_dataset = load_from_disk(f'./kream_data_segs/{seg}')
    sub_datasets.append(sub_dataset)

dataset = concatenate_datasets(sub_datasets)
print(f'new: {len(dataset)}')

prev_dirs = sorted([d for d in os.listdir('./') if d.startswith('kream_dataset')])
if prev_dirs:
    prev_dataset = load_from_disk(prev_dirs[-1])
else:
    prev_dataset = None

if prev_dataset:
    print(f'prev: {len(prev_dataset)}')
    dataset = concatenate_datasets([prev_dataset, dataset])

print(f'total: {len(dataset)}')
num_shards = dataset.num_rows // 10000 + int(bool(dataset.num_rows % 10000))
dataset.save_to_disk(f'./kream_dataset_{datetime.now().strftime("%Y%m%d%H%M%S")}', num_shards=num_shards)

In [None]:
import shutil

shutil.rmtree('./kream_data_segs')
shutil.rmtree('./kream_anchor_positive_images')

In [None]:
dataset.push_to_hub('yainage90/kream-fashion-anchor-positive-images', private=True)