In [1]:
import pandas as pd
from collections import defaultdict
from tqdm import *
import numpy as np

def make_category_tables():
    cat2idx = {}
    idx2cat = {}
    for ir in categories_df.itertuples():
        category_id = ir[0]
        category_idx = ir[4]
        cat2idx[category_id] = category_idx
        idx2cat[category_idx] = category_id
    return cat2idx, idx2cat

def make_val_set(df, split_percentage=0.2, drop_percentage=0.):
    # Find the product_ids for each category.
    category_dict = defaultdict(list)
    for ir in tqdm(df.itertuples()):
        category_dict[ir[4]].append(ir[0])

    train_list = []
    val_list = []
    with tqdm(total=len(df)) as pbar:
        for category_id, product_ids in category_dict.items():
            category_idx = cat2idx[category_id]

            # Randomly remove products to make the dataset smaller.
            keep_size = int(len(product_ids) * (1. - drop_percentage))
            if keep_size < len(product_ids):
                product_ids = np.random.choice(product_ids, keep_size, replace=False)

            # Randomly choose the products that become part of the validation set.
            val_size = int(len(product_ids) * split_percentage)
            if val_size > 0:
                val_ids = np.random.choice(product_ids, val_size, replace=False)
            else:
                val_ids = []

            # Create a new row for each image.
            for product_id in product_ids:
                row = [product_id, category_idx]
                for img_idx in range(df.loc[product_id, "num_imgs"]):
                    if product_id in val_ids:
                        val_list.append(row + [img_idx])
                    else:
                        train_list.append(row + [img_idx])
                pbar.update()
                
    columns = ["product_id", "category_idx", "img_idx"]
    train_df = pd.DataFrame(train_list, columns=columns)
    val_df = pd.DataFrame(val_list, columns=columns)   
    return train_df, val_df

categories_df = pd.read_csv("categories.csv", index_col=0)
cat2idx, idx2cat = make_category_tables()

print(cat2idx[1000005633])

train_offsets_df = pd.read_csv("train_offsets.csv", index_col=0)
dev_train_images_df, dev_val_images_df = make_val_set(train_offsets_df, split_percentage=0.2, drop_percentage=0.9)
print("Number of dev train images:", len(dev_train_images_df))
print("Number of dev val images:", len(dev_val_images_df))
dev_train_images_df.to_csv("dev_train_images.csv")
dev_val_images_df.to_csv("dev_val_images.csv")

619


  mask |= (ar1 == a)
7069896it [00:09, 769306.78it/s]
 10%|▉         | 704102/7069896 [00:31<04:43, 22436.43it/s]


Number of dev train images: 989127
Number of dev val images: 242091
