In [22]:
try:
    import torch
    import numpy as np
    from pathlib import Path
    import os
    import os.path as osp
    import importlib
    importlib.reload(dp)
    importlib.reload(Config)
except NameError: # It hasn't been imported yet
        import data_load.data_provider as dp
        import config.config_flags as Config

In [4]:
#Loads tieredImageNet embeddings by default. Set a different dataset in config_flags.py

#Pick from train, val, test. 
#Debug param keeps a copy of the pkl embeddings data
dataset_type_pkl = 'val' #This represents the original pkl dataset type to be loaded
dataProvider = dp.DataProvider(dataset_type_pkl, debug=True, verbose=True) 

embeddings_data = dataProvider.get_embeddings_data()

Path fetched: ../embeddings/tieredImageNet/center/val_embeddings.pkl


In [5]:
#Get first num elements in the dictionary - helper method
def head(dict_obj, num=5):
    count = 0
    for key,value in dict_obj.items():
        print("Key: " + str(key), "value: " + str(value))
        count+=1
        if(count == num):
            return

In [6]:
# Format of each element in "keys": _-classLabel-classLabel_filename.JPEG
# List of class labels in miniImageNet https://gist.github.com/kaixin96/ffb88bd025fc05deb2d7f1378e9b7282

head(embeddings_data, 3)

Key: labels value: [0 0 0 ... 0 0 0]
Key: embeddings value: [[3.2385639e-03 1.8942569e-04 1.3159506e-02 ... 5.6467261e-03
  2.8360044e-04 5.4967026e-03]
 [1.0305644e-06 5.7039782e-05 6.0153725e-03 ... 7.2171907e-03
  7.8330771e-04 2.3086595e-03]
 [3.9626868e-04 2.3760945e-03 3.5099715e-03 ... 1.6851516e-03
  2.4224888e-03 6.0567213e-04]
 ...
 [5.0420989e-04 1.1343597e-03 7.8602228e-03 ... 4.0134267e-04
  1.4575045e-03 3.5321803e-04]
 [0.0000000e+00 1.9685794e-03 1.7314308e-03 ... 1.8279791e-03
  1.4138945e-02 5.3597195e-04]
 [4.2886622e-04 2.1248795e-04 5.9927190e-03 ... 4.9523721e-03
  8.7791802e-03 7.8218954e-04]]
Key: keys value: ['1072646529445394375-n02099601-n02099601_2439.JPEG'
 '1113032556112010943-n02102480-n02102480_7854.JPEG'
 '1120287575005342714-n03496892-n03496892_17606.JPEG' ...
 '554146050667129952-n04067472-n04067472_14426.JPEG'
 '574482493408202056-n02930766-n02930766_16758.JPEG'
 '585246420666931655-n03930630-n03930630_1767.JPEG']


In [7]:
#Note the labels key in the dictionary does not represent class labels. The real class labels are inside the filename of each embedding as highlighted in the above cell.
print("===pkl embeddings file info===")
labels = np.array(embeddings_data['labels'])
print("labels shape:", labels.shape, "Unique elements: ", np.unique(labels))

embedding_values = np.array(embeddings_data['embeddings'])
print("embedding values shape: ", embedding_values.shape)

keys = np.array(embeddings_data['keys'])
print("keys shape:", keys.shape)

===pkl embeddings file info===
labels shape: (124000,) Unique elements:  [0]
embedding values shape:  (124000, 640)
keys shape: (124000,)


In [10]:
#Raw embeddings data (pkl) is indexed/organized into two dictionaries. Indexing helps us construct our n-way k shot problems
class_image_file_dict, image_file_embeddings_dict = dataProvider.get_indexed_data()
print("class label to image filenames dictionary:", len(class_image_file_dict))
print("image filename to embeddings data dictionary", len(image_file_embeddings_dict))

class label to image filenames dictionary: 97
image filename to embeddings data dictionary 124000


In [18]:
head(class_image_file_dict, 2)

Key: n02099601 value: ['n02099601_2439.JPEG' 'n02099601_1654.JPEG' 'n02099601_6124.JPEG' ...
 'n02099601_2460.JPEG' 'n02099601_3411.JPEG' 'n02099601_12990.JPEG']
Key: n02102480 value: ['n02102480_7854.JPEG' 'n02102480_8759.JPEG' 'n02102480_3326.JPEG' ...
 'n02102480_9483.JPEG' 'n02102480_6350.JPEG' 'n02102480_5534.JPEG']


In [19]:
head(image_file_embeddings_dict, 1)

Key: n02099601_2439.JPEG value: [3.23856389e-03 1.89425686e-04 1.31595060e-02 1.15267793e-03
 2.22466144e-04 4.02899343e-04 3.83161893e-03 0.00000000e+00
 7.68344710e-03 3.35025042e-02 1.18376885e-03 4.46909748e-04
 1.17685609e-02 0.00000000e+00 4.11629444e-04 2.45084683e-03
 3.62301944e-03 6.37231424e-05 8.03190633e-04 5.83129330e-03
 3.83506389e-03 4.93042928e-04 1.56671857e-03 1.67190167e-03
 9.08649701e-04 4.05392377e-04 7.74350483e-04 1.62679775e-04
 1.03507342e-03 3.15711647e-03 1.14170453e-02 1.16079105e-02
 2.23458093e-03 8.25890992e-03 4.96176617e-05 0.00000000e+00
 5.92266209e-04 0.00000000e+00 2.19947123e-03 1.53862906e-03
 2.45403801e-03 7.79392198e-03 5.06709237e-03 1.04401568e-02
 2.55171629e-03 0.00000000e+00 0.00000000e+00 3.33414832e-03
 4.36473219e-03 3.51644959e-03 0.00000000e+00 2.87072361e-03
 1.30566163e-03 6.66173524e-04 4.53151239e-04 5.39812667e-04
 6.82270504e-04 5.15033666e-04 3.33822286e-03 1.56134670e-03
 9.62726027e-03 2.58035073e-03 1.79423892e-04 1.23264

In [24]:
#Save embeddings data for validation (to be used in TASML)
#Note if we want, we can extract training + val from the embeddings validation data (as done in this cell)

db_title = Config.EMBEDDINGS_DATASET_NAME
sample_size = 1 #Kept small for demonstration. Authors use 30,000
tr_size = Config.TRAINING_NUM_OF_EXAMPLES
val_size = Config.VALIDATION_NUM_OF_EXAMPLES
num_classes = Config.NUM_OF_CLASSES
save_root = Config.SAVE_ROOT
save_path = osp.join(save_root, "%s_%s_%i_%i_%i" % (dataset_type_pkl, db_title, sample_size, tr_size, val_size))

dataProvider.create_db(sample_size, num_classes, tr_size, val_size)
dataProvider.save_db(save_path)

embedding_array:  (5, 16, 640)
task_sig (640,)
label_array (5, 16, 1)
path_array (5, 16)


In [33]:
#task_sig is a class-wise mean and then an example-wise mean of the normalized data
db = dataProvider.db 
db

[(array([1.22679362e-03, 1.66860817e-03, 1.98660997e-02, 1.08372749e-02,
         1.73841098e-02, 2.35794176e-03, 2.78341095e-03, 8.08146786e-03,
         5.84388759e-03, 2.14294225e-02, 7.23577253e-03, 2.44560636e-04,
         2.09411227e-02, 3.10056358e-03, 1.45171826e-03, 3.51813171e-02,
         8.04532692e-03, 4.87685775e-03, 6.96730206e-03, 2.57808108e-02,
         7.85200083e-03, 3.49718472e-03, 5.68181562e-03, 3.44916233e-03,
         2.36794069e-03, 4.33719530e-03, 5.83922545e-03, 2.10596067e-03,
         5.48093395e-04, 1.85515910e-02, 4.36749785e-03, 3.92307011e-02,
         9.88021849e-03, 9.65795597e-04, 1.32382499e-03, 3.29620792e-03,
         5.70353053e-03, 1.02233878e-03, 1.06422998e-02, 2.97966561e-02,
         5.10539088e-03, 1.02006691e-02, 8.28211141e-03, 5.97583935e-04,
         2.17622400e-02, 9.44034857e-03, 4.99323543e-04, 5.89582134e-03,
         8.71939460e-03, 1.59626293e-02, 3.39628236e-03, 1.41584946e-02,
         1.81927474e-03, 1.23228771e-03, 4.20945281

In [57]:
print('type(db):', type(db))
print('len(db):', len(db))
print()
print('type(db[0]):', type(db[0]))
print('len(db[0])', len(db[0]))
print()
print('type(db[0][0])', type(db[0][0]), 'db[0][0].shape:', db[0][0].shape)
print('type(db[0][1])', type(db[0][1]), 'db[0][1].shape:', db[0][1].shape)
print('type(db[0][2])', type(db[0][2]), 'db[0][2].shape:', db[0][2].shape)
print()
print('db[0][0][:3]:\n', db[0][0][:3])
print('db[0][1][:3]:\n', db[0][1][:3])
print('db[0][2][:3]:\n', db[0][2][:3])

type(db): <class 'list'>
len(db): 1

type(db[0]): <class 'tuple'>
len(db[0]) 3

type(db[0][0]) <class 'numpy.ndarray'> db[0][0].shape: (640,)
type(db[0][1]) <class 'numpy.ndarray'> db[0][1].shape: (5, 16, 1)
type(db[0][2]) <class 'numpy.ndarray'> db[0][2].shape: (5, 16)

db[0][0][:3]:
 [0.00122679 0.00166861 0.0198661 ]
db[0][1][:3]:
 [[[0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]
  [0]]

 [[1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]
  [1]]

 [[2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]
  [2]]]
db[0][2][:3]:
 [['n03642806_18680.JPEG' 'n03642806_16305.JPEG' 'n03642806_23213.JPEG'
  'n03642806_17504.JPEG' 'n03642806_3598.JPEG' 'n03642806_23854.JPEG'
  'n03642806_11583.JPEG' 'n03642806_3902.JPEG' 'n03642806_26089.JPEG'
  'n03642806_20128.JPEG' 'n03642806_22729.JPEG' 'n03642806_57.JPEG'
  'n03642806_24723.JPEG' 'n03642806_19125.JPEG' 'n03642806_17362.JPEG'
  'n