In [1]:
%matplotlib inline
from matplotlib.pyplot import imshow
import numpy as np
import os
from PIL import Image
import pickle
import random

In [2]:
dir1 = '/barleyhome/sgutstei/101_ObjectCategories_32x32'
dir2 = '/barleyhome/sgutstei/101_ObjectCategories_32x32_b'
os.makedirs(os.path.join(dir1,'cifar_style_datasets'), exist_ok=True)
os.makedirs(os.path.join(dir2,'cifar_style_datasets'), exist_ok=True)

In [3]:
categories = sorted([x for x in os.listdir(dir1) if len(os.listdir(os.path.join(dir1,x)))>0])
categories_dict = {x:ctr for ctr,x in enumerate(categories)}
inv_categories_dict = {v: k for k, v in categories_dict.items()}
#inv_categories_dict
#pickle.dump(inv_categories_dict,open('caltech101_dicts_all.pkl','wb'))

In [4]:
def get_data_elems(im_root, ims, category, fine_label):

    cl_list = []
    fl_list = []
    fn_list = []
    im_list = []

    for curr_im in ims:
        with Image.open(os.path.join(im_root, category, curr_im)) as z:
            z=z.convert('RGB')
            z=z.resize((32,32),Image.ANTIALIAS)
            zz=np.asarray(z)
            zz=zz.transpose(2,0,1)
            zz=zz.reshape(1,3*32*32)
        
        coarse_label=fine_label
        cl_list.append(coarse_label)
        fl_list.append(fine_label)
        fn_list.append("_".join([category,curr_im]))
        im_list.append(zz)

    if len(im_list) < 1:
        import pdb
        pdb.set_trace()
        temp=0
    im_array = np.concatenate(im_list)
    return [cl_list, fl_list, fn_list,  im_array]


In [8]:
def make_datasets(src_dir):
    #categories = sorted([x for x in os.listdir(dir1) if len(os.listdir(os.path.join(dir1,x)))>0])
    #categories_dict = {x:ctr for ctr,x in enumerate(categories)}
    #inv_categories_dict = {v: k for k, v in categories_dict.items()}

    train_cl_list = []
    train_fl_list = []
    train_fn_list = []
    train_im_list = []

    test_cl_list = []
    test_fl_list = []
    test_fn_list = []
    test_im_list = []

    for curr_cat in categories:
        all_ims = [x for x in os.listdir(os.path.join(dir1,curr_cat)) if x[-4:] == '.jpg']
        num_ims = len(all_ims)
        random.shuffle(all_ims)
        tr_ims = all_ims[0:int(.83*num_ims)]
        te_ims = all_ims[int(.83*num_ims):]

        new_tr_info = get_data_elems(dir1, tr_ims, curr_cat, 
                                     categories_dict[curr_cat])

        train_cl_list += new_tr_info[0]
        train_fl_list += new_tr_info[1]
        train_fn_list += new_tr_info[2]
        train_im_list.append(new_tr_info[3])

        new_te_info = get_data_elems(dir1, te_ims, curr_cat, 
                                     categories_dict[curr_cat])

        test_cl_list += new_te_info[0]
        test_fl_list += new_te_info[1]
        test_fn_list += new_te_info[2]
        test_im_list.append(new_te_info[3])

    import pdb
    pdb.set_trace()
    train_im_array = np.concatenate(train_im_list)
    test_im_array = np.concatenate(test_im_list)
    
    return [[train_cl_list, train_fl_list, train_fn_list, train_im_array],
            [test_cl_list, test_fl_list, test_fn_list, test_im_array]]

In [9]:
def shuffle_dataset(cl_list, fl_list, fn_list, im_array):

    shuff_array = np.zeros(im_array.shape, dtype=im_array.dtype)
    num_images = im_array.shape[0]
    shuff_list = [x for x in range(num_images)]
    random.shuffle(shuff_list)

    shuff_cl = num_images * [None]
    shuff_fl = num_images * [None]
    shuff_fn = num_images * [None]

    for new, old in enumerate(shuff_list):
        print("Shuffling ",new,"of",len(shuff_list))
        shuff_cl[new] = cl_list[old]
        shuff_fl[new] = fl_list[old]
        shuff_fn[new] = fn_list[old]
        shuff_array[new,:] = im_array[old,:]

    return [shuff_cl, shuff_fl, shuff_fn, shuff_array]

In [None]:
tr1, te1 = make_datasets(dir1)

> [0;32m<ipython-input-8-f935c7177a41>[0m(41)[0;36mmake_datasets[0;34m()[0m
[0;32m     39 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 41 [0;31m    [0mtrain_im_array[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mconcatenate[0m[0;34m([0m[0mtrain_im_list[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     42 [0;31m    [0mtest_im_array[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mconcatenate[0m[0;34m([0m[0mtest_im_list[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m[0;34m[0m[0m
[0m
ipdb> len(train_im_list)
0
ipdb> categories
[]


In [None]:
shtr1 = shuffle_dataset(tr1[0], tr1[1], tr1[2], tr1[3])
shte1 = shuffle_dataset(te1[0], te1[1], te1[2], te1[3])

In [None]:
def show_image(samp, data_dict):
    test_im = data_dict[3][samp]
    zz=test_im.reshape(1,3,32,32)
    zz=zz[0,:]
    zz=zz.transpose(1,2,0)
    im_name = inv_categories_dict[data_dict[1][samp]]
    file_name = data_dict[2][samp]
    print(im_name, file_name)
    imshow(zz)
    

In [None]:
show_image(5500, shtr1)

In [None]:
shtr_dict = {'coarse_labels': shtr1[0], 'fine_labels':shtr1[1], 
             'filenames':shtr1[2],'data':shtr1[3], 'batch_label':"N/A"}
shte_dict = {'coarse_labels': shte1[0], 'fine_labels':shte1[1], 
             'filenames':shte1[2],'data':shte1[3], 'batch_label':"N/A"}

meta_list = [inv_categories_dict[x] for x in inv_categories_dict]
meta_dict = {'coarse_label_names':meta_list, 'fine_label_names':meta_list}

In [None]:
pickle.dump(shtr_dict,open("train","wb"))
pickle.dump(shte_dict,open("test","wb"))
pickle.dump(meta_dict,open("meta","wb"))
os.getcwd()

In [None]:
pickle.dump(categories_dict, 
            open('/home/smgutstein/Projects/opt-tfer-2/dataset_info/caltech101_dicts_all.pkl','wb'))

In [None]:
liv=['sunflower', 'scorpion', 'dolphin', 'stegosaurus', 'hawksbill',
        'water_lilly', 'dragonfly', 'crayfish', 'Leopards', 'cannon',
        'flamingo_head', 'tick', 'Faces', 'cougar_body', 'flamingo',
        'crocodile', 'bonsai', 'gerenuk', 'emu', 'panda', 'ant',
        'butterfly', 'ibis', 'hedgehog', 'pigeon', 'beaver',
        'platypus', 'lotus', 'wild_cat', 'crab', 'strawberry',
        'rooster', 'sea_horse', 'llama', 'trilobite', 'brontosaurus',
        'nautilus', 'rhino', 'mayfly', 'airplanes', 'lobster',
        'okapi', 'dalmatian', 'crocodile_head', 'bass', 'joshua_tree',
        'kangaroo', 'cougar_face', 'octopus', 'elephant', 'starfish']

In [None]:
nl=['Motorbikes', 'accordion', 'anchor', 'barrel', 'binocular',
        'brain', 'buddha', 'camera', 'car_side', 'ceiling_fan',
        'cellphone', 'chair', 'chandelier', 'cup', 'dollar_bill',
        'electric_guitar', 'euphonium', 'ewer', 'ferry', 'garfield',
        'gramophone', 'grand_piano', 'headphone', 'helicopter', 'inline_skate',
        'ketch', 'lamp', 'laptop', 'mandolin', 'menorah', 
        'metronome', 'minaret', 'pagoda', 'pizza', 'pyramid',
        'revolver', 'saxophone', 'schooner', 'scissors', 'snoopy',
        'soccer_ball', 'stapler', 'stop_sign', 'umbrella', 'watch',
        'wheelchair', 'windsor_chair', 'wrench', 'yin_yang']

In [None]:
def make_subset_datasets(src_dir, categories):
    #categories_dict = {x:ctr for ctr,x in enumerate(categories)}
    #inv_categories_dict = {v: k for k, v in categories_dict.items()}

    train_cl_list = []
    train_fl_list = []
    train_fn_list = []
    train_im_list = []

    test_cl_list = []
    test_fl_list = []
    test_fn_list = []
    test_im_list = []

    for curr_cat in categories:
        all_ims = [x for x in os.listdir(os.path.join(dir1,curr_cat)) if x[-4:] == '.jpg']
        num_ims = len(all_ims)
        random.shuffle(all_ims)
        #print(curr_cat,num_ims)
        tr_ims = all_ims[0:int(.83*num_ims)]
        te_ims = all_ims[int(.83*num_ims):]

        new_tr_info = get_data_elems(dir1, tr_ims, curr_cat, 
                                     categories_dict[curr_cat])

        train_cl_list += new_tr_info[0]
        train_fl_list += new_tr_info[1]
        train_fn_list += new_tr_info[2]
        train_im_list.append(new_tr_info[3])

        new_te_info = get_data_elems(dir1, te_ims, curr_cat, 
                                     categories_dict[curr_cat])

        test_cl_list += new_te_info[0]
        test_fl_list += new_te_info[1]
        test_fn_list += new_te_info[2]
        test_im_list.append(new_te_info[3])

    train_im_array = np.concatenate(train_im_list)
    test_im_array = np.concatenate(test_im_list)
    
    return [[train_cl_list, train_fl_list, train_fn_list, train_im_array],
            [test_cl_list, test_fl_list, test_fn_list, test_im_array]]

In [None]:
liv_tr, liv_te = make_subset_datasets(dir1,liv)

In [None]:
liv_shtr = shuffle_dataset(liv_tr[0], liv_tr[1], liv_tr[2], liv_tr[3])
liv_shte = shuffle_dataset(liv_te[0], liv_te[1], liv_te[2], liv_te[3])

liv_shtr_dict = {'coarse_labels': liv_shtr[0], 'fine_labels':liv_shtr[1], 
             'filenames':liv_shtr[2],'data':liv_shtr[3], 
             'batch_label':"N/A"}
liv_shte_dict = {'coarse_labels': liv_shte[0], 'fine_labels':liv_shte[1], 
             'filenames':liv_shte[2],'data':liv_shte[3], 
             'batch_label':"N/A"}
meta_dict = {'coarse_label_names':liv, 'fine_label_names':liv}


In [None]:
max(liv_shte_dict['fine_labels'])

In [None]:
nliv_tr, nliv_te = make_subset_datasets(dir1,nl)

In [None]:
nliv_shtr = shuffle_dataset(nliv_tr[0], nliv_tr[1], nliv_tr[2], nliv_tr[3])
nliv_shte = shuffle_dataset(nliv_te[0], nliv_te[1], nliv_te[2], nliv_te[3])
meta_dict = {'coarse_label_names':nl, 'fine_label_names':nl}

nliv_shtr_dict = {'coarse_labels': nliv_shtr[0], 'fine_labels':nliv_shtr[1], 
             'filenames':nliv_shtr[2],'data':nliv_shtr[3], 
             'batch_label':"N/A"}
nliv_shte_dict = {'coarse_labels': nliv_shte[0], 'fine_labels':nliv_shte[1], 
             'filenames':nliv_shte[2],'data':nliv_shte[3], 
             'batch_label':"N/A"}
meta_dict = {'coarse_label_names':nl, 'fine_label_names':nl}


In [None]:
pickle.dump(liv_shtr_dict,open("train_src","wb"))
pickle.dump(liv_shte_dict,open("test_src","wb"))
pickle.dump(meta_dict,open("meta_src","wb"))

pickle.dump(nliv_shtr_dict,open("train_trgt","wb"))
pickle.dump(nliv_shte_dict,open("test_trgt","wb"))
pickle.dump(meta_dict,open("meta_trgt","wb"))