In [1]:
import numpy as np
import pandas as pd
import json
import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import typing as tp

from PIL import Image

from parse_deepfashion import Crop, read_splits

In [2]:
DEEP_FASHION_DIR = "/mnt/data/datasets/deepfashion2/"

In [3]:
def show_crops(crops: tp.List[Crop]):
    fig = plt.figure(figsize=(8, 8))
    columns = 3
    rows = 3
    for i in range(1, columns*rows +1):
        if i-1 >= len(crops):
            break
        id_ = crops[i-1].image_id
        img = np.array(Image.open(f'train/image/{id_}.jpg'), dtype=np.uint8)
        with open(f"train/annos/{id_}.json", "r") as f:
            ann = json.load(f)
        bbox_id = 1
        while f"item{bbox_id}" in ann:
            key = f"item{bbox_id}"
            item_style = ann[key]['style']
            if item_style != 0:
                break
            bbox_id += 1
        bbox = ann[key]['bounding_box']
        category_id = ann[key]['category_id']
        fig.add_subplot(rows, columns, i)
        plt.imshow(img)
        colors = {1: 'r', 2: 'g', 3: 'b', 4: 'c'}
        rect = patches.Rectangle((bbox[:2]), bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor=colors[item_style], facecolor='none')
        plt.gca().add_patch(rect)
        plt.axis('off')
    plt.show()

def show_triplet(*triplet):
    fig = plt.figure(figsize=(8, 8))
    columns = 3
    rows = 1
    i = 1
    for t in triplet:
        img = np.array(Image.open(t), dtype=np.uint8)
        fig.add_subplot(rows, columns, i)
        plt.imshow(img)
        plt.axis('off')
        i += 1
    plt.show()

In [4]:
def generate_triplets(item_data, num_triplets: int, min_crops: int = 5,
                      same_item_prob=0.0,
                      same_style_prob=0.0,
                      neg_same_category_prob=0.0):
    item_ids = [iid for iid, crops in item_data.items() if len(crops) >= min_crops]
    anch_list = []
    pos_list = []
    neg_list = []
    for _ in tqdm.trange(num_triplets):
        while True:
            same_item = (np.random.random() < same_item_prob)
            if (np.random.random() < neg_same_category_prob) and same_item:
                anch_item = np.random.choice(item_ids)
                category_id = item_data[anch_item][0].category_id
                potential_negatives = [iid for iid in item_ids if 
                                       (item_data[iid][0].category_id == category_id)
                                       and
                                       (iid != anch_item)
                                      ]
                if len(potential_negatives) < 1:
                    continue
                neg_item = np.random.choice(potential_negatives)
            else:
                anch_item, neg_item = np.random.choice(item_ids, size=2, replace=False)

            if same_item:
                same_style = (np.random.random() < same_style_prob)
                anch = np.random.choice(item_data[anch_item])
                potential_pos = [c for c in item_data[anch_item] if (c.category_id == anch.category_id) and (c != anch)]
                if same_style:
                    potential_pos = [c for c in potential_pos if c.item_style == anch.item_style]
                if not potential_pos:
                    continue
                pos = np.random.choice(potential_pos)
            else:
                category_id = item_data[anch_item][0].category_id
                potential_pos = [iid for iid in item_ids if 
                                       (item_data[iid][0].category_id == category_id)
                                       and
                                       (iid != anch_item)
                                      ]
                if len(potential_pos) < 1:
                    continue
                pos_item = np.random.choice(potential_pos)
                anch = np.random.choice(item_data[anch_item])
                pos = np.random.choice(item_data[pos_item])

            neg = np.random.choice(item_data[neg_item])

            anch_list.append(anch.crop_file)
            pos_list.append(pos.crop_file)
            neg_list.append(neg.crop_file)
            break
    return anch_list, pos_list, neg_list


def triplets_to_csv(triplets, path, tvt_split=None):
    if tvt_split is None:
        tvt_split = [0.7, 0.15, 0.15]
    df = pd.DataFrame({
        "anchor": triplets[0],
        "positive": triplets[1],
        "negative": triplets[2],
    })
    df['split'] = df['anchor']
    rd = np.random.random(size=len(df))
    train = rd <= tvt_split[0]
    df.loc[train, 'split'] = 'train'
    val = (rd > tvt_split[0]) * (rd <= tvt_split[0] + tvt_split[1])
    df.loc[val, 'split'] = 'val'
    if len(tvt_split) == 3:
        test = (rd > tvt_split[0] + tvt_split[1])
        df.loc[test, 'split'] = 'test'
    df.to_csv(path, index=False)

In [5]:
items_data = read_splits(DEEP_FASHION_DIR, ["train"], False)

Reading train split


  0%|          | 0/191961 [00:00<?, ?it/s]

100%|██████████| 191961/191961 [00:27<00:00, 6978.52it/s] 


In [7]:
triplets = generate_triplets(items_data, 4 * 10 ** 5, 5,
                             0.6,
                             1.0,
                             0.3)
triplets_to_csv(triplets, f"{DEEP_FASHION_DIR}/triplets_400k.csv", tvt_split=[0.9, 0.1])

100%|██████████| 400000/400000 [21:02<00:00, 316.83it/s]
