# Setup

In [1]:
import os
import PIL.Image
import numpy as np
import pickle

In [2]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [3]:
def repickle(res, file):
    pickle.dump(res, open(file, "wb" ))

# Training set

In [4]:
cluster_list = './curriculum_clustering/input/webvision_cls0-9.txt' # webvision train list
with open(cluster_list) as f:
    metadata = [x.strip().split(' ') for x in f]

file_set = set([item[0] for item in metadata])
file_list = [item[0] for item in metadata]
print(file_list[0])
#print(metadata[0])
print("# file_set:", len(file_set))

flickr/q0001/2680197397.jpg
# file_set: 25744


In [5]:
source_folders = ["./flickr_resized_256/flickr/", "./google_resized_256/google/"]
X = []
img_ids = []
#query = []
for folder in source_folders:
    folders = [x[0] for x in os.walk(folder)]
    folders = folders[1:]
    #print(folders)

    for fld in folders:
        #print(fld)
        for file in os.listdir(fld):
            img_id = fld +"/"+ file
            img_id = "/".join(img_id.split('/')[-3:])
            file = os.path.join(fld, file)
            #print(img_id)
            #break
            if img_id in file_set:
                #print(img_id)
                img = PIL.Image.open(file)
                img = np.array(img)
                X.append(img)
                img_ids.append(img_id)

    print("# X:", len(X))

# X: 14502
# X: 25744


In [10]:
class_to_query = {'0':['q0001', 'q0002'], '1':['q0003', 'q0004'], '2':['q0005', 'q0006', 'q0007'], '3':['q0008', 'q0009']
                 , '4':['q0010'], '5':['q0011', 'q0012', 'q0013'], '6':['q0014'], '7':['q0015', 'q0016'], '8':['q0017']
                 , '9':['q0018', 'q0019']}
query_to_class = dict()
for key, val in class_to_query.items():
    for query in val:
        query_to_class[query] = key
        
print(query_to_class)

{'q0001': '0', 'q0002': '0', 'q0003': '1', 'q0004': '1', 'q0005': '2', 'q0006': '2', 'q0007': '2', 'q0008': '3', 'q0009': '3', 'q0010': '4', 'q0011': '5', 'q0012': '5', 'q0013': '5', 'q0014': '6', 'q0015': '7', 'q0016': '7', 'q0017': '8', 'q0018': '9', 'q0019': '9'}


In [12]:
class_to_imgid = dict()
for img_id in img_ids:
    # get query
    query = img_id.split('/')[1]
    clss = query_to_class[query]
    if clss not in class_to_imgid:
        class_to_imgid[clss] = []
    class_to_imgid[clss].append(img_id)

#print(class_to_imgid)

In [20]:
import random
imgid_set = set()
random.seed(42)
for key, val in class_to_imgid.items():
    # randomly choose 1372 entries from val
    print("# val:", len(val))
    random.shuffle(val)
    rand_imgid = set(val[:1372])
    imgid_set = imgid_set.union(rand_imgid)

print("# imgid_set:", len(imgid_set))

# val: 1738
# val: 2612
# val: 3589
# val: 2602
# val: 1372
# val: 2006
# val: 2131
# val: 4187
# val: 2147
# val: 3360
# imgid_set: 13720


In [26]:
X_rand = []
img_ids_rand = []

for i, img_id in enumerate(img_ids):
    if img_id in imgid_set:
        X_rand.append(X[i])
        img_ids_rand.append(img_id)

X = X_rand
img_ids = img_ids_rand

print("# X_rand:", len(X))
print("# img_ids_rand:", len(img_ids))

# X_rand: 13720
# img_ids_rand: 13720


In [25]:
print(img_ids_rand[0])
print(img_ids_rand[1])

flickr/q0001/10153619346.jpg
flickr/q0001/1058887763.jpg


In [27]:
data = dict()
data['images'] = X[:6860]
data['id'] = img_ids[:6860]
repickle(data, './curriculum_clustering/input/webvision_cls0-9_batch1')

data = dict()
data['images'] = X[6860:]
data['id'] = img_ids[6860:]
repickle(data, './curriculum_clustering/input/webvision_cls0-9_batch2')

In [46]:
X_p = unpickle('./curriculum_clustering/input/webvision_cls0-9_data')

In [None]:
print(len(X_p))
print(X_p[0])

# Validation set

In [20]:
cluster_list = './info/val_filelist.txt' # webvision train list
with open(cluster_list) as f:
    metadata = [x.strip().split(' ') for x in f]

metadata = metadata[:500]
file_set = set([item[0] for item in metadata])
file_list = [item[0] for item in metadata]
class_list = [item[1] for item in metadata]
print(class_list[0])
print(file_list[0])
#print(metadata[0])
print("# file_set:", len(file_set))

0
val000001.jpg
# file_set: 500


In [21]:
fld = "./val_images_256/"
X = []

for file in os.listdir(fld):
    #print(file)
    file = os.path.join(fld, file)

    img = PIL.Image.open(file)
    img = np.array(img)
    X.append(img)

print("# X:", len(X))

# X: 500


In [22]:
data = dict()
data['images'] = X
data['id'] = file_list
data['classes'] = class_list
repickle(data, './curriculum_clustering/input/webvision_cls0-9_val')