In [21]:
import os
from collections import defaultdict
import numpy as np
PREFIX = "/home/robincheong/sbir/data/sketchy/"
FILE_NAMES=["invalid-ambiguous.txt", "invalid-error.txt", "invalid-pose.txt"]


In [79]:
## Remove bad files from dataset
invalid = set()

# function to prepare the file name for removal
def prep_fn(fn):
    fn = fn[:-1]
    fn += ".png"
    return fn
    
# read in text files and add to a set 
for file_name in FILE_NAMES:
    with open(PREFIX + 'info/' + file_name, 'r') as f:
        files = f.readlines()
        files = map(prep_fn, files)
        invalid |= set(files)

In [None]:
TRANSFORMS = ["tx_000000000010/", "tx_000000000110/", "tx_000000001010/", "tx_000000001110/", "tx_000100000000/"]
sketch_datadir = PREFIX + "sketch/"
for transform in TRANSFORMS:
    datadir = sketch_datadir + transform

    for sketchdir in os.listdir(datadir):
        print(f"Walking through {sketchdir}...")
        for file in os.listdir(datadir + sketchdir):
            if file in invalid:
                print(f"Removing file: {file}")
                os.remove(datadir + sketchdir + "/" + file)

In [64]:
## Create test set
with open(PREFIX + "testset.txt", 'r') as f:
    test_set = f.readlines()
    test_set = set(map(lambda x: x[:-5], test_set))

In [None]:
sketches_per_image = defaultdict(int)
## ensure there's > 5 sketches to test on for each image after removing invalids
datadir = PREFIX + "sketch/tx_000000000010/"
for sketchdir in os.listdir(datadir):
    print(f"Walking through {sketchdir}...")
    for file in os.listdir(datadir + sketchdir):
        image_for_sketch = sketchdir + "/" + file[:-6]
        if image_for_sketch in test_set:
            sketches_per_image[image_for_sketch] += 1
            
print(sketches_per_image)

In [14]:
total = 0
for key in sketches_per_image:
    total += sketches_per_image[key]
    if sketches_per_image[key] < 5:
        print(key)
print(total)

6340


In [15]:
max(sketches_per_image.keys(), key=(lambda key: sketches_per_image[key]))

'cannon/n02950826_5567'

In [112]:
sketches_per_image['cannon/n02950826_5567']

9

In [72]:
## construct cat => files mapping
img_per_cat = defaultdict(list)

for file in test_set:
    cat, fn = file.split("/")
    img_per_cat[cat].append(fn + ".jpg")

In [None]:
## Construct validation set
datadir = PREFIX + "photo/tx_000000000000/"
val_set = []
np.random.seed(42)
for photodir in os.listdir(datadir):
    print(f"Walking through {photodir}...")
    
    photos_in_cat = os.listdir(datadir + photodir)
    photos_in_cat = [x for x in photos_in_cat if x not in img_per_cat[photodir]]
    cat_photos = np.random.choice(photos_in_cat, size=10, replace=False)
    cat_photos = list(map(lambda x: photodir + "/" + x[:-4], cat_photos))
    val_set += cat_photos

print(val_set)

In [77]:
os.remove(PREFIX + "valset.txt")
with open(PREFIX + "valset.txt", 'w') as fp:
    for item in val_set:
        fp.write(f"{item}\n")

In [78]:
print(set(val_set) & set(test_set))

set()


In [None]:
print(val_set)

In [None]:
sketches_per_image_val = defaultdict(int)
## ensure there's > 5 sketches to test on for each image after removing invalids
datadir = PREFIX + "sketch/tx_000000000000/"
for sketchdir in os.listdir(datadir):
    print(f"Walking through {sketchdir}...")
    for file in os.listdir(datadir + sketchdir):
        image_for_sketch = sketchdir + "/" + file.split("-")[0].split("/")[0]
        if image_for_sketch in val_set:
            sketches_per_image_val[image_for_sketch] += 1
            
print(len(sketches_per_image_val))

In [84]:
total = 0
for key in sketches_per_image_val:
    total += sketches_per_image_val[key]
    if sketches_per_image_val[key] < 5:
        print(key)
print(total)

6615


In [None]:
img_per_cat_val = defaultdict(list)

for file in val_set:
    cat, fn = file.split("/")
    img_per_cat_val[cat].append(fn + ".jpg")
    
print(img_per_cat_val)

In [None]:
datadir = PREFIX + "photo/tx_000000000000/"
train_set = []
np.random.seed(42)
for photodir in os.listdir(datadir):
    print(f"Walking through {photodir}...")
    
    photos_in_cat = os.listdir(datadir + photodir)
    cat_photos = [x for x in photos_in_cat if x not in img_per_cat[photodir] and x not in img_per_cat_val[photodir]]
    cat_photos = list(map(lambda x: photodir + "/" + x[:-4], cat_photos))
    train_set += cat_photos

print(train_set)

In [96]:
print(set(val_set) & set(test_set) & set(train_set))

set()


In [97]:
print(len(train_set))

10001


In [94]:
sketches_per_image_tr = defaultdict(int)
## ensure there's > 5 sketches to test on for each image after removing invalids
datadir = PREFIX + "sketch/tx_000000000000/"
for sketchdir in os.listdir(datadir):
    print(f"Walking through {sketchdir}...")
    for file in os.listdir(datadir + sketchdir):
        image_for_sketch = sketchdir + "/" + file.split("-")[0].split("/")[0]
        if image_for_sketch in val_set:
            sketches_per_image_tr[image_for_sketch] += 1
            
print(len(sketches_per_image_tr))

Walking through pretzel...
Walking through bee...
Walking through jellyfish...
Walking through crab...
Walking through rifle...
Walking through bat...
Walking through cannon...
Walking through sea_turtle...
Walking through violin...
Walking through zebra...
Walking through turtle...
Walking through elephant...
Walking through horse...
Walking through scissors...
Walking through racket...
Walking through sheep...
Walking through wheelchair...
Walking through window...
Walking through frog...
Walking through tree...
Walking through fish...
Walking through bear...
Walking through parrot...
Walking through deer...
Walking through airplane...
Walking through mouse...
Walking through hamburger...
Walking through armor...
Walking through bicycle...
Walking through chicken...
Walking through flower...
Walking through couch...
Walking through kangaroo...
Walking through candle...
Walking through crocodilian...
Walking through songbird...
Walking through bell...
Walking through motorcycle...
Wal

In [101]:
total = 0
for key in sketches_per_image_tr:
    total += sketches_per_image_tr[key]
    if sketches_per_image_tr[key] < 5:
        print(key)
print(total)

6615


In [104]:
os.remove(PREFIX + "trainset.txt")
with open(PREFIX + "trainset.txt", 'w') as fp:
    for item in train_set:
        fp.write(f"{item}\n")

In [106]:
## Create test set
with open(PREFIX + "testset.txt", 'r') as f:
    testset = f.readlines()
    testset = set(map(lambda x: x[:-5], testset))

In [107]:
## Create test set
with open(PREFIX + "valset.txt", 'r') as f:
    valset = f.readlines()
    valset = set(map(lambda x: x[:-5], valset))

In [108]:
## Create test set
with open(PREFIX + "trainset.txt", 'r') as f:
    trainset = f.readlines()
    trainset = set(map(lambda x: x[:-5], trainset))

In [109]:
print(set(val_set) & set(test_set) & set(train_set))

set()


In [113]:
print(train_set)

['pretzel/n07695742_10567', 'pretzel/n07695742_9827', 'pretzel/n07695742_10000', 'pretzel/n07695742_1795', 'pretzel/n07695742_10766', 'pretzel/n07695742_8015', 'pretzel/n07695742_1602', 'pretzel/n07695742_6804', 'pretzel/n07695742_7948', 'pretzel/n07695742_4484', 'pretzel/n07695742_6230', 'pretzel/n07695742_4147', 'pretzel/n07695742_2788', 'pretzel/n07695742_1250', 'pretzel/n07695742_4492', 'pretzel/n07695742_10321', 'pretzel/n07695742_4375', 'pretzel/n07695742_3267', 'pretzel/n07695742_6663', 'pretzel/n07695742_8237', 'pretzel/n07695742_6580', 'pretzel/n07695742_10303', 'pretzel/n07695742_4616', 'pretzel/n07695742_2091', 'pretzel/n07695742_3371', 'pretzel/n07695742_1032', 'pretzel/n07695742_2935', 'pretzel/n07695742_6496', 'pretzel/n07695742_2899', 'pretzel/n07695742_967', 'pretzel/n07695742_4705', 'pretzel/n07695742_4075', 'pretzel/n07695742_8800', 'pretzel/n07695742_2052', 'pretzel/n07695742_11167', 'pretzel/n07695742_4065', 'pretzel/n07695742_4256', 'pretzel/n07695742_11573', 'pret