In [1]:
#essential imports
import os
import numpy as np
from tqdm import tqdm
import glob
import tensorflow as tf
import scipy.misc

In [2]:
train_location    = "../dataset/images_background"
evaluate_location = "../dataset/images_evaluation"
#dataset:- class->char1,2,3,...->*.png
def read_omniglot():
    data = []
    for r in [train_location, evaluate_location]:
        classes = glob.glob(r + "/*")
        #for each of the name of class
        for cls in tqdm(classes):
            #get the directory for each of the alphabet
            alphabets = glob.glob(cls + "/*")
            for a in alphabets:
                #get each file name
                characters = glob.glob(a+"/*")
                raws = []
                for ch in characters:
                    raw = scipy.misc.imread(ch)
                    raw = scipy.misc.imresize(raw, (28, 28))
                    #for data augmentation
                    for dg in [0, 90, 180, 270]:
                        raw_rot = scipy.misc.imrotate(raw, dg)
                        raw_rot = raw_rot[:, :, np.newaxis]
                        raw_rot = raw_rot.astype(np.float32) / 255
                        raws.append(raw_rot)
                #raws shape:- (80, 28, 28, 1)
                data.append(np.asarray(raws))
#                 print("Raws shape " + str(np.asarray(raws).shape))
    #data shape is (1623, 80, 28, 28, 1)
    np.save("omniglot.npy", np.asarray(data))
#     print("SHape of data " + str(np.asarray(data).shape))
# read_omniglot()

In [3]:
#build the dataloader class
class Data_loader():
    def __init__(self, batch_size, n_way=5, k_shot=1, train_mode=True):
        if not os.path.exists('omniglot.npy'):
            read_omniglot()
        self.batch_size = batch_size
        #number of classes model look at
        self.n_way      = n_way
        #k shot means k examples for each class
        self.k_shot     = k_shot
        omniglot        = np.load('omniglot.npy')
        np.random.shuffle(omniglot)
        #well max of omniglot is 1.0 and min is 0.0
        print("shape of omniglot is " + str(omniglot.shape))
        if train_mode:
            self.images = omniglot[:1200, :20, :, :, :]
            self.num_classes = self.images.shape[0]
            self.num_samples = self.images.shape[1]
        else:
            self.images = omniglot[1200:, :20, :, :, :]
            self.num_classes = self.images.shape[0]
            self.num_samples = self.images.shape[1]
        #number of iterations = number of classes so far
        self.iters = self.num_classes
        
    def next_batch(self):
        x_set_batch = []
        y_set_batch = []
        x_hat_batch = []
        y_hat_batch = []
        #build each set and xhat, yhat
        #set means k examples of n way implies n*k
        for _ in xrange(self.batch_size):
            x_set = []
            y_set = []
            x     = []
            y     = []
            #get n classes randomly
            classes = np.random.permutation(self.num_classes)[:self.n_way]
            #get random target class
            target_class = np.random.randint(self.n_way)
            for i, c in enumerate(classes):
                #get particular samples of the class
                samples = np.random.permutation(self.num_samples)[:self.k_shot+1]
                for s in samples[:-1]:
                    x_set.append(self.images[c][s])
                    y_set.append(i)
                
                if i==target_class:
                    x_hat_batch.append(self.images[c][samples[-1]])
                    y_hat_batch.append(i)
                    
            x_set_batch.append(x_set)
            y_set_batch.append(y_set)
        return np.asarray(x_set_batch).astype(np.float32), np.asarray(y_set_batch).astype(np.int32), np.asarray(x_hat_batch).astype(np.float32), np.asarray(y_hat_batch).astype(np.int32)
#         return np.asarray(x_set_batch).astype(np.float32), np.asarray(y_set_batch).astype(np.int32), np.asarray(x_hat_batch).astype(np.float32), np.asarray(y_hat_batch).asarray(np.int32) 

In [21]:
slim = tf.contrib.slim
rnn  = tf.contrib.rnn
class Matching_Nets():
    #get the model hyperparameters and create placeholders
    def __init__(self,  lr, n_way, k_shot, use_fce, batch_size=32):
        self.lr                   = lr
        self.n_way                = n_way
        self.k_shot               = k_shot
        self.use_fce              = use_fce
        self.batch_size           = batch_size
        self.processing_steps     = 10
        self.support_set_image_ph = tf.placeholder(tf.float32, [None, n_way*k_shot, 28, 28, 1])
        self.support_set_label_ph = tf.placeholder(tf.int32, [None, n_way*k_shot])
        self.example_image_ph     = tf.placeholder(tf.float32, [None, 28, 28, 1])
        self.example_label_ph     = tf.placeholder(tf.int32, [None, ])
    
    def image_encoder(self, image):
        #create 4 layer image net
        with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):
            net = slim.conv2d(image)
            net = slim.max_pool2d(net, [2, 2])
            net = slim.conv2d(net)
            net = slim.max_pool2d(net, [2, 2])
            net = slim.conv2d(net)
            net = slim.max_pool2d(net, [2, 2])
            net = slim.conv2d(net)
            net = slim.max_pool2d(net, [2, 2])
        return tf.reshape(net, [-1, 1 * 1 * 64])
    
    #cosine similarity for embedded support set and target
    def cosine_similarity(self, target, support_set):
        target_normed = target
        sup_similarity = []
        for i in tf.unstack(support_set):
            #batch X 64
            i_normed = tf.nn.l2_normalize(i, 1)
            #(batch, )
            similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2))
            sup_similarity.append(similarity)
        #batch, n*k
        return tf.squeeze(tf.stack(sup_similarity, axis=1))
    
    
    def build(self, support_set_image, support_set_label, image):
        #batch X 64
        image_encoded             = self.image_encoder(image)
        support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]
        #batch x 64
        f_embedding               = image_encoded
        #n*k, batch, 64
        g_embedding               = tf.stack(support_set_image_encoded)
        #c(f(xhat), g(xi))
        #batch, n*k
        embedding_similarity      = self.cosine_similarity(f_embedding, g_embedding)
        #compute softmax on similarity to get a(xhat, xi)
        attention                 = tf.nn.softmax(embedding_similarity)
        y_hat                     = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))
        self.logits               = tf.squeeze(y_hat)
        self.pred                 = tf.argmax(self.logits, 1)
        
    def loss(self, label):
        self.loss_op = tf.losses.sparse_softmax_cross_entropy(label, self.logits)
        
    def train(self):
        return tf.train.AdamOptimizer(self.lr).minimize(self.loss_op)

In [5]:
#hyperparameters:
lr     = 1e-3
epochs = 100
batch_size = 32
n_way = 20
k_shot = 1
use_fce = False

In [22]:
#main model function calling
def train():
    train_loader = Data_loader(batch_size, n_way, k_shot)
    eval_loader  = Data_loader(batch_size, n_way, k_shot, train_mode=False)
    model = Matching_Nets(lr, n_way, k_shot, use_fce, batch_size)
    model.build(model.support_set_image_ph, model.support_set_label_ph, model.example_image_ph)
    model.loss(model.example_label_ph)
    train_op = model.train()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    print('Start Training ')
    print('batch size %d, epoch: %d, initial lr: %.3f' %(batch_size, epochs, lr))
    for epoch in xrange(epochs):
        correct = []
        for step in xrange(train_loader.iters):
            x_set, y_set, x_hat, y_hat = train_loader.next_batch()
            feed_dict = {model.support_set_image_ph:x_set,
                        model.support_set_label_ph:y_set,
                        model.example_image_ph:x_hat,
                        model.example_label_ph:y_hat}
            logits, predictions, loss, _ = sess.run([model.logits, model.pred, model.loss_op, train_op], feed_dict=feed_dict)
            correct.append(np.equal(predictions, y_hat))
            if step%100==0:
                print("epoch %3d, step %3d, loss %3d, acc %.2f%%" % (epoch+1, step, loss, np.mean(np.equal(predictions, y_hat))*100))
                
        print('Training accuracy: %.2f%%'%(np.mean(np.stack(correct)) * 100))
        correct = []
        for step in xrange(eval_loader.iters):
            x_set, y_set, x_hat, y_hat = eval_loader.next_batch()
            feed_dict = {model.support_set_image_ph:x_set,
                        model.support_set_label_ph:y_set,
                        model.example.image_ph:x_hat}
            logits, prediction = sess.run([model.logits, model.pred], feed_dict=feed_dict)
            correct.append(np.equal(prediction, y_hat))
        print('Evaluation accuracy: %.2f%%' % (np.mean(np.stack(correct)) * 100))
    print('Done')
    
train()

shape of omniglot is (1623, 80, 28, 28, 1)
shape of omniglot is (1623, 80, 28, 28, 1)
Start Training 
batch size 32, epoch: 100, initial lr: 0.001
(32, 20, 28, 28, 1)
(32, 20)
(32, 28, 28, 1)
(32,)
(20, 32, 1, 1)
(32, 1, 1)
(32, 20)
(32, 1, 20)
One hot (32, 20, 20)
encoded points is (32, 64)
epoints are (20, 32, 64)
epoints2 are (20, 32, 64)
Model similarity is (32, 20)
epoch   1, step   0, loss   2, acc 3.12%
(32, 20, 28, 28, 1)
(32, 20)
(32, 28, 28, 1)
(32,)
(20, 32, 1, 1)
(32, 1, 1)
(32, 20)
(32, 1, 20)
One hot (32, 20, 20)
encoded points is (32, 64)
epoints are (20, 32, 64)
epoints2 are (20, 32, 64)
Model similarity is (32, 20)
(32, 20, 28, 28, 1)
(32, 20)
(32, 28, 28, 1)
(32,)
(20, 32, 1, 1)
(32, 1, 1)
(32, 20)
(32, 1, 20)
One hot (32, 20, 20)
encoded points is (32, 64)
epoints are (20, 32, 64)
epoints2 are (20, 32, 64)
Model similarity is (32, 20)
(32, 20, 28, 28, 1)
(32, 20)
(32, 28, 28, 1)
(32,)
(20, 32, 1, 1)
(32, 1, 1)
(32, 20)
(32, 1, 20)
One hot (32, 20, 20)
encoded points 

KeyboardInterrupt: 

In [None]:
        if self.use_fce:
            g_embedding = self.fce_g(support_set_image_encoded)     # (n * k, batch_size, 64)
            f_embedding = self.fce_f(image_encoded, g_embedding)    # (batch_size, 64)


#encoded xi:- [batch, 64] of size n*k
#output:- [n*k, batch, 64] where lstm cell needs [batch, 64] for that length of n*k classes.
def fce_g(self, encoded_x_i):
        """the fully conditional embedding function g
        This is a bi-directional LSTM, g(x_i, S) = h_i(->) + h_i(<-) + g'(x_i) where g' is the image encoder
        For omniglot, this is not used.

        encoded_x_i: g'(x_i) in the equation.   length n * k list of (batch_size ,64)
        """
        fw_cell = rnn.BasicLSTMCell(32) # 32 is half of 64 (output from cnn)
        bw_cell = rnn.BasicLSTMCell(32)
        outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(fw_cell, bw_cell, encoded_x_i, dtype=tf.float32)

        return tf.add(tf.stack(encoded_x_i), tf.stack(outputs))

    def fce_f(self, encoded_x, g_embedding):
        """the fully conditional embedding function f
        This is just a vanilla LSTM with attention where the input at each time step is constant and the hidden state
        is a function of previous hidden state but also a concatenated readout vector.
        For omniglot, this is not used.

        encoded_x: f'(x_hat) in equation (3) in paper appendix A.1.     (batch_size, 64)
        g_embedding: g(x_i) in equation (5), (6) in paper appendix A.1. (n * k, batch_size, 64)
        """
        cell = rnn.BasicLSTMCell(64)
        prev_state = cell.zero_state(self.batch_size, tf.float32) # state[0] is c, state[1] is h

        for step in xrange(self.processing_steps):
            output, state = cell(encoded_x, prev_state) # output: (batch_size, 64)
            
            h_k = tf.add(output, encoded_x) # (batch_size, 64)

            content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))    # (n * k, batch_size, 64)
            r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)      # (batch_size, 64)

            prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))

        return output

