In [37]:
import os
from collections import defaultdict
import numpy as np
import pandas as pd
DATADIR = "/home/robincheong/sbir/data/eitz2012/sketches/"
OUTPUTDIR = "/home/robincheong/sbir/data/eitz2012/"

In [41]:
def save_split_fps_txt(outputdir, split, data):
    ''' 
        Saves the train / val / test split filepaths into a .txt file for later use 
        Args:
            outputdir: path to the data folder in which the .txt files will be stored
            split: the split (train / val / test)
            data: the data to be stored
    '''
    with open(PREFIX + f"{split}.txt", 'w') as fp:
        for item in data:
            fp.write(f"{item}\n")
            
            
def get_categories(data_fps, category_map):
    '''
        Return a list of labels for each file in a list
        Args:
            data: list containing the files to match to categories
            category_map: a map matching the string of a category to an int representation
    '''
    labels = []
    for i, file in enumerate(data_fps):
        label = file.split('/')[0]
        label_val = category_map[label]
        labels.append(label_val)
    return labels


def save_csv(data_fps, labels, outputdir, name):
    '''
        Saves the data + labels into outputdir
    '''
    df = pd.DataFrame(data={"filepath": data_fps, "label": labels})
    df.to_csv(f"{outputdir}/{name}.csv", sep=',', index=False)
    

In [42]:
## Create test set
test_set_fps = []
np.random.seed(42)
for catdir in os.listdir(DATADIR):
    sketches = os.listdir(DATADIR + catdir)
    chosen = np.random.choice(sketches, size=10, replace=False)
    chosen = list(map(lambda x: catdir + "/" + x, chosen))
    test_set_fps += chosen

test_set_labels = get_categories(test_set_fps, labels_map)
save_csv(test_set_fps, test_set_labels, OUTPUTDIR, "test")
print(len(test_set_labels))

2500


In [46]:
## Construct validation set
val_set_fps = []
np.random.seed(42)
for catdir in os.listdir(DATADIR):    
    sketches = os.listdir(DATADIR + catdir)
    sketches = [x for x in sketches if catdir + "/" + x not in test_set_fps]
    chosen = np.random.choice(sketches, size=10, replace=False)
    chosen = list(map(lambda x: catdir + "/" + x, chosen))
    val_set_fps += chosen

print(len(val_set_fps))

2500


In [47]:
print(set(val_set_fps) & set(test_set_fps))

set()


In [48]:
val_set_labels = get_categories(val_set_fps, labels_map)
save_csv(val_set_fps, val_set_labels, OUTPUTDIR, "val")

In [50]:
## Construct train set
train_set_fps = []
np.random.seed(42)
for catdir in os.listdir(DATADIR):    
    sketches = os.listdir(DATADIR + catdir)
    sketches = [x for x in sketches if catdir + "/" + x not in test_set_fps and catdir + "/" + x not in val_set_fps]
    chosen = list(map(lambda x: catdir + "/" + x, sketches))
    train_set_fps += chosen

print(len(train_set_fps))

15000


In [51]:
print(set(val_set_fps) & set(test_set_fps) & set(train_set_fps))

set()


In [52]:
train_set_labels = get_categories(train_set_fps, labels_map)
save_csv(train_set_fps, train_set_labels, OUTPUTDIR, "train")

In [53]:
labels = list(set(map(lambda x: x.split('/')[0], test_set_fps)))

In [54]:
labels_map = {label: val for val, label in enumerate(labels)}

In [55]:
print(labels_map)

{'violin': 0, 'satellite': 1, 'beer-mug': 2, 'submarine': 3, 'bear (animal)': 4, 'dog': 5, 'diamond': 6, 'mailbox': 7, 'rainbow': 8, 'giraffe': 9, 'umbrella': 10, 'house': 11, 'ship': 12, 'fire hydrant': 13, 'hamburger': 14, 'bottle opener': 15, 'mouth': 16, 'octopus': 17, 'palm tree': 18, 'eye': 19, 'chair': 20, 'computer-mouse': 21, 'hot-dog': 22, 'arm': 23, 'pig': 24, 'cake': 25, 'saxophone': 26, 'tent': 27, 'spider': 28, 'tennis-racket': 29, 'nose': 30, 't-shirt': 31, 'comb': 32, 'scissors': 33, 'train': 34, 'door handle': 35, 'snail': 36, 'potted plant': 37, 'teacup': 38, 'monkey': 39, 'lightbulb': 40, 'crocodile': 41, 'chandelier': 42, 'cactus': 43, 'harp': 44, 'ear': 45, 'bicycle': 46, 'mug': 47, 'tomato': 48, 'hourglass': 49, 'butterfly': 50, 'head-phones': 51, 'computer monitor': 52, 'frying-pan': 53, 'windmill': 54, 'horse': 55, 'flower with stem': 56, 'angel': 57, 'hand': 58, 'floor lamp': 59, 'truck': 60, 'moon': 61, 'squirrel': 62, 'table': 63, 'foot': 64, 'parrot': 65, 'm