<a href="https://colab.research.google.com/github/psmouli14/final_project/blob/main/omniglot_prototype_combo_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import glob
from PIL import Image

import numpy as np
import tensorflow as tf

In [2]:
from google.colab import files
uploaded = files.upload()

Saving omniglot.zip to omniglot.zip


In [3]:
!unzip omniglot.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_05.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_11.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_10.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_04.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_12.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_06.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_07.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_13.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_17.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_03.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_02.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/0277_16.png  
  inflating: omniglot/data/Syriac_(Estrangelo)/character05/

In [4]:
root_dir = 'omniglot/'
train_split_path = os.path.join(root_dir, 'splits', 'vinyals', 'train.txt')
print(train_split_path)

with open(train_split_path, 'r') as train_split:
    train_classes = [line.rstrip() for line in train_split.readlines()]
print(train_classes[2])

val_split_path = os.path.join(root_dir, 'splits', 'vinyals', 'val.txt')
print(val_split_path)

with open(val_split_path, 'r') as val_split:
    val_classes = [line.rstrip() for line in val_split.readlines()]
print(val_classes[2])

test_split_path = os.path.join(root_dir, 'splits', 'vinyals', 'test.txt')
print(test_split_path)

with open(test_split_path, 'r') as test_split:
    test_classes = [line.rstrip() for line in test_split.readlines()]
print(test_classes[2])

omniglot/splits/vinyals/train.txt
Angelic/character01/rot180
omniglot/splits/vinyals/val.txt
Hebrew/character01/rot180
omniglot/splits/vinyals/test.txt
Gurmukhi/character42/rot180


In [5]:
no_train_classes = len(train_classes)
print(no_train_classes)

no_val_classes = len(val_classes)
print(no_val_classes)

no_test_classes = len(test_classes)
print(no_test_classes)

4112
688
1692


In [6]:
#number of examples per class
num_examples = 20

#image width
img_width = 32

#image height
img_height = 32

#channels
channels = 1

In [7]:
train_dataset = np.zeros([no_train_classes, num_examples, img_height, img_width], dtype=np.float32)
print(train_dataset.shape)

val_dataset = np.zeros([no_val_classes, num_examples, img_height, img_width], dtype=np.float32)
print(val_dataset.shape)

test_dataset = np.zeros([no_test_classes, num_examples, img_height, img_width], dtype=np.float32)
print(test_dataset.shape)


(4112, 20, 32, 32)
(688, 20, 32, 32)
(1692, 20, 32, 32)


In [8]:
def populate_dataset(classes, dataset):
  for label, name in enumerate(classes):
    alphabet, character, rotation = name.split('/')
    rotation = float(rotation[3:])
    img_dir = os.path.join(root_dir, 'data', alphabet, character)
    img_files = sorted(glob.glob(os.path.join(img_dir, '*.png')))
    
    for index, img_file in enumerate(img_files):
      
      img = Image.open(img_file).resize((img_height, img_width)).rotate(rotation)
      img = np.asarray(img)
      img = 1 - img
      dataset[label, index] = img

  return dataset


train_dataset = populate_dataset(train_classes, train_dataset)
val_dataset = populate_dataset(val_classes, val_dataset)
test_dataset = populate_dataset(test_classes, test_dataset)


In [9]:
#Increase the axis by 1 to accomodate channel
train_dataset = np.expand_dims(train_dataset, axis=-1)
val_dataset = np.expand_dims(val_dataset, axis=-1)
test_dataset = np.expand_dims(test_dataset, axis=-1)

print(train_dataset.shape)
print(val_dataset.shape)
print(test_dataset.shape)

(4112, 20, 32, 32, 1)
(688, 20, 32, 32, 1)
(1692, 20, 32, 32, 1)


In [10]:
# This will be used to get the next set of support and query dataset for the episode
def get_next_episode(dataset, num_way, num_shot, num_query, no_of_classes):
  support = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
  query = np.zeros([num_way, num_query, img_height, img_width, channels], dtype=np.float32)
  episodic_classes = np.random.permutation(no_of_classes)[:num_way]

  for index, class_ in enumerate(episodic_classes):
    selected = np.random.permutation(num_examples)[:num_shot + num_query]
    
    support[index] = dataset[class_][selected[:num_shot]]
    query[index] = dataset[class_][selected[num_shot:]]
    
  return support, query

def euclidean_distance(a, b):

    N, D = tf.shape(a)[0], tf.shape(a)[1]
    M = tf.shape(b)[0]
    a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
    b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
    return tf.reduce_mean(tf.square(a - b), axis=2)


In [11]:
from tensorflow.keras.layers import Dense, Flatten, Conv2D, BatchNormalization, Dropout, GlobalMaxPooling2D
from tensorflow.keras import Model

class Prototypical(Model):
   
    def __init__(self, n_support, n_query, w, h, c):
       
        super(Prototypical, self).__init__()
        self.w, self.h, self.c = w, h, c

        # Encoder of CNN with 4 blocks
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPool2D((2, 2)),

            tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPool2D((2, 2)),

            tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPool2D((2, 2)),

            tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPool2D((2, 2)), Flatten()]
        )

    def call(self, support, query):
        n_class = support.shape[0]
        n_support = support.shape[1]
        n_query = query.shape[1]
        y = np.tile(np.arange(n_class)[:, np.newaxis], (1, n_query))
        y_onehot = tf.cast(tf.one_hot(y, n_class), tf.float32)

        # correct indices of support samples (just natural order)
        target_inds = tf.reshape(tf.range(n_class), [n_class, 1])
        target_inds = tf.tile(target_inds, [1, n_query])

        # merge support and query to forward through encoder
        cat = tf.concat([
            tf.reshape(support, [n_class * n_support,
                                 self.w, self.h, self.c]),
            tf.reshape(query, [n_class * n_query,
                               self.w, self.h, self.c])], axis=0)
        z = self.encoder(cat)

        # Divide embedding into support and query
        z_prototypes = tf.reshape(z[:n_class * n_support],
                                  [n_class, n_support, z.shape[-1]])
        # Prototypes
        z_prototypes = tf.math.reduce_mean(z_prototypes, axis=1)
        z_query = z[n_class * n_support:]

        # Calculate distances between query and prototypes
        dists = euclidean_distance(z_query, z_prototypes)

        # log softmax 
        log_p_y = tf.nn.log_softmax(-dists, axis=-1)
        log_p_y = tf.reshape(log_p_y, [n_class, n_query, -1])
        
        loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_onehot, log_p_y), axis=-1), [-1]))
        eq = tf.cast(tf.equal(
            tf.cast(tf.argmax(log_p_y, axis=-1), tf.int32), 
            tf.cast(y, tf.int32)), tf.float32)
        acc = tf.reduce_mean(eq)
        return loss, acc

    def save(self, model_path):
        
        self.encoder.save(model_path)

    def load(self, model_path):
        
        self.encoder(tf.zeros([1, self.w, self.h, self.c]))
        self.encoder.load_weights(model_path)

In [12]:
# Lists to hold values for N-way k-shots experiments
train_num_ways = [60, 60, 40, 40]
test_num_ways = [5, 5, 20, 20]
num_shots = [5, 1, 5, 1]
learning_rate = 0.001

In [13]:
#Run prototypical model with only training set
# This will be using the first values from the lists for the experiment
num_epochs = 80
num_episodes = 100
save_path = "./results/models/omniglot_train0.h5"

train_loss = tf.metrics.Mean(name='train_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')


#number of classes
num_way = train_num_ways[0] 

#number of examples per class for support set
num_shot = num_shots[0]  

#number of query points
num_query = num_shots[0] 

least_loss = {'least_loss': 100.00}

support = np.zeros([num_way, num_shot, img_width, img_height, channels], dtype=np.float32)
query = np.zeros([num_way, num_query, img_width, img_height, channels], dtype=np.float32)
model = Prototypical(support, query, img_width, img_height, channels)
optimizer = tf.keras.optimizers.Adam(learning_rate)

@tf.function
def loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

least_loss = {'least_loss': 100.00}

@tf.function
def train_step(support, query):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  train_loss(loss)
  train_acc(acc)

for epoch in range(num_epochs):
  train_loss.reset_states()
  train_acc.reset_states()

  for episode in range(num_episodes):
    train_support, train_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
    train_step(train_support, train_query)

  cur_loss = train_loss.result().numpy()
  if cur_loss < least_loss['least_loss']:
      print("Saving new best model with loss: ", cur_loss)
      least_loss['least_loss'] = cur_loss
      model.save(save_path)
  
  template = 'Epoch {}, Loss: {}, Accuracy: {}'
  print(template.format(epoch + 1, train_loss.result(), train_acc.result() * 100))



Saving new best model with loss:  1.8223703
Epoch 1, Loss: 1.8223702907562256, Accuracy: 59.45664596557617




Saving new best model with loss:  0.6934691
Epoch 2, Loss: 0.6934691071510315, Accuracy: 80.90666198730469




Saving new best model with loss:  0.39470193
Epoch 3, Loss: 0.39470192790031433, Accuracy: 88.8933334350586




Saving new best model with loss:  0.26275072
Epoch 4, Loss: 0.2627507150173187, Accuracy: 92.49332427978516




Saving new best model with loss:  0.21018419
Epoch 5, Loss: 0.21018418669700623, Accuracy: 93.8800048828125




Saving new best model with loss:  0.19405496
Epoch 6, Loss: 0.1940549612045288, Accuracy: 94.44332885742188




Saving new best model with loss:  0.1841898
Epoch 7, Loss: 0.1841897964477539, Accuracy: 95.08998107910156




Saving new best model with loss:  0.16692977
Epoch 8, Loss: 0.16692976653575897, Accuracy: 95.0433120727539




Saving new best model with loss:  0.14996357
Epoch 9, Loss: 0.14996357262134552, Accuracy: 95.72666931152344
Epoch 10, Loss: 0.1519554853439331, Accuracy: 95.62998962402344




Saving new best model with loss:  0.13310142
Epoch 11, Loss: 0.1331014186143875, Accuracy: 96.25334930419922
Epoch 12, Loss: 0.13790582120418549, Accuracy: 95.98333740234375




Saving new best model with loss:  0.12336545
Epoch 13, Loss: 0.12336544692516327, Accuracy: 96.4699935913086
Epoch 14, Loss: 0.12367140501737595, Accuracy: 96.43665313720703




Saving new best model with loss:  0.11403987
Epoch 15, Loss: 0.11403986811637878, Accuracy: 96.75335693359375




Saving new best model with loss:  0.11082628
Epoch 16, Loss: 0.1108262836933136, Accuracy: 96.60667419433594




Saving new best model with loss:  0.10802282
Epoch 17, Loss: 0.10802281647920609, Accuracy: 96.9800033569336
Epoch 18, Loss: 0.10821503400802612, Accuracy: 96.89334869384766




Saving new best model with loss:  0.09590104
Epoch 19, Loss: 0.09590104222297668, Accuracy: 97.11670684814453




Saving new best model with loss:  0.089777514
Epoch 20, Loss: 0.08977751433849335, Accuracy: 97.18999481201172
Epoch 21, Loss: 0.0924503356218338, Accuracy: 97.23333740234375




Saving new best model with loss:  0.08829202
Epoch 22, Loss: 0.08829201757907867, Accuracy: 97.33670043945312




Saving new best model with loss:  0.08785015
Epoch 23, Loss: 0.08785015344619751, Accuracy: 97.46332550048828




Saving new best model with loss:  0.08562172
Epoch 24, Loss: 0.08562172204256058, Accuracy: 97.46669006347656




Saving new best model with loss:  0.081191294
Epoch 25, Loss: 0.0811912938952446, Accuracy: 97.52999114990234




Saving new best model with loss:  0.07741179
Epoch 26, Loss: 0.07741179317235947, Accuracy: 97.72000885009766
Epoch 27, Loss: 0.0844140350818634, Accuracy: 97.3466796875
Epoch 28, Loss: 0.08009614795446396, Accuracy: 97.60002136230469




Saving new best model with loss:  0.0759286
Epoch 29, Loss: 0.07592859864234924, Accuracy: 97.77333068847656
Epoch 30, Loss: 0.07720301300287247, Accuracy: 97.61665344238281




Saving new best model with loss:  0.07560368
Epoch 31, Loss: 0.0756036788225174, Accuracy: 97.68999481201172




Saving new best model with loss:  0.07480076
Epoch 32, Loss: 0.0748007595539093, Accuracy: 97.78331756591797




Saving new best model with loss:  0.07120441
Epoch 33, Loss: 0.07120440900325775, Accuracy: 97.84333801269531




Saving new best model with loss:  0.066211425
Epoch 34, Loss: 0.06621142476797104, Accuracy: 97.87662506103516
Epoch 35, Loss: 0.07591696083545685, Accuracy: 97.79329681396484
Epoch 36, Loss: 0.06903688609600067, Accuracy: 97.90000915527344
Epoch 37, Loss: 0.06786845624446869, Accuracy: 97.89999389648438




Saving new best model with loss:  0.06406434
Epoch 38, Loss: 0.06406433880329132, Accuracy: 98.00334167480469
Epoch 39, Loss: 0.06704043596982956, Accuracy: 97.95332336425781




Saving new best model with loss:  0.06285743
Epoch 40, Loss: 0.06285742670297623, Accuracy: 97.95331573486328




Saving new best model with loss:  0.06086492
Epoch 41, Loss: 0.060864921659231186, Accuracy: 98.11331939697266
Epoch 42, Loss: 0.0624455064535141, Accuracy: 98.11997985839844
Epoch 43, Loss: 0.06483400613069534, Accuracy: 97.99998474121094




Saving new best model with loss:  0.057385873
Epoch 44, Loss: 0.057385873049497604, Accuracy: 98.22663879394531




Saving new best model with loss:  0.055491377
Epoch 45, Loss: 0.0554913766682148, Accuracy: 98.17667388916016
Epoch 46, Loss: 0.055811136960983276, Accuracy: 98.25665283203125
Epoch 47, Loss: 0.05992266535758972, Accuracy: 98.09332275390625
Epoch 48, Loss: 0.05678478255867958, Accuracy: 98.26664733886719




Saving new best model with loss:  0.05422054
Epoch 49, Loss: 0.0542205385863781, Accuracy: 98.28331756591797




Saving new best model with loss:  0.04798604
Epoch 50, Loss: 0.04798604175448418, Accuracy: 98.41331481933594
Epoch 51, Loss: 0.05636012554168701, Accuracy: 98.16332244873047
Epoch 52, Loss: 0.05462675914168358, Accuracy: 98.22999572753906
Epoch 53, Loss: 0.050800736993551254, Accuracy: 98.38664245605469
Epoch 54, Loss: 0.04990335926413536, Accuracy: 98.38997650146484
Epoch 55, Loss: 0.05164864659309387, Accuracy: 98.29330444335938
Epoch 56, Loss: 0.04850419983267784, Accuracy: 98.4066390991211
Epoch 57, Loss: 0.05317254364490509, Accuracy: 98.17333221435547
Epoch 58, Loss: 0.0535234659910202, Accuracy: 98.19329071044922




Saving new best model with loss:  0.04409799
Epoch 59, Loss: 0.04409798979759216, Accuracy: 98.56663513183594
Epoch 60, Loss: 0.050138920545578, Accuracy: 98.37664031982422




Saving new best model with loss:  0.043723036
Epoch 61, Loss: 0.04372303560376167, Accuracy: 98.51663208007812
Epoch 62, Loss: 0.046944886445999146, Accuracy: 98.50330352783203
Epoch 63, Loss: 0.04587273672223091, Accuracy: 98.4832992553711
Epoch 64, Loss: 0.049494363367557526, Accuracy: 98.31330871582031




Saving new best model with loss:  0.04034471
Epoch 65, Loss: 0.040344711393117905, Accuracy: 98.6032943725586
Epoch 66, Loss: 0.041141342371702194, Accuracy: 98.5533218383789
Epoch 67, Loss: 0.04525251314043999, Accuracy: 98.47332000732422
Epoch 68, Loss: 0.04823112487792969, Accuracy: 98.35997009277344




Saving new best model with loss:  0.038025767
Epoch 69, Loss: 0.03802576661109924, Accuracy: 98.75662994384766
Epoch 70, Loss: 0.04090440645813942, Accuracy: 98.61662292480469
Epoch 71, Loss: 0.04670336842536926, Accuracy: 98.43667602539062
Epoch 72, Loss: 0.0420367605984211, Accuracy: 98.60327911376953
Epoch 73, Loss: 0.04615261033177376, Accuracy: 98.45331573486328
Epoch 74, Loss: 0.04139649495482445, Accuracy: 98.61663818359375
Epoch 75, Loss: 0.04341845586895943, Accuracy: 98.6133041381836
Epoch 76, Loss: 0.0422544851899147, Accuracy: 98.52996063232422




Saving new best model with loss:  0.03801883
Epoch 77, Loss: 0.03801883012056351, Accuracy: 98.65997314453125




Saving new best model with loss:  0.036466975
Epoch 78, Loss: 0.03646697476506233, Accuracy: 98.76663208007812
Epoch 79, Loss: 0.03946603462100029, Accuracy: 98.6099624633789
Epoch 80, Loss: 0.04077218100428581, Accuracy: 98.5699691772461


In [14]:
#Run prototypical model with only training set
# This will be using the second values from the lists for the experiment
num_epochs = 80
num_episodes = 100
save_path = "./results/models/omniglot_train1.h5"

train_loss = tf.metrics.Mean(name='train_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')


#number of classes
num_way = train_num_ways[1] 

#number of examples per class for support set
num_shot = num_shots[1]  

#number of query points
num_query = num_shots[1] 

least_loss = {'least_loss': 100.00}

support = np.zeros([num_way, num_shot, img_width, img_height, channels], dtype=np.float32)
query = np.zeros([num_way, num_query, img_width, img_height, channels], dtype=np.float32)
model = Prototypical(support, query, img_width, img_height, channels)
optimizer = tf.keras.optimizers.Adam(learning_rate)

@tf.function
def loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

least_loss = {'least_loss': 100.00}

@tf.function
def train_step(support, query):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  train_loss(loss)
  train_acc(acc)

for epoch in range(num_epochs):
  train_loss.reset_states()
  train_acc.reset_states()

  for episode in range(num_episodes):
    train_support, train_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
    train_step(train_support, train_query)

  cur_loss = train_loss.result().numpy()
  if cur_loss < least_loss['least_loss']:
      print("Saving new best model with loss: ", cur_loss)
      least_loss['least_loss'] = cur_loss
      model.save(save_path)
  
  template = 'Epoch {}, Loss: {}, Accuracy: {}'
  print(template.format(epoch + 1, train_loss.result(), train_acc.result() * 100))



Saving new best model with loss:  2.725025
Epoch 1, Loss: 2.725024938583374, Accuracy: 38.58333206176758




Saving new best model with loss:  1.6814504
Epoch 2, Loss: 1.681450366973877, Accuracy: 56.53331756591797




Saving new best model with loss:  1.2530684
Epoch 3, Loss: 1.253068447113037, Accuracy: 67.20002746582031




Saving new best model with loss:  0.897805
Epoch 4, Loss: 0.8978049755096436, Accuracy: 74.76666259765625




Saving new best model with loss:  0.75266457
Epoch 5, Loss: 0.7526645660400391, Accuracy: 79.08333587646484




Saving new best model with loss:  0.6636819
Epoch 6, Loss: 0.6636819243431091, Accuracy: 81.61664581298828




Saving new best model with loss:  0.6004457
Epoch 7, Loss: 0.6004456877708435, Accuracy: 83.08332061767578




Saving new best model with loss:  0.5677421
Epoch 8, Loss: 0.567742109298706, Accuracy: 83.59999084472656




Saving new best model with loss:  0.5464982
Epoch 9, Loss: 0.54649817943573, Accuracy: 84.76666259765625




Saving new best model with loss:  0.51183605
Epoch 10, Loss: 0.511836051940918, Accuracy: 85.39997863769531




Saving new best model with loss:  0.49727196
Epoch 11, Loss: 0.49727195501327515, Accuracy: 86.09999084472656




Saving new best model with loss:  0.4672421
Epoch 12, Loss: 0.4672420918941498, Accuracy: 86.15000915527344




Saving new best model with loss:  0.4445419
Epoch 13, Loss: 0.44454190135002136, Accuracy: 86.433349609375




Saving new best model with loss:  0.40536222
Epoch 14, Loss: 0.40536221861839294, Accuracy: 87.84998321533203
Epoch 15, Loss: 0.41128113865852356, Accuracy: 88.36665344238281




Saving new best model with loss:  0.3874309
Epoch 16, Loss: 0.38743090629577637, Accuracy: 89.33333587646484
Epoch 17, Loss: 0.39466017484664917, Accuracy: 88.81664276123047




Saving new best model with loss:  0.3787293
Epoch 18, Loss: 0.37872931361198425, Accuracy: 88.76667022705078




Saving new best model with loss:  0.37445432
Epoch 19, Loss: 0.3744543194770813, Accuracy: 89.1166763305664




Saving new best model with loss:  0.33138332
Epoch 20, Loss: 0.3313833177089691, Accuracy: 89.98334503173828
Epoch 21, Loss: 0.33626270294189453, Accuracy: 90.30000305175781
Epoch 22, Loss: 0.3342527747154236, Accuracy: 90.29998779296875




Saving new best model with loss:  0.30722943
Epoch 23, Loss: 0.3072294294834137, Accuracy: 91.1166763305664
Epoch 24, Loss: 0.33041003346443176, Accuracy: 90.48334503173828
Epoch 25, Loss: 0.32819080352783203, Accuracy: 90.19999694824219
Epoch 26, Loss: 0.31369906663894653, Accuracy: 91.21668243408203
Epoch 27, Loss: 0.31752869486808777, Accuracy: 90.88334655761719
Epoch 28, Loss: 0.3087031841278076, Accuracy: 91.21666717529297




Saving new best model with loss:  0.2926168
Epoch 29, Loss: 0.2926168143749237, Accuracy: 91.0833511352539
Epoch 30, Loss: 0.2935243844985962, Accuracy: 91.31668853759766




Saving new best model with loss:  0.27271786
Epoch 31, Loss: 0.2727178633213043, Accuracy: 92.08333587646484
Epoch 32, Loss: 0.27623119950294495, Accuracy: 91.71666717529297




Saving new best model with loss:  0.26507512
Epoch 33, Loss: 0.26507511734962463, Accuracy: 91.9000015258789
Epoch 34, Loss: 0.27261221408843994, Accuracy: 92.23335266113281




Saving new best model with loss:  0.24351121
Epoch 35, Loss: 0.24351121485233307, Accuracy: 92.54998779296875
Epoch 36, Loss: 0.27500736713409424, Accuracy: 92.2667007446289
Epoch 37, Loss: 0.272095263004303, Accuracy: 91.81666564941406
Epoch 38, Loss: 0.26221275329589844, Accuracy: 92.14999389648438
Epoch 39, Loss: 0.24995751678943634, Accuracy: 92.86666107177734




Saving new best model with loss:  0.24329633
Epoch 40, Loss: 0.2432963252067566, Accuracy: 92.5999755859375
Epoch 41, Loss: 0.2574007213115692, Accuracy: 92.9666748046875
Epoch 42, Loss: 0.27080008387565613, Accuracy: 92.18333435058594
Epoch 43, Loss: 0.24451416730880737, Accuracy: 92.91665649414062




Saving new best model with loss:  0.2365124
Epoch 44, Loss: 0.23651239275932312, Accuracy: 93.19998931884766




Saving new best model with loss:  0.22734477
Epoch 45, Loss: 0.22734476625919342, Accuracy: 92.99999237060547
Epoch 46, Loss: 0.23016513884067535, Accuracy: 93.48333740234375
Epoch 47, Loss: 0.2321457713842392, Accuracy: 92.8333511352539




Saving new best model with loss:  0.21407971
Epoch 48, Loss: 0.21407970786094666, Accuracy: 93.60000610351562
Epoch 49, Loss: 0.21843059360980988, Accuracy: 93.04998779296875
Epoch 50, Loss: 0.24128396809101105, Accuracy: 93.18333435058594




Saving new best model with loss:  0.20946911
Epoch 51, Loss: 0.20946910977363586, Accuracy: 93.5833511352539




Saving new best model with loss:  0.20709206
Epoch 52, Loss: 0.2070920616388321, Accuracy: 94.01666259765625
Epoch 53, Loss: 0.22769029438495636, Accuracy: 93.38333129882812
Epoch 54, Loss: 0.21028442680835724, Accuracy: 94.04998016357422
Epoch 55, Loss: 0.2184450775384903, Accuracy: 93.55001068115234




Saving new best model with loss:  0.19155377
Epoch 56, Loss: 0.1915537714958191, Accuracy: 94.11666107177734
Epoch 57, Loss: 0.21993903815746307, Accuracy: 93.16667938232422
Epoch 58, Loss: 0.1945711374282837, Accuracy: 94.3833236694336
Epoch 59, Loss: 0.19814616441726685, Accuracy: 93.93331146240234
Epoch 60, Loss: 0.19905954599380493, Accuracy: 93.94999694824219
Epoch 61, Loss: 0.20597705245018005, Accuracy: 94.13331604003906
Epoch 62, Loss: 0.21292129158973694, Accuracy: 93.96666717529297
Epoch 63, Loss: 0.1921580731868744, Accuracy: 94.24998474121094




Saving new best model with loss:  0.18370558
Epoch 64, Loss: 0.18370558321475983, Accuracy: 94.36666870117188
Epoch 65, Loss: 0.2319098263978958, Accuracy: 93.04999542236328
Epoch 66, Loss: 0.1956406980752945, Accuracy: 93.95001220703125
Epoch 67, Loss: 0.19223757088184357, Accuracy: 93.94998168945312
Epoch 68, Loss: 0.18465524911880493, Accuracy: 94.24998474121094
Epoch 69, Loss: 0.18447345495224, Accuracy: 94.06665802001953




Saving new best model with loss:  0.1738064
Epoch 70, Loss: 0.17380639910697937, Accuracy: 94.6333236694336
Epoch 71, Loss: 0.19733457267284393, Accuracy: 94.23332977294922
Epoch 72, Loss: 0.1998845338821411, Accuracy: 94.13331604003906
Epoch 73, Loss: 0.180770143866539, Accuracy: 94.68331909179688
Epoch 74, Loss: 0.1760677546262741, Accuracy: 94.43331146240234
Epoch 75, Loss: 0.18751540780067444, Accuracy: 94.19998931884766
Epoch 76, Loss: 0.1813250184059143, Accuracy: 94.46664428710938
Epoch 77, Loss: 0.1927918940782547, Accuracy: 94.04999542236328
Epoch 78, Loss: 0.18336202204227448, Accuracy: 94.38330841064453
Epoch 79, Loss: 0.19591639935970306, Accuracy: 93.93334197998047




Saving new best model with loss:  0.16431047
Epoch 80, Loss: 0.16431047022342682, Accuracy: 95.36666107177734


In [15]:
#Run prototypical model with only training set
# This will be using the third values from the lists for the experiment
num_epochs = 80
num_episodes = 100
save_path = "./results/models/omniglot_train2.h5"

train_loss = tf.metrics.Mean(name='train_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')


#number of classes
num_way = train_num_ways[2] 

#number of examples per class for support set
num_shot = num_shots[2]  

#number of query points
num_query = num_shots[2] 

least_loss = {'least_loss': 100.00}

support = np.zeros([num_way, num_shot, img_width, img_height, channels], dtype=np.float32)
query = np.zeros([num_way, num_query, img_width, img_height, channels], dtype=np.float32)
model = Prototypical(support, query, img_width, img_height, channels)
optimizer = tf.keras.optimizers.Adam(learning_rate)

@tf.function
def loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

least_loss = {'least_loss': 100.00}

@tf.function
def train_step(support, query):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  train_loss(loss)
  train_acc(acc)

for epoch in range(num_epochs):
  train_loss.reset_states()
  train_acc.reset_states()

  for episode in range(num_episodes):
    train_support, train_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
    train_step(train_support, train_query)

  cur_loss = train_loss.result().numpy()
  if cur_loss < least_loss['least_loss']:
      print("Saving new best model with loss: ", cur_loss)
      least_loss['least_loss'] = cur_loss
      model.save(save_path)
  
  template = 'Epoch {}, Loss: {}, Accuracy: {}'
  print(template.format(epoch + 1, train_loss.result(), train_acc.result() * 100))



Saving new best model with loss:  1.6227574
Epoch 1, Loss: 1.6227574348449707, Accuracy: 62.99502182006836




Saving new best model with loss:  0.61009717
Epoch 2, Loss: 0.6100971698760986, Accuracy: 83.37998962402344




Saving new best model with loss:  0.40875885
Epoch 3, Loss: 0.40875884890556335, Accuracy: 88.61997985839844




Saving new best model with loss:  0.2541982
Epoch 4, Loss: 0.25419819355010986, Accuracy: 92.61499786376953




Saving new best model with loss:  0.2009862
Epoch 5, Loss: 0.20098620653152466, Accuracy: 94.19498443603516




Saving new best model with loss:  0.17222165
Epoch 6, Loss: 0.17222164571285248, Accuracy: 95.19499969482422




Saving new best model with loss:  0.15148762
Epoch 7, Loss: 0.15148761868476868, Accuracy: 95.51998901367188




Saving new best model with loss:  0.13440822
Epoch 8, Loss: 0.13440822064876556, Accuracy: 95.95999908447266




Saving new best model with loss:  0.1269744
Epoch 9, Loss: 0.12697440385818481, Accuracy: 96.48501586914062
Epoch 10, Loss: 0.13051262497901917, Accuracy: 96.43999481201172
Epoch 11, Loss: 0.1296151578426361, Accuracy: 96.38500213623047




Saving new best model with loss:  0.110336736
Epoch 12, Loss: 0.11033673584461212, Accuracy: 96.8399887084961
Epoch 13, Loss: 0.1169830709695816, Accuracy: 96.63497161865234




Saving new best model with loss:  0.10468605
Epoch 14, Loss: 0.10468605160713196, Accuracy: 96.9999771118164
Epoch 15, Loss: 0.10516892373561859, Accuracy: 96.9749755859375




Saving new best model with loss:  0.095324956
Epoch 16, Loss: 0.09532495588064194, Accuracy: 97.24000549316406




Saving new best model with loss:  0.094640724
Epoch 17, Loss: 0.09464072436094284, Accuracy: 97.19998931884766




Saving new best model with loss:  0.09175509
Epoch 18, Loss: 0.09175509214401245, Accuracy: 97.33002471923828
Epoch 19, Loss: 0.09726627171039581, Accuracy: 97.13497161865234




Saving new best model with loss:  0.08630772
Epoch 20, Loss: 0.08630771934986115, Accuracy: 97.48001861572266




Saving new best model with loss:  0.08023305
Epoch 21, Loss: 0.08023305237293243, Accuracy: 97.60999298095703
Epoch 22, Loss: 0.09092296659946442, Accuracy: 97.37499237060547
Epoch 23, Loss: 0.08091194182634354, Accuracy: 97.6449966430664
Epoch 24, Loss: 0.08320652693510056, Accuracy: 97.59502410888672




Saving new best model with loss:  0.076925345
Epoch 25, Loss: 0.07692534476518631, Accuracy: 97.79000854492188




Saving new best model with loss:  0.069000535
Epoch 26, Loss: 0.06900053471326828, Accuracy: 97.93502807617188




Saving new best model with loss:  0.06650753
Epoch 27, Loss: 0.06650753319263458, Accuracy: 98.09000396728516
Epoch 28, Loss: 0.07391633093357086, Accuracy: 97.7750244140625




Saving new best model with loss:  0.065626495
Epoch 29, Loss: 0.06562649458646774, Accuracy: 98.10501861572266
Epoch 30, Loss: 0.0714334174990654, Accuracy: 97.90502166748047




Saving new best model with loss:  0.06554047
Epoch 31, Loss: 0.06554047018289566, Accuracy: 98.1250228881836
Epoch 32, Loss: 0.06598136574029922, Accuracy: 98.02501678466797




Saving new best model with loss:  0.06062045
Epoch 33, Loss: 0.06062044948339462, Accuracy: 98.2600326538086
Epoch 34, Loss: 0.0613446943461895, Accuracy: 98.23001098632812
Epoch 35, Loss: 0.06438524276018143, Accuracy: 98.24504089355469
Epoch 36, Loss: 0.062764011323452, Accuracy: 98.14000701904297




Saving new best model with loss:  0.056746468
Epoch 37, Loss: 0.0567464679479599, Accuracy: 98.24996948242188
Epoch 38, Loss: 0.05694327875971794, Accuracy: 98.28002166748047
Epoch 39, Loss: 0.06123363599181175, Accuracy: 98.27503204345703
Epoch 40, Loss: 0.060527220368385315, Accuracy: 98.1600341796875




Saving new best model with loss:  0.056584787
Epoch 41, Loss: 0.056584786623716354, Accuracy: 98.2300033569336
Epoch 42, Loss: 0.06263496726751328, Accuracy: 98.12500762939453
Epoch 43, Loss: 0.057165659964084625, Accuracy: 98.30001068115234




Saving new best model with loss:  0.056099433
Epoch 44, Loss: 0.056099433451890945, Accuracy: 98.14501190185547




Saving new best model with loss:  0.051375803
Epoch 45, Loss: 0.05137580260634422, Accuracy: 98.48001098632812
Epoch 46, Loss: 0.057831112295389175, Accuracy: 98.40003204345703




Saving new best model with loss:  0.04980312
Epoch 47, Loss: 0.04980311915278435, Accuracy: 98.36998748779297
Epoch 48, Loss: 0.052008915692567825, Accuracy: 98.33001708984375




Saving new best model with loss:  0.04547982
Epoch 49, Loss: 0.04547981917858124, Accuracy: 98.56999969482422
Epoch 50, Loss: 0.04967110604047775, Accuracy: 98.50501251220703
Epoch 51, Loss: 0.04874469339847565, Accuracy: 98.5250244140625
Epoch 52, Loss: 0.04835718125104904, Accuracy: 98.4300308227539




Saving new best model with loss:  0.043321237
Epoch 53, Loss: 0.043321236968040466, Accuracy: 98.59502410888672
Epoch 54, Loss: 0.04388178884983063, Accuracy: 98.64500427246094




Saving new best model with loss:  0.04258868
Epoch 55, Loss: 0.04258868098258972, Accuracy: 98.73004913330078




Saving new best model with loss:  0.03957807
Epoch 56, Loss: 0.039578069001436234, Accuracy: 98.7550277709961
Epoch 57, Loss: 0.04682217910885811, Accuracy: 98.50003814697266
Epoch 58, Loss: 0.04339703172445297, Accuracy: 98.73003387451172
Epoch 59, Loss: 0.044392503798007965, Accuracy: 98.5150375366211
Epoch 60, Loss: 0.042854052037000656, Accuracy: 98.67002868652344




Saving new best model with loss:  0.037011
Epoch 61, Loss: 0.037011001259088516, Accuracy: 98.73001098632812
Epoch 62, Loss: 0.0456576868891716, Accuracy: 98.59501647949219
Epoch 63, Loss: 0.039380885660648346, Accuracy: 98.69503021240234
Epoch 64, Loss: 0.04404246434569359, Accuracy: 98.6750259399414
Epoch 65, Loss: 0.04271294176578522, Accuracy: 98.61502838134766




Saving new best model with loss:  0.03624097
Epoch 66, Loss: 0.036240968853235245, Accuracy: 98.7650146484375
Epoch 67, Loss: 0.03726862370967865, Accuracy: 98.78002166748047
Epoch 68, Loss: 0.03950682654976845, Accuracy: 98.66002655029297




Saving new best model with loss:  0.036107734
Epoch 69, Loss: 0.036107733845710754, Accuracy: 98.85003662109375
Epoch 70, Loss: 0.0373772457242012, Accuracy: 98.635009765625
Epoch 71, Loss: 0.037399400025606155, Accuracy: 98.81500244140625




Saving new best model with loss:  0.034875084
Epoch 72, Loss: 0.03487508371472359, Accuracy: 98.88504028320312
Epoch 73, Loss: 0.04054630175232887, Accuracy: 98.74002838134766




Saving new best model with loss:  0.032696187
Epoch 74, Loss: 0.0326961874961853, Accuracy: 98.95502471923828
Epoch 75, Loss: 0.035186897963285446, Accuracy: 98.81002807617188
Epoch 76, Loss: 0.034691277891397476, Accuracy: 98.7550277709961
Epoch 77, Loss: 0.034321557730436325, Accuracy: 98.96003723144531




Saving new best model with loss:  0.031905357
Epoch 78, Loss: 0.03190535679459572, Accuracy: 98.91503143310547
Epoch 79, Loss: 0.03580688312649727, Accuracy: 98.86001586914062
Epoch 80, Loss: 0.03340233862400055, Accuracy: 98.90503692626953


In [16]:
#Run prototypical model with only training set
# This will be using the fourth values from the lists for the experiment
num_epochs = 80
num_episodes = 100
save_path = "./results/models/omniglot_train3.h5"

train_loss = tf.metrics.Mean(name='train_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')


#number of classes
num_way = train_num_ways[3] 

#number of examples per class for support set
num_shot = num_shots[3]  

#number of query points
num_query = num_shots[3] 

least_loss = {'least_loss': 100.00}

support = np.zeros([num_way, num_shot, img_width, img_height, channels], dtype=np.float32)
query = np.zeros([num_way, num_query, img_width, img_height, channels], dtype=np.float32)
model = Prototypical(support, query, img_width, img_height, channels)
optimizer = tf.keras.optimizers.Adam(learning_rate)

@tf.function
def loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

least_loss = {'least_loss': 100.00}

@tf.function
def train_step(support, query):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  train_loss(loss)
  train_acc(acc)

for epoch in range(num_epochs):
  train_loss.reset_states()
  train_acc.reset_states()

  for episode in range(num_episodes):
    train_support, train_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
    train_step(train_support, train_query)

  cur_loss = train_loss.result().numpy()
  if cur_loss < least_loss['least_loss']:
      print("Saving new best model with loss: ", cur_loss)
      least_loss['least_loss'] = cur_loss
      model.save(save_path)
  
  template = 'Epoch {}, Loss: {}, Accuracy: {}'
  print(template.format(epoch + 1, train_loss.result(), train_acc.result() * 100))



Saving new best model with loss:  2.4195518
Epoch 1, Loss: 2.4195518493652344, Accuracy: 41.549991607666016




Saving new best model with loss:  1.5646489
Epoch 2, Loss: 1.5646488666534424, Accuracy: 58.124996185302734




Saving new best model with loss:  1.197028
Epoch 3, Loss: 1.1970280408859253, Accuracy: 68.25001525878906




Saving new best model with loss:  0.93336415
Epoch 4, Loss: 0.9333641529083252, Accuracy: 74.8499984741211




Saving new best model with loss:  0.7554164
Epoch 5, Loss: 0.7554163932800293, Accuracy: 78.92499542236328




Saving new best model with loss:  0.6486669
Epoch 6, Loss: 0.6486669182777405, Accuracy: 81.69998931884766




Saving new best model with loss:  0.53964937
Epoch 7, Loss: 0.5396493673324585, Accuracy: 83.95000457763672
Epoch 8, Loss: 0.5504009127616882, Accuracy: 84.42501068115234




Saving new best model with loss:  0.53839815
Epoch 9, Loss: 0.5383981466293335, Accuracy: 84.67500305175781




Saving new best model with loss:  0.4961293
Epoch 10, Loss: 0.4961293041706085, Accuracy: 85.39999389648438




Saving new best model with loss:  0.46690175
Epoch 11, Loss: 0.4669017493724823, Accuracy: 85.87499237060547




Saving new best model with loss:  0.4231772
Epoch 12, Loss: 0.42317721247673035, Accuracy: 87.47498321533203
Epoch 13, Loss: 0.4266640841960907, Accuracy: 87.5




Saving new best model with loss:  0.41178444
Epoch 14, Loss: 0.41178444027900696, Accuracy: 88.15000915527344
Epoch 15, Loss: 0.42782458662986755, Accuracy: 87.82500457763672




Saving new best model with loss:  0.39471427
Epoch 16, Loss: 0.39471426606178284, Accuracy: 88.34999084472656
Epoch 17, Loss: 0.40809595584869385, Accuracy: 88.42501068115234




Saving new best model with loss:  0.3613637
Epoch 18, Loss: 0.3613637089729309, Accuracy: 89.12499237060547
Epoch 19, Loss: 0.36413004994392395, Accuracy: 90.12500762939453




Saving new best model with loss:  0.34734336
Epoch 20, Loss: 0.3473433554172516, Accuracy: 89.67500305175781




Saving new best model with loss:  0.32184768
Epoch 21, Loss: 0.32184767723083496, Accuracy: 90.07500457763672
Epoch 22, Loss: 0.3282147943973541, Accuracy: 90.32500457763672
Epoch 23, Loss: 0.3235490024089813, Accuracy: 90.85002136230469




Saving new best model with loss:  0.32103541
Epoch 24, Loss: 0.3210354149341583, Accuracy: 91.17499542236328
Epoch 25, Loss: 0.330936998128891, Accuracy: 90.6500015258789
Epoch 26, Loss: 0.32598984241485596, Accuracy: 91.30000305175781




Saving new best model with loss:  0.2809907
Epoch 27, Loss: 0.28099068999290466, Accuracy: 91.85000610351562




Saving new best model with loss:  0.2762402
Epoch 28, Loss: 0.27624019980430603, Accuracy: 91.92495727539062
Epoch 29, Loss: 0.2905048131942749, Accuracy: 92.37501525878906
Epoch 30, Loss: 0.2983030378818512, Accuracy: 91.54997253417969
Epoch 31, Loss: 0.2829934358596802, Accuracy: 91.82502746582031
Epoch 32, Loss: 0.2817455530166626, Accuracy: 91.85002136230469




Saving new best model with loss:  0.2520608
Epoch 33, Loss: 0.25206080079078674, Accuracy: 92.67498779296875
Epoch 34, Loss: 0.25306397676467896, Accuracy: 92.30001068115234




Saving new best model with loss:  0.24240406
Epoch 35, Loss: 0.2424040585756302, Accuracy: 92.45002746582031
Epoch 36, Loss: 0.2451450526714325, Accuracy: 92.50003051757812
Epoch 37, Loss: 0.2536834180355072, Accuracy: 92.94998931884766
Epoch 38, Loss: 0.2831035256385803, Accuracy: 91.64998626708984
Epoch 39, Loss: 0.24772199988365173, Accuracy: 92.87498474121094




Saving new best model with loss:  0.23566636
Epoch 40, Loss: 0.23566636443138123, Accuracy: 92.59998321533203
Epoch 41, Loss: 0.25991758704185486, Accuracy: 92.15000915527344
Epoch 42, Loss: 0.24799737334251404, Accuracy: 92.625




Saving new best model with loss:  0.22769514
Epoch 43, Loss: 0.22769513726234436, Accuracy: 93.72500610351562
Epoch 44, Loss: 0.2522781491279602, Accuracy: 92.39997863769531




Saving new best model with loss:  0.2263587
Epoch 45, Loss: 0.22635869681835175, Accuracy: 93.7750015258789




Saving new best model with loss:  0.21088253
Epoch 46, Loss: 0.2108825296163559, Accuracy: 93.12498474121094
Epoch 47, Loss: 0.21524576842784882, Accuracy: 93.52497100830078
Epoch 48, Loss: 0.24230070412158966, Accuracy: 92.64997863769531
Epoch 49, Loss: 0.2222149521112442, Accuracy: 93.12499237060547
Epoch 50, Loss: 0.2392786145210266, Accuracy: 92.75001525878906




Saving new best model with loss:  0.19183627
Epoch 51, Loss: 0.19183626770973206, Accuracy: 94.32498931884766
Epoch 52, Loss: 0.21516595780849457, Accuracy: 93.57498168945312
Epoch 53, Loss: 0.22133274376392365, Accuracy: 92.94996643066406




Saving new best model with loss:  0.18847893
Epoch 54, Loss: 0.18847893178462982, Accuracy: 94.22498321533203
Epoch 55, Loss: 0.21808242797851562, Accuracy: 93.49995422363281
Epoch 56, Loss: 0.20452998578548431, Accuracy: 93.72500610351562
Epoch 57, Loss: 0.20025140047073364, Accuracy: 94.07499694824219
Epoch 58, Loss: 0.20591992139816284, Accuracy: 93.64998626708984
Epoch 59, Loss: 0.20623888075351715, Accuracy: 93.9749984741211
Epoch 60, Loss: 0.19467797875404358, Accuracy: 94.39998626708984




Saving new best model with loss:  0.18176337
Epoch 61, Loss: 0.18176336586475372, Accuracy: 94.02497100830078
Epoch 62, Loss: 0.18789023160934448, Accuracy: 94.60000610351562
Epoch 63, Loss: 0.1919507533311844, Accuracy: 94.59999084472656
Epoch 64, Loss: 0.18456722795963287, Accuracy: 94.67495727539062




Saving new best model with loss:  0.1750999
Epoch 65, Loss: 0.17509989440441132, Accuracy: 94.5999755859375




Saving new best model with loss:  0.15712206
Epoch 66, Loss: 0.15712206065654755, Accuracy: 95.57496643066406
Epoch 67, Loss: 0.17420987784862518, Accuracy: 94.52496337890625
Epoch 68, Loss: 0.18387925624847412, Accuracy: 94.87496948242188
Epoch 69, Loss: 0.17550747096538544, Accuracy: 94.5999755859375
Epoch 70, Loss: 0.17219987511634827, Accuracy: 94.47496795654297
Epoch 71, Loss: 0.18452712893486023, Accuracy: 94.57498168945312
Epoch 72, Loss: 0.19801875948905945, Accuracy: 94.59998321533203
Epoch 73, Loss: 0.17818252742290497, Accuracy: 94.50000762939453
Epoch 74, Loss: 0.19622284173965454, Accuracy: 93.90000915527344
Epoch 75, Loss: 0.17716480791568756, Accuracy: 94.5999755859375
Epoch 76, Loss: 0.1686609983444214, Accuracy: 94.69996643066406
Epoch 77, Loss: 0.19023028016090393, Accuracy: 94.74999237060547
Epoch 78, Loss: 0.16630807518959045, Accuracy: 95.19995880126953
Epoch 79, Loss: 0.1677059680223465, Accuracy: 94.94998168945312
Epoch 80, Loss: 0.16286741197109222, Accuracy: 95

In [17]:
save_paths = ["./results/models/omniglot_train0.h5", "./results/models/omniglot_train1.h5", "./results/models/omniglot_train2.h5", "./results/models/omniglot_train3.h5"]

In [18]:
def calc_loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

In [19]:
accuracies_proto_0 = {}
for save_path in save_paths:
  model_path = save_path
  model.load(model_path)
  print("Model loaded.")
  num_episodes = 1000

  #number of classes
  num_way = test_num_ways[0]

  #number of examples per class for support set
  num_shot = num_shots[0]  

  #number of query points
  num_query = num_shots[0] 

  # Metrics to gather
  test_loss = tf.metrics.Mean(name='test_loss')
  test_acc = tf.metrics.Mean(name='test_accuracy')


  for i_episode in range(num_episodes):
    test_support, test_query = get_next_episode(test_dataset, num_way, num_shot, num_shot, no_test_classes)
    loss, acc = calc_loss(test_support, test_query)
    test_loss(loss)
    test_acc(acc)

  loss = test_loss.result().numpy()
  accuracy = test_acc.result().numpy() * 100
  print("Loss: ", loss)
  print("Accuracy: ", accuracy)
  accuracies_proto_0[model_path] = accuracy


Model loaded.
Loss:  0.01443269
Accuracy:  99.65207576751709
Model loaded.
Loss:  0.017238224
Accuracy:  99.6001124382019
Model loaded.
Loss:  0.022710524
Accuracy:  99.52009320259094
Model loaded.
Loss:  0.017339597
Accuracy:  99.56807494163513


In [20]:
accuracies_proto_1 = {}
for save_path in save_paths:
  model_path = save_path
  model.load(model_path)
  print("Model with path {} loaded.".format(save_path))
  num_episodes = 1000

  #number of classes
  num_way = test_num_ways[1]

  #number of examples per class for support set
  num_shot = num_shots[1]  

  #number of query points
  num_query = num_shots[1] 

  # Metrics to gather
  test_loss = tf.metrics.Mean(name='test_loss')
  test_acc = tf.metrics.Mean(name='test_accuracy')


  for i_episode in range(num_episodes):
    test_support, test_query = get_next_episode(test_dataset, num_way, num_shot, num_shot, no_test_classes)
    loss, acc = calc_loss(test_support, test_query)
    test_loss(loss)
    test_acc(acc)

  loss = test_loss.result().numpy()
  accuracy = test_acc.result().numpy() * 100
  print("Loss: ", loss)
  print("Accuracy: ", accuracy)
  accuracies_proto_1[model_path] = accuracy


Model with path ./results/models/omniglot_train0.h5 loaded.
Loss:  0.10468974
Accuracy:  97.63989448547363
Model with path ./results/models/omniglot_train1.h5 loaded.
Loss:  0.064977564
Accuracy:  98.13992977142334
Model with path ./results/models/omniglot_train2.h5 loaded.
Loss:  0.122516006
Accuracy:  97.03988432884216
Model with path ./results/models/omniglot_train3.h5 loaded.
Loss:  0.06688529
Accuracy:  98.11992049217224


In [21]:
accuracies_proto_2 = {}
for save_path in save_paths:
  model_path = save_path
  model.load(model_path)

  num_episodes = 1000

  #number of classes
  num_way = test_num_ways[2]

  #number of examples per class for support set
  num_shot = num_shots[2]  

  #number of query points
  num_query = num_shots[2] 

  # Metrics to gather
  test_loss = tf.metrics.Mean(name='test_loss')
  test_acc = tf.metrics.Mean(name='test_accuracy')

  print("Testing {} way {} shot".format(num_way, num_shot))
  print("Model with path {} loaded.".format(save_path))

  for i_episode in range(num_episodes):
    test_support, test_query = get_next_episode(test_dataset, num_way, num_shot, num_shot, no_test_classes)
    loss, acc = calc_loss(test_support, test_query)
    test_loss(loss)
    test_acc(acc)

  loss = test_loss.result().numpy()
  accuracy = test_acc.result().numpy() * 100
  print("Loss: ", loss)
  print("Accuracy: ", accuracy)
  accuracies_proto_2[model_path] = accuracy

Model with path ./results/models/omniglot_train0.h5 loaded.
Loss:  0.049333546
Accuracy:  98.63952994346619
Model with path ./results/models/omniglot_train1.h5 loaded.
Loss:  0.050810993
Accuracy:  98.72457981109619
Model with path ./results/models/omniglot_train2.h5 loaded.
Loss:  0.060099393
Accuracy:  98.33952188491821
Model with path ./results/models/omniglot_train3.h5 loaded.
Loss:  0.059270512
Accuracy:  98.42256903648376


In [22]:
accuracies_proto_3 = {}

def calc_loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

for save_path in save_paths:
  model_path = save_path
  model.load(model_path)
  print("Model with path {} loaded.".format(save_path))
  num_episodes = 1000

  #number of classes
  num_way = test_num_ways[3]

  #number of examples per class for support set
  num_shot = num_shots[3]  

  #number of query points
  num_query = num_shots[3] 

  # Metrics to gather
  test_loss = tf.metrics.Mean(name='test_loss')
  test_acc = tf.metrics.Mean(name='test_accuracy')


  for i_episode in range(num_episodes):
    test_support, test_query = get_next_episode(test_dataset, num_way, num_shot, num_shot, no_test_classes)
    loss, acc = calc_loss(test_support, test_query)
    test_loss(loss)
    test_acc(acc)

  loss = test_loss.result().numpy()
  accuracy = test_acc.result().numpy() * 100
  print("Loss: ", loss)
  print("Accuracy: ", accuracy)
  accuracies_proto_3[model_path] = accuracy

Model with path ./results/models/omniglot_train0.h5 loaded.
Loss:  0.3221269
Accuracy:  92.59039759635925
Model with path ./results/models/omniglot_train1.h5 loaded.
Loss:  0.16637978
Accuracy:  95.2004075050354
Model with path ./results/models/omniglot_train2.h5 loaded.
Loss:  0.3385203
Accuracy:  91.7953372001648
Model with path ./results/models/omniglot_train3.h5 loaded.
Loss:  0.23058069
Accuracy:  93.64041090011597


In [23]:
#Combi Proto + Reptile for first values from the list ie 60-way 5-shot experiment

meta_step_size = 0.25

#Interval between running SGD on the validation dataset
eval_interval = 4

#number of classes
num_way = train_num_ways[0]

#number of examples per class for support set
num_shot = num_shots[0]  

#number of query points
num_query = num_shots[0] 


train_loss = tf.metrics.Mean(name='train_loss')
val_loss = tf.metrics.Mean(name='val_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')
val_acc = tf.metrics.Mean(name='val_accuracy')
support = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
query = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
model = Prototypical(support, query, img_height, img_width, channels)
optimizer_adam = tf.keras.optimizers.Adam(learning_rate, beta_1=0)
optimizer_sgd = tf.keras.optimizers.SGD(learning_rate)

num_epochs = 241
num_episodes = 100
save_path = "./results/models/omniglot_train_reptile0.h5"

@tf.function
def loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

@tf.function
def train_step(support, query, optimizer):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  
  train_loss(loss)
  train_acc(acc)
  

@tf.function
def val_step(support, query, optimizer):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  
  val_loss(loss)
  val_acc(acc)


least_loss = {'least_loss': 100.00}
for epoch in range(num_epochs):
  train_loss.reset_states()
  val_loss.reset_states()
  train_acc.reset_states()
  val_acc.reset_states()

  for episode in range(num_episodes):
    frac_done = episode / num_episodes
    cur_meta_step_size = (1 - frac_done) * meta_step_size
    train_support, train_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
    
    model = Prototypical(train_support, train_query, img_width, img_height, channels)
    model.call(train_support, train_query)
    old_weights = model.get_weights()
    train_step(train_support, train_query, optimizer_sgd)
    new_weights = model.get_weights()

    for part_weight in range(len(new_weights)):
        new_weights[part_weight] = old_weights[part_weight] + (
            (new_weights[part_weight] - old_weights[part_weight]) * cur_meta_step_size
        )

    model.set_weights(new_weights)
    if epoch % eval_interval == 0:
      eval_support, eval_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
      old_vars = model.get_weights()
      val_step(eval_support, eval_query, optimizer_adam)
      model.set_weights(old_vars)
      
  cur_loss = val_loss.result().numpy()

  if (epoch % eval_interval == 0):
    template = 'Epoch {}, Loss: {}, Accuracy: {}, ' \
                  'Val Loss: {}, Val Accuracy: {}'
    print(template.format(epoch + 1, train_loss.result(), train_acc.result() * 100, val_loss.result(),
                            val_acc.result() * 100))
    if cur_loss < least_loss['least_loss']:
      print("Saving new best model with loss: ", cur_loss)
      least_loss['least_loss'] = cur_loss
      model.save(save_path)



Epoch 1, Loss: 1.5062423944473267, Accuracy: 65.39667510986328, Val Loss: 1.534399390220642, Val Accuracy: 65.55000305175781
Saving new best model with loss:  1.5343994




Epoch 5, Loss: 0.5645511746406555, Accuracy: 84.56001281738281, Val Loss: 0.5675460696220398, Val Accuracy: 84.68998718261719
Saving new best model with loss:  0.56754607




Epoch 9, Loss: 0.3253828287124634, Accuracy: 90.81998443603516, Val Loss: 0.33602410554885864, Val Accuracy: 90.38001251220703
Saving new best model with loss:  0.3360241




Epoch 13, Loss: 0.25109007954597473, Accuracy: 92.83332061767578, Val Loss: 0.2403283715248108, Val Accuracy: 93.19664764404297
Saving new best model with loss:  0.24032837




Epoch 17, Loss: 0.21814675629138947, Accuracy: 93.75997924804688, Val Loss: 0.2108759880065918, Val Accuracy: 93.95664978027344
Saving new best model with loss:  0.21087599




Epoch 21, Loss: 0.1934673935174942, Accuracy: 94.67998504638672, Val Loss: 0.18680429458618164, Val Accuracy: 94.61332702636719
Saving new best model with loss:  0.1868043




Epoch 25, Loss: 0.17535510659217834, Accuracy: 94.96666717529297, Val Loss: 0.17445963621139526, Val Accuracy: 95.13663482666016
Saving new best model with loss:  0.17445964




Epoch 29, Loss: 0.17094850540161133, Accuracy: 95.0799560546875, Val Loss: 0.17354947328567505, Val Accuracy: 94.94664001464844
Saving new best model with loss:  0.17354947




Epoch 33, Loss: 0.14992772042751312, Accuracy: 95.50666046142578, Val Loss: 0.15987858176231384, Val Accuracy: 95.49665832519531
Saving new best model with loss:  0.15987858




Epoch 37, Loss: 0.14858967065811157, Accuracy: 95.72665405273438, Val Loss: 0.1533307433128357, Val Accuracy: 95.40998077392578
Saving new best model with loss:  0.15333074




Epoch 41, Loss: 0.1381675750017166, Accuracy: 95.9333267211914, Val Loss: 0.1428695172071457, Val Accuracy: 95.99998474121094
Saving new best model with loss:  0.14286952




Epoch 45, Loss: 0.14077039062976837, Accuracy: 95.90333557128906, Val Loss: 0.13679207861423492, Val Accuracy: 96.17334747314453
Saving new best model with loss:  0.13679208




Epoch 49, Loss: 0.12903937697410583, Accuracy: 96.43331909179688, Val Loss: 0.13245835900306702, Val Accuracy: 96.07999420166016
Saving new best model with loss:  0.13245836




Epoch 53, Loss: 0.13278600573539734, Accuracy: 96.12664794921875, Val Loss: 0.12212850898504257, Val Accuracy: 96.36666107177734
Saving new best model with loss:  0.12212851




Epoch 57, Loss: 0.11669649928808212, Accuracy: 96.58000946044922, Val Loss: 0.12099897116422653, Val Accuracy: 96.39334869384766
Saving new best model with loss:  0.12099897
Epoch 61, Loss: 0.10465839505195618, Accuracy: 96.72999572753906, Val Loss: 0.1215125247836113, Val Accuracy: 96.58667755126953




Epoch 65, Loss: 0.11495956778526306, Accuracy: 96.56665802001953, Val Loss: 0.11471110582351685, Val Accuracy: 96.72001647949219
Saving new best model with loss:  0.114711106




Epoch 69, Loss: 0.10859204083681107, Accuracy: 96.9466781616211, Val Loss: 0.11390692740678787, Val Accuracy: 96.75334167480469
Saving new best model with loss:  0.11390693




Epoch 73, Loss: 0.10960789024829865, Accuracy: 96.836669921875, Val Loss: 0.10245170444250107, Val Accuracy: 96.96332550048828
Saving new best model with loss:  0.102451704
Epoch 77, Loss: 0.10231249034404755, Accuracy: 96.9767074584961, Val Loss: 0.10789843648672104, Val Accuracy: 97.00000762939453
Epoch 81, Loss: 0.10448313504457474, Accuracy: 97.00001525878906, Val Loss: 0.10534705966711044, Val Accuracy: 96.86000061035156




Epoch 85, Loss: 0.0963558480143547, Accuracy: 97.16666412353516, Val Loss: 0.09785986691713333, Val Accuracy: 97.16333770751953
Saving new best model with loss:  0.09785987




Epoch 89, Loss: 0.09173659980297089, Accuracy: 97.36668395996094, Val Loss: 0.08927443623542786, Val Accuracy: 97.3933334350586
Saving new best model with loss:  0.089274436




Epoch 93, Loss: 0.08714421093463898, Accuracy: 97.47665405273438, Val Loss: 0.08654215931892395, Val Accuracy: 97.42667388916016
Saving new best model with loss:  0.08654216
Epoch 97, Loss: 0.09014296531677246, Accuracy: 97.35668182373047, Val Loss: 0.09032967686653137, Val Accuracy: 97.3166732788086
Epoch 101, Loss: 0.09032999724149704, Accuracy: 97.34000396728516, Val Loss: 0.09109177440404892, Val Accuracy: 97.27002716064453




Epoch 105, Loss: 0.07244934886693954, Accuracy: 97.81000518798828, Val Loss: 0.085663340985775, Val Accuracy: 97.5300064086914
Saving new best model with loss:  0.08566334




Epoch 109, Loss: 0.07231861352920532, Accuracy: 97.816650390625, Val Loss: 0.07462868839502335, Val Accuracy: 97.85668182373047
Saving new best model with loss:  0.07462869
Epoch 113, Loss: 0.08488763123750687, Accuracy: 97.39998626708984, Val Loss: 0.08205202221870422, Val Accuracy: 97.61998748779297
Epoch 117, Loss: 0.07680276781320572, Accuracy: 97.71333312988281, Val Loss: 0.08500378578901291, Val Accuracy: 97.41665649414062




Epoch 121, Loss: 0.07930541038513184, Accuracy: 97.77332305908203, Val Loss: 0.0705096423625946, Val Accuracy: 97.89002990722656
Saving new best model with loss:  0.07050964
Epoch 125, Loss: 0.06691918522119522, Accuracy: 97.9732894897461, Val Loss: 0.07690373063087463, Val Accuracy: 97.75666046142578
Epoch 129, Loss: 0.07669711858034134, Accuracy: 97.586669921875, Val Loss: 0.07658062130212784, Val Accuracy: 97.69666290283203
Epoch 133, Loss: 0.07282193005084991, Accuracy: 97.82664489746094, Val Loss: 0.07250259071588516, Val Accuracy: 97.72334289550781
Epoch 137, Loss: 0.07340239733457565, Accuracy: 97.74667358398438, Val Loss: 0.07449744641780853, Val Accuracy: 97.66666412353516
Epoch 141, Loss: 0.07206247001886368, Accuracy: 97.87334442138672, Val Loss: 0.07100342214107513, Val Accuracy: 97.87999725341797




Epoch 145, Loss: 0.07218322157859802, Accuracy: 97.91999816894531, Val Loss: 0.06482996046543121, Val Accuracy: 97.9666748046875
Saving new best model with loss:  0.06482996
Epoch 149, Loss: 0.06644316017627716, Accuracy: 97.90666198730469, Val Loss: 0.06819895654916763, Val Accuracy: 97.97669982910156




Epoch 153, Loss: 0.07064341753721237, Accuracy: 97.89998626708984, Val Loss: 0.06364583969116211, Val Accuracy: 97.93668365478516
Saving new best model with loss:  0.06364584
Epoch 157, Loss: 0.06660594791173935, Accuracy: 97.8266830444336, Val Loss: 0.06763284653425217, Val Accuracy: 97.90003204345703




Epoch 161, Loss: 0.06180991977453232, Accuracy: 98.05999755859375, Val Loss: 0.059995800256729126, Val Accuracy: 98.20999908447266
Saving new best model with loss:  0.0599958




Epoch 165, Loss: 0.05534041300415993, Accuracy: 98.32000732421875, Val Loss: 0.05855554714798927, Val Accuracy: 98.13998413085938
Saving new best model with loss:  0.058555547
Epoch 169, Loss: 0.05830049142241478, Accuracy: 98.12996673583984, Val Loss: 0.05952426418662071, Val Accuracy: 98.10997009277344




Epoch 173, Loss: 0.06476714462041855, Accuracy: 98.04664611816406, Val Loss: 0.05787741765379906, Val Accuracy: 98.11000061035156
Saving new best model with loss:  0.057877418




Epoch 177, Loss: 0.05526799336075783, Accuracy: 98.25332641601562, Val Loss: 0.0575527586042881, Val Accuracy: 98.23998260498047
Saving new best model with loss:  0.05755276
Epoch 181, Loss: 0.05736425891518593, Accuracy: 98.15664672851562, Val Loss: 0.06108621135354042, Val Accuracy: 98.10002136230469
Epoch 185, Loss: 0.0550750195980072, Accuracy: 98.20999145507812, Val Loss: 0.06354694813489914, Val Accuracy: 97.97332000732422




Epoch 189, Loss: 0.054399359971284866, Accuracy: 98.22333526611328, Val Loss: 0.05400128662586212, Val Accuracy: 98.35999298095703
Saving new best model with loss:  0.054001287




Epoch 193, Loss: 0.054537512362003326, Accuracy: 98.37665557861328, Val Loss: 0.0524386502802372, Val Accuracy: 98.36331176757812
Saving new best model with loss:  0.05243865
Epoch 197, Loss: 0.053780246526002884, Accuracy: 98.28997802734375, Val Loss: 0.05883496254682541, Val Accuracy: 98.07667541503906
Epoch 201, Loss: 0.05303064361214638, Accuracy: 98.3299560546875, Val Loss: 0.055257946252822876, Val Accuracy: 98.29666137695312




Epoch 205, Loss: 0.052976205945014954, Accuracy: 98.39665985107422, Val Loss: 0.04552629962563515, Val Accuracy: 98.51995849609375
Saving new best model with loss:  0.0455263
Epoch 209, Loss: 0.051373306661844254, Accuracy: 98.35332489013672, Val Loss: 0.05278395116329193, Val Accuracy: 98.2933349609375
Epoch 213, Loss: 0.05315212160348892, Accuracy: 98.23663330078125, Val Loss: 0.05116524547338486, Val Accuracy: 98.29663848876953
Epoch 217, Loss: 0.04985051229596138, Accuracy: 98.48330688476562, Val Loss: 0.04787299782037735, Val Accuracy: 98.36997985839844
Epoch 221, Loss: 0.048713792115449905, Accuracy: 98.30663299560547, Val Loss: 0.0480974055826664, Val Accuracy: 98.42664337158203
Epoch 225, Loss: 0.049977127462625504, Accuracy: 98.37332153320312, Val Loss: 0.05241459235548973, Val Accuracy: 98.35330200195312
Epoch 229, Loss: 0.050369992852211, Accuracy: 98.35665893554688, Val Loss: 0.046472545713186264, Val Accuracy: 98.47332763671875
Epoch 233, Loss: 0.04515768587589264, Accurac

In [24]:
#Combi Proto + Reptile for second values from the list ie 60-way 1-shot experiment

meta_step_size = 0.25

#Interval between running SGD on the validation dataset
eval_interval = 4

#number of classes
num_way = train_num_ways[1]

#number of examples per class for support set
num_shot = num_shots[1]  

#number of query points
num_query = num_shots[1] 


train_loss = tf.metrics.Mean(name='train_loss')
val_loss = tf.metrics.Mean(name='val_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')
val_acc = tf.metrics.Mean(name='val_accuracy')
support = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
query = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
model = Prototypical(support, query, img_height, img_width, channels)
optimizer_adam = tf.keras.optimizers.Adam(learning_rate, beta_1=0)
optimizer_sgd = tf.keras.optimizers.SGD(learning_rate)

num_epochs = 241
num_episodes = 100
save_path = "./results/models/omniglot_train_reptile1.h5"

@tf.function
def loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

@tf.function
def train_step(support, query, optimizer):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  
  train_loss(loss)
  train_acc(acc)
  

@tf.function
def val_step(support, query, optimizer):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  
  val_loss(loss)
  val_acc(acc)


least_loss = {'least_loss': 100.00}
for epoch in range(num_epochs):
  train_loss.reset_states()
  val_loss.reset_states()
  train_acc.reset_states()
  val_acc.reset_states()

  for episode in range(num_episodes):
    frac_done = episode / num_episodes
    cur_meta_step_size = (1 - frac_done) * meta_step_size
    train_support, train_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
    
    model = Prototypical(train_support, train_query, img_width, img_height, channels)
    model.call(train_support, train_query)
    old_weights = model.get_weights()
    train_step(train_support, train_query, optimizer_sgd)
    new_weights = model.get_weights()

    for part_weight in range(len(new_weights)):
        new_weights[part_weight] = old_weights[part_weight] + (
            (new_weights[part_weight] - old_weights[part_weight]) * cur_meta_step_size
        )

    model.set_weights(new_weights)
    if epoch % eval_interval == 0:
      eval_support, eval_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
      old_vars = model.get_weights()
      val_step(eval_support, eval_query, optimizer_adam)
      model.set_weights(old_vars)
      
  cur_loss = val_loss.result().numpy()

  if (epoch % eval_interval == 0):
    template = 'Epoch {}, Loss: {}, Accuracy: {}, ' \
                  'Val Loss: {}, Val Accuracy: {}'
    print(template.format(epoch + 1, train_loss.result(), train_acc.result() * 100, val_loss.result(),
                            val_acc.result() * 100))
    if cur_loss < least_loss['least_loss']:
      print("Saving new best model with loss: ", cur_loss)
      least_loss['least_loss'] = cur_loss
      model.save(save_path)



Epoch 1, Loss: 2.6033387184143066, Accuracy: 40.98332977294922, Val Loss: 2.576746940612793, Val Accuracy: 41.04999923706055
Saving new best model with loss:  2.576747




Epoch 5, Loss: 1.7560092210769653, Accuracy: 56.23332214355469, Val Loss: 1.7480398416519165, Val Accuracy: 55.94999313354492
Saving new best model with loss:  1.7480398




Epoch 9, Loss: 1.218807339668274, Accuracy: 68.01668548583984, Val Loss: 1.1998348236083984, Val Accuracy: 68.35002136230469
Saving new best model with loss:  1.1998348




Epoch 13, Loss: 0.8729893565177917, Accuracy: 75.8833236694336, Val Loss: 0.9005581140518188, Val Accuracy: 75.55000305175781
Saving new best model with loss:  0.9005581




Epoch 17, Loss: 0.7501782774925232, Accuracy: 79.44998931884766, Val Loss: 0.7413337826728821, Val Accuracy: 79.44998931884766
Saving new best model with loss:  0.7413338




Epoch 21, Loss: 0.6730606555938721, Accuracy: 81.41667175292969, Val Loss: 0.6669648885726929, Val Accuracy: 81.41665649414062
Saving new best model with loss:  0.6669649




Epoch 25, Loss: 0.6462421417236328, Accuracy: 81.28330993652344, Val Loss: 0.6190152168273926, Val Accuracy: 82.61666107177734
Saving new best model with loss:  0.6190152




Epoch 29, Loss: 0.5937665700912476, Accuracy: 83.73331451416016, Val Loss: 0.6171714663505554, Val Accuracy: 82.84998321533203
Saving new best model with loss:  0.61717147




Epoch 33, Loss: 0.5191134214401245, Accuracy: 84.6166763305664, Val Loss: 0.567471981048584, Val Accuracy: 84.23333740234375
Saving new best model with loss:  0.567472
Epoch 37, Loss: 0.5327808260917664, Accuracy: 84.91667175292969, Val Loss: 0.569898247718811, Val Accuracy: 84.14999389648438




Epoch 41, Loss: 0.4995056390762329, Accuracy: 85.54998779296875, Val Loss: 0.4828377068042755, Val Accuracy: 86.03335571289062
Saving new best model with loss:  0.4828377
Epoch 45, Loss: 0.4843531548976898, Accuracy: 85.91665649414062, Val Loss: 0.48674947023391724, Val Accuracy: 86.36669921875




Epoch 49, Loss: 0.4435310661792755, Accuracy: 86.8499984741211, Val Loss: 0.4413152039051056, Val Accuracy: 86.58334350585938
Saving new best model with loss:  0.4413152




Epoch 53, Loss: 0.44662415981292725, Accuracy: 87.4999771118164, Val Loss: 0.4246196448802948, Val Accuracy: 87.80000305175781
Saving new best model with loss:  0.42461964
Epoch 57, Loss: 0.43235546350479126, Accuracy: 87.33333587646484, Val Loss: 0.46410396695137024, Val Accuracy: 86.61666870117188




Epoch 61, Loss: 0.395089328289032, Accuracy: 88.56666564941406, Val Loss: 0.4057287275791168, Val Accuracy: 88.23332977294922
Saving new best model with loss:  0.40572873
Epoch 65, Loss: 0.4130925238132477, Accuracy: 88.1833267211914, Val Loss: 0.4181597828865051, Val Accuracy: 88.28334045410156




Epoch 69, Loss: 0.4024318754673004, Accuracy: 88.71665954589844, Val Loss: 0.40200015902519226, Val Accuracy: 88.43331909179688
Saving new best model with loss:  0.40200016




Epoch 73, Loss: 0.3813064992427826, Accuracy: 88.69998168945312, Val Loss: 0.37829458713531494, Val Accuracy: 89.0500259399414
Saving new best model with loss:  0.3782946




Epoch 77, Loss: 0.3885118067264557, Accuracy: 89.29998016357422, Val Loss: 0.35086002945899963, Val Accuracy: 89.5
Saving new best model with loss:  0.35086003
Epoch 81, Loss: 0.38140255212783813, Accuracy: 89.48332977294922, Val Loss: 0.3570939004421234, Val Accuracy: 89.25001525878906




Epoch 85, Loss: 0.3444972634315491, Accuracy: 90.06666564941406, Val Loss: 0.34407129883766174, Val Accuracy: 89.76666259765625
Saving new best model with loss:  0.3440713
Epoch 89, Loss: 0.34334155917167664, Accuracy: 90.06669616699219, Val Loss: 0.3639553487300873, Val Accuracy: 89.433349609375




Epoch 93, Loss: 0.34783488512039185, Accuracy: 89.95000457763672, Val Loss: 0.32922694087028503, Val Accuracy: 89.76666259765625
Saving new best model with loss:  0.32922694
Epoch 97, Loss: 0.3368595242500305, Accuracy: 90.61666107177734, Val Loss: 0.3538930416107178, Val Accuracy: 89.7833251953125
Epoch 101, Loss: 0.3091639280319214, Accuracy: 91.0, Val Loss: 0.3430728018283844, Val Accuracy: 90.08334350585938




Epoch 105, Loss: 0.33324486017227173, Accuracy: 90.4000244140625, Val Loss: 0.32130300998687744, Val Accuracy: 91.35002136230469
Saving new best model with loss:  0.321303
Epoch 109, Loss: 0.31750863790512085, Accuracy: 90.51668548583984, Val Loss: 0.32648587226867676, Val Accuracy: 90.75001525878906




Epoch 113, Loss: 0.3268447518348694, Accuracy: 90.53335571289062, Val Loss: 0.29836785793304443, Val Accuracy: 91.36666870117188
Saving new best model with loss:  0.29836786
Epoch 117, Loss: 0.2909415364265442, Accuracy: 91.98332214355469, Val Loss: 0.3043321669101715, Val Accuracy: 91.16666412353516




Epoch 121, Loss: 0.30339571833610535, Accuracy: 91.23333740234375, Val Loss: 0.29539692401885986, Val Accuracy: 91.51667785644531
Saving new best model with loss:  0.29539692




Epoch 125, Loss: 0.295055627822876, Accuracy: 91.36666870117188, Val Loss: 0.28974902629852295, Val Accuracy: 91.05001068115234
Saving new best model with loss:  0.28974903




Epoch 129, Loss: 0.2984561622142792, Accuracy: 91.14999389648438, Val Loss: 0.287800133228302, Val Accuracy: 91.55001831054688
Saving new best model with loss:  0.28780013




Epoch 133, Loss: 0.2784733474254608, Accuracy: 91.88333892822266, Val Loss: 0.2821562886238098, Val Accuracy: 92.25001525878906
Saving new best model with loss:  0.2821563




Epoch 137, Loss: 0.2901659309864044, Accuracy: 91.34998321533203, Val Loss: 0.2797023057937622, Val Accuracy: 91.98331451416016
Saving new best model with loss:  0.2797023




Epoch 141, Loss: 0.2647792100906372, Accuracy: 92.25, Val Loss: 0.2779799699783325, Val Accuracy: 92.26665496826172
Saving new best model with loss:  0.27797997




Epoch 145, Loss: 0.2638787031173706, Accuracy: 92.2166519165039, Val Loss: 0.27233457565307617, Val Accuracy: 92.23335266113281
Saving new best model with loss:  0.27233458
Epoch 149, Loss: 0.2719774544239044, Accuracy: 92.25, Val Loss: 0.2743369936943054, Val Accuracy: 91.91666412353516
Epoch 153, Loss: 0.26623570919036865, Accuracy: 92.5333251953125, Val Loss: 0.27610480785369873, Val Accuracy: 91.46666717529297




Epoch 157, Loss: 0.26589515805244446, Accuracy: 92.4333267211914, Val Loss: 0.2687501013278961, Val Accuracy: 92.20001220703125
Saving new best model with loss:  0.2687501




Epoch 161, Loss: 0.24144817888736725, Accuracy: 92.64998626708984, Val Loss: 0.25260043144226074, Val Accuracy: 93.11665344238281
Saving new best model with loss:  0.25260043
Epoch 165, Loss: 0.2553168535232544, Accuracy: 92.49998474121094, Val Loss: 0.2614561915397644, Val Accuracy: 92.36666107177734




Epoch 169, Loss: 0.2714557945728302, Accuracy: 92.09999084472656, Val Loss: 0.24429307878017426, Val Accuracy: 92.98333740234375
Saving new best model with loss:  0.24429308




Epoch 173, Loss: 0.2563086748123169, Accuracy: 93.04999542236328, Val Loss: 0.23052774369716644, Val Accuracy: 93.23332977294922
Saving new best model with loss:  0.23052774
Epoch 177, Loss: 0.24664954841136932, Accuracy: 93.04998016357422, Val Loss: 0.2368330955505371, Val Accuracy: 92.8499984741211
Epoch 181, Loss: 0.22733066976070404, Accuracy: 92.91666412353516, Val Loss: 0.24814777076244354, Val Accuracy: 92.48332214355469
Epoch 185, Loss: 0.2377336621284485, Accuracy: 92.39999389648438, Val Loss: 0.2308473438024521, Val Accuracy: 93.1333236694336
Epoch 189, Loss: 0.24759028851985931, Accuracy: 93.1500015258789, Val Loss: 0.23694837093353271, Val Accuracy: 93.01665496826172
Epoch 193, Loss: 0.22300229966640472, Accuracy: 93.45001220703125, Val Loss: 0.24687905609607697, Val Accuracy: 92.6166763305664




Epoch 197, Loss: 0.2479996234178543, Accuracy: 93.48332977294922, Val Loss: 0.21971751749515533, Val Accuracy: 93.7166519165039
Saving new best model with loss:  0.21971752
Epoch 201, Loss: 0.23086856305599213, Accuracy: 93.25000762939453, Val Loss: 0.2370212972164154, Val Accuracy: 93.24998474121094
Epoch 205, Loss: 0.22050286829471588, Accuracy: 93.58332061767578, Val Loss: 0.22989103198051453, Val Accuracy: 93.43331146240234




Epoch 209, Loss: 0.22516338527202606, Accuracy: 92.98333740234375, Val Loss: 0.2149752825498581, Val Accuracy: 93.36665344238281
Saving new best model with loss:  0.21497528




Epoch 213, Loss: 0.20225901901721954, Accuracy: 93.86664581298828, Val Loss: 0.21099480986595154, Val Accuracy: 93.80001068115234
Saving new best model with loss:  0.21099481
Epoch 217, Loss: 0.24796682596206665, Accuracy: 92.93333435058594, Val Loss: 0.23946760594844818, Val Accuracy: 92.98332214355469




Epoch 221, Loss: 0.22618892788887024, Accuracy: 93.33332061767578, Val Loss: 0.20968016982078552, Val Accuracy: 93.78334045410156
Saving new best model with loss:  0.20968017
Epoch 225, Loss: 0.20451462268829346, Accuracy: 93.4666519165039, Val Loss: 0.2385098785161972, Val Accuracy: 93.05000305175781




Epoch 229, Loss: 0.1802801936864853, Accuracy: 94.3499984741211, Val Loss: 0.2032293826341629, Val Accuracy: 94.04998016357422
Saving new best model with loss:  0.20322938
Epoch 233, Loss: 0.21104231476783752, Accuracy: 93.81666564941406, Val Loss: 0.2121756225824356, Val Accuracy: 93.78331756591797
Epoch 237, Loss: 0.22431673109531403, Accuracy: 93.73332977294922, Val Loss: 0.21172069013118744, Val Accuracy: 93.55001068115234




Epoch 241, Loss: 0.2029552310705185, Accuracy: 94.0333251953125, Val Loss: 0.20284609496593475, Val Accuracy: 93.88333129882812
Saving new best model with loss:  0.2028461


In [25]:
#Combi Proto + Reptile for first values from the list ie 40-way 5-shot experiment

meta_step_size = 0.25

#Interval between running SGD on the validation dataset
eval_interval = 4

#number of classes
num_way = train_num_ways[2]

#number of examples per class for support set
num_shot = num_shots[2]  

#number of query points
num_query = num_shots[2] 


train_loss = tf.metrics.Mean(name='train_loss')
val_loss = tf.metrics.Mean(name='val_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')
val_acc = tf.metrics.Mean(name='val_accuracy')
support = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
query = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
model = Prototypical(support, query, img_height, img_width, channels)
optimizer_adam = tf.keras.optimizers.Adam(learning_rate, beta_1=0)
optimizer_sgd = tf.keras.optimizers.SGD(learning_rate)

num_epochs = 241
num_episodes = 100
save_path = "./results/models/omniglot_train_reptile2.h5"

@tf.function
def loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

@tf.function
def train_step(support, query, optimizer):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  
  train_loss(loss)
  train_acc(acc)
  

@tf.function
def val_step(support, query, optimizer):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  
  val_loss(loss)
  val_acc(acc)


least_loss = {'least_loss': 100.00}
for epoch in range(num_epochs):
  train_loss.reset_states()
  val_loss.reset_states()
  train_acc.reset_states()
  val_acc.reset_states()

  for episode in range(num_episodes):
    frac_done = episode / num_episodes
    cur_meta_step_size = (1 - frac_done) * meta_step_size
    train_support, train_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
    
    model = Prototypical(train_support, train_query, img_width, img_height, channels)
    model.call(train_support, train_query)
    old_weights = model.get_weights()
    train_step(train_support, train_query, optimizer_sgd)
    new_weights = model.get_weights()

    for part_weight in range(len(new_weights)):
        new_weights[part_weight] = old_weights[part_weight] + (
            (new_weights[part_weight] - old_weights[part_weight]) * cur_meta_step_size
        )

    model.set_weights(new_weights)
    if epoch % eval_interval == 0:
      eval_support, eval_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
      old_vars = model.get_weights()
      val_step(eval_support, eval_query, optimizer_adam)
      model.set_weights(old_vars)
      
  cur_loss = val_loss.result().numpy()

  if (epoch % eval_interval == 0):
    template = 'Epoch {}, Loss: {}, Accuracy: {}, ' \
                  'Val Loss: {}, Val Accuracy: {}'
    print(template.format(epoch + 1, train_loss.result(), train_acc.result() * 100, val_loss.result(),
                            val_acc.result() * 100))
    if cur_loss < least_loss['least_loss']:
      print("Saving new best model with loss: ", cur_loss)
      least_loss['least_loss'] = cur_loss
      model.save(save_path)



Epoch 1, Loss: 1.2778325080871582, Accuracy: 70.7650146484375, Val Loss: 1.2833958864212036, Val Accuracy: 70.57499694824219
Saving new best model with loss:  1.2833959




Epoch 5, Loss: 0.531635046005249, Accuracy: 85.2649917602539, Val Loss: 0.5366384387016296, Val Accuracy: 85.49000549316406
Saving new best model with loss:  0.53663844




Epoch 9, Loss: 0.3411201536655426, Accuracy: 90.114990234375, Val Loss: 0.3406173288822174, Val Accuracy: 90.59001922607422
Saving new best model with loss:  0.34061733




Epoch 13, Loss: 0.24660103023052216, Accuracy: 92.8499755859375, Val Loss: 0.2399563193321228, Val Accuracy: 93.18499755859375
Saving new best model with loss:  0.23995632




Epoch 17, Loss: 0.19253690540790558, Accuracy: 94.36499786376953, Val Loss: 0.18862922489643097, Val Accuracy: 94.62498474121094
Saving new best model with loss:  0.18862922




Epoch 21, Loss: 0.16575631499290466, Accuracy: 95.18496704101562, Val Loss: 0.16646651923656464, Val Accuracy: 95.26000213623047
Saving new best model with loss:  0.16646652




Epoch 25, Loss: 0.14937224984169006, Accuracy: 95.65497589111328, Val Loss: 0.162489116191864, Val Accuracy: 95.49998474121094
Saving new best model with loss:  0.16248912




Epoch 29, Loss: 0.1354098618030548, Accuracy: 96.05501556396484, Val Loss: 0.15481416881084442, Val Accuracy: 95.65499877929688
Saving new best model with loss:  0.15481417




Epoch 33, Loss: 0.13597017526626587, Accuracy: 95.89000701904297, Val Loss: 0.13422712683677673, Val Accuracy: 95.97000122070312
Saving new best model with loss:  0.13422713




Epoch 37, Loss: 0.12829512357711792, Accuracy: 96.1399917602539, Val Loss: 0.12371604144573212, Val Accuracy: 96.3899917602539
Saving new best model with loss:  0.12371604




Epoch 41, Loss: 0.11664925515651703, Accuracy: 96.6199951171875, Val Loss: 0.11615307629108429, Val Accuracy: 96.78997802734375
Saving new best model with loss:  0.11615308




Epoch 45, Loss: 0.1089835911989212, Accuracy: 96.7750015258789, Val Loss: 0.11496764421463013, Val Accuracy: 96.75000762939453
Saving new best model with loss:  0.114967644




Epoch 49, Loss: 0.10411446541547775, Accuracy: 96.87002563476562, Val Loss: 0.10971865057945251, Val Accuracy: 96.80998992919922
Saving new best model with loss:  0.10971865




Epoch 53, Loss: 0.10446824878454208, Accuracy: 97.11498260498047, Val Loss: 0.10837903618812561, Val Accuracy: 96.99000549316406
Saving new best model with loss:  0.108379036




Epoch 57, Loss: 0.1053667664527893, Accuracy: 96.9649887084961, Val Loss: 0.1079777255654335, Val Accuracy: 96.87500762939453
Saving new best model with loss:  0.107977726




Epoch 61, Loss: 0.09492173790931702, Accuracy: 97.37001037597656, Val Loss: 0.09298408776521683, Val Accuracy: 97.22500610351562
Saving new best model with loss:  0.09298409
Epoch 65, Loss: 0.08910074830055237, Accuracy: 97.41000366210938, Val Loss: 0.10031209141016006, Val Accuracy: 97.17500305175781




Epoch 69, Loss: 0.09257641434669495, Accuracy: 97.31503295898438, Val Loss: 0.09135586768388748, Val Accuracy: 97.31999969482422
Saving new best model with loss:  0.09135587




Epoch 73, Loss: 0.08452310413122177, Accuracy: 97.5400161743164, Val Loss: 0.08601638674736023, Val Accuracy: 97.45498657226562
Saving new best model with loss:  0.08601639
Epoch 77, Loss: 0.0793399065732956, Accuracy: 97.68501281738281, Val Loss: 0.08982040733098984, Val Accuracy: 97.22002410888672
Epoch 81, Loss: 0.08369364589452744, Accuracy: 97.61000061035156, Val Loss: 0.09161774814128876, Val Accuracy: 97.27999877929688
Epoch 85, Loss: 0.08576440811157227, Accuracy: 97.64498901367188, Val Loss: 0.08863837271928787, Val Accuracy: 97.39500427246094




Epoch 89, Loss: 0.0896897092461586, Accuracy: 97.44500732421875, Val Loss: 0.07857168465852737, Val Accuracy: 97.68999481201172
Saving new best model with loss:  0.078571685




Epoch 93, Loss: 0.0873657613992691, Accuracy: 97.51000213623047, Val Loss: 0.07182144373655319, Val Accuracy: 97.67000579833984
Saving new best model with loss:  0.071821444
Epoch 97, Loss: 0.07409047335386276, Accuracy: 97.91000366210938, Val Loss: 0.07194408029317856, Val Accuracy: 98.07003021240234
Epoch 101, Loss: 0.07202820479869843, Accuracy: 97.87001037597656, Val Loss: 0.078948974609375, Val Accuracy: 97.7600326538086
Epoch 105, Loss: 0.07309170067310333, Accuracy: 97.81001281738281, Val Loss: 0.07566153258085251, Val Accuracy: 97.905029296875
Epoch 109, Loss: 0.0657123327255249, Accuracy: 98.0199966430664, Val Loss: 0.07335759699344635, Val Accuracy: 97.90999603271484
Epoch 113, Loss: 0.07673897594213486, Accuracy: 97.80001068115234, Val Loss: 0.07334624975919724, Val Accuracy: 97.80500793457031
Epoch 117, Loss: 0.06711799651384354, Accuracy: 97.96500396728516, Val Loss: 0.0720212310552597, Val Accuracy: 97.83000183105469




Epoch 121, Loss: 0.07333584129810333, Accuracy: 97.76500701904297, Val Loss: 0.06762539595365524, Val Accuracy: 97.7950210571289
Saving new best model with loss:  0.067625396
Epoch 125, Loss: 0.06426732987165451, Accuracy: 98.10503387451172, Val Loss: 0.07328392565250397, Val Accuracy: 97.80001068115234




Epoch 129, Loss: 0.06347328424453735, Accuracy: 98.1300277709961, Val Loss: 0.06303703784942627, Val Accuracy: 98.21002197265625
Saving new best model with loss:  0.06303704
Epoch 133, Loss: 0.0629470944404602, Accuracy: 97.99999237060547, Val Loss: 0.07302965968847275, Val Accuracy: 97.63003540039062
Epoch 137, Loss: 0.07035204023122787, Accuracy: 98.06000518798828, Val Loss: 0.0642375573515892, Val Accuracy: 98.16500854492188
Epoch 141, Loss: 0.060604143887758255, Accuracy: 98.16503143310547, Val Loss: 0.06728161871433258, Val Accuracy: 98.11502075195312
Epoch 145, Loss: 0.05995490401983261, Accuracy: 98.30000305175781, Val Loss: 0.06557169556617737, Val Accuracy: 98.12498474121094




Epoch 149, Loss: 0.05880499258637428, Accuracy: 98.15503692626953, Val Loss: 0.06254205107688904, Val Accuracy: 98.17501831054688
Saving new best model with loss:  0.06254205




Epoch 153, Loss: 0.05897952616214752, Accuracy: 98.25001525878906, Val Loss: 0.0623171441257, Val Accuracy: 98.0350112915039
Saving new best model with loss:  0.062317144




Epoch 157, Loss: 0.06327632814645767, Accuracy: 97.98003387451172, Val Loss: 0.051997341215610504, Val Accuracy: 98.40001678466797
Saving new best model with loss:  0.05199734
Epoch 161, Loss: 0.05255850404500961, Accuracy: 98.2900161743164, Val Loss: 0.059324730187654495, Val Accuracy: 98.26001739501953
Epoch 165, Loss: 0.06512188166379929, Accuracy: 98.08502960205078, Val Loss: 0.05904696509242058, Val Accuracy: 98.09001922607422
Epoch 169, Loss: 0.05091585963964462, Accuracy: 98.32002258300781, Val Loss: 0.056442905217409134, Val Accuracy: 98.32501220703125




Epoch 173, Loss: 0.05552484840154648, Accuracy: 98.34504699707031, Val Loss: 0.046321965754032135, Val Accuracy: 98.59003448486328
Saving new best model with loss:  0.046321966
Epoch 177, Loss: 0.05484458804130554, Accuracy: 98.32502746582031, Val Loss: 0.05130770802497864, Val Accuracy: 98.33003997802734
Epoch 181, Loss: 0.05601535737514496, Accuracy: 98.27001953125, Val Loss: 0.05257531255483627, Val Accuracy: 98.25501251220703
Epoch 185, Loss: 0.05296824499964714, Accuracy: 98.3900375366211, Val Loss: 0.05188318341970444, Val Accuracy: 98.42002868652344
Epoch 189, Loss: 0.048764124512672424, Accuracy: 98.43000793457031, Val Loss: 0.04891091212630272, Val Accuracy: 98.54502868652344
Epoch 193, Loss: 0.05324413254857063, Accuracy: 98.46501159667969, Val Loss: 0.0491485558450222, Val Accuracy: 98.5250015258789




Epoch 197, Loss: 0.05165988579392433, Accuracy: 98.42501068115234, Val Loss: 0.04530874267220497, Val Accuracy: 98.48002624511719
Saving new best model with loss:  0.045308743
Epoch 201, Loss: 0.04475121200084686, Accuracy: 98.55001831054688, Val Loss: 0.04991777241230011, Val Accuracy: 98.3900146484375




Epoch 205, Loss: 0.05244136229157448, Accuracy: 98.34001922607422, Val Loss: 0.04506022483110428, Val Accuracy: 98.54503631591797
Saving new best model with loss:  0.045060225
Epoch 209, Loss: 0.053685903549194336, Accuracy: 98.44002532958984, Val Loss: 0.04774003475904465, Val Accuracy: 98.53001403808594
Epoch 213, Loss: 0.04466458410024643, Accuracy: 98.59001922607422, Val Loss: 0.051700323820114136, Val Accuracy: 98.6250228881836
Epoch 217, Loss: 0.04466061666607857, Accuracy: 98.56502532958984, Val Loss: 0.046057119965553284, Val Accuracy: 98.54002380371094
Epoch 221, Loss: 0.04654303193092346, Accuracy: 98.51502227783203, Val Loss: 0.04824216291308403, Val Accuracy: 98.49504089355469




Epoch 225, Loss: 0.04849007725715637, Accuracy: 98.60501861572266, Val Loss: 0.04121449589729309, Val Accuracy: 98.7650375366211
Saving new best model with loss:  0.041214496




Epoch 229, Loss: 0.045249711722135544, Accuracy: 98.55001831054688, Val Loss: 0.039442431181669235, Val Accuracy: 98.68502044677734
Saving new best model with loss:  0.03944243
Epoch 233, Loss: 0.040863752365112305, Accuracy: 98.63504028320312, Val Loss: 0.04529554396867752, Val Accuracy: 98.59504699707031
Epoch 237, Loss: 0.047326575964689255, Accuracy: 98.42503356933594, Val Loss: 0.04339262843132019, Val Accuracy: 98.6600341796875
Epoch 241, Loss: 0.04493933171033859, Accuracy: 98.44500732421875, Val Loss: 0.04194182530045509, Val Accuracy: 98.68501281738281


In [26]:
#Combi Proto + Reptile for first values from the list ie 40-way 1-shot experiment

meta_step_size = 0.25

#Interval between running SGD on the validation dataset
eval_interval = 4

#number of classes
num_way = train_num_ways[3]

#number of examples per class for support set
num_shot = num_shots[3]  

#number of query points
num_query = num_shots[3] 


train_loss = tf.metrics.Mean(name='train_loss')
val_loss = tf.metrics.Mean(name='val_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')
val_acc = tf.metrics.Mean(name='val_accuracy')
support = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
query = np.zeros([num_way, num_shot, img_height, img_width, channels], dtype=np.float32)
model = Prototypical(support, query, img_height, img_width, channels)
optimizer_adam = tf.keras.optimizers.Adam(learning_rate, beta_1=0)
optimizer_sgd = tf.keras.optimizers.SGD(learning_rate)

num_epochs = 241
num_episodes = 100
save_path = "./results/models/omniglot_train_reptile3.h5"

@tf.function
def loss(support, query):
  loss, acc = model(support, query)
  return loss, acc

@tf.function
def train_step(support, query, optimizer):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  
  train_loss(loss)
  train_acc(acc)
  

@tf.function
def val_step(support, query, optimizer):
  with tf.GradientTape() as tape:
    loss, acc = model(support, query)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))
  
  val_loss(loss)
  val_acc(acc)


least_loss = {'least_loss': 100.00}
for epoch in range(num_epochs):
  train_loss.reset_states()
  val_loss.reset_states()
  train_acc.reset_states()
  val_acc.reset_states()

  for episode in range(num_episodes):
    frac_done = episode / num_episodes
    cur_meta_step_size = (1 - frac_done) * meta_step_size
    train_support, train_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
    
    model = Prototypical(train_support, train_query, img_width, img_height, channels)
    model.call(train_support, train_query)
    old_weights = model.get_weights()
    train_step(train_support, train_query, optimizer_sgd)
    new_weights = model.get_weights()

    for part_weight in range(len(new_weights)):
        new_weights[part_weight] = old_weights[part_weight] + (
            (new_weights[part_weight] - old_weights[part_weight]) * cur_meta_step_size
        )

    model.set_weights(new_weights)
    if epoch % eval_interval == 0:
      eval_support, eval_query = get_next_episode(train_dataset, num_way, num_shot, num_query, no_train_classes)
      old_vars = model.get_weights()
      val_step(eval_support, eval_query, optimizer_adam)
      model.set_weights(old_vars)
      
  cur_loss = val_loss.result().numpy()

  if (epoch % eval_interval == 0):
    template = 'Epoch {}, Loss: {}, Accuracy: {}, ' \
                  'Val Loss: {}, Val Accuracy: {}'
    print(template.format(epoch + 1, train_loss.result(), train_acc.result() * 100, val_loss.result(),
                            val_acc.result() * 100))
    if cur_loss < least_loss['least_loss']:
      print("Saving new best model with loss: ", cur_loss)
      least_loss['least_loss'] = cur_loss
      model.save(save_path)



Epoch 1, Loss: 2.2832324504852295, Accuracy: 45.125, Val Loss: 2.2794322967529297, Val Accuracy: 46.64999771118164
Saving new best model with loss:  2.2794323




Epoch 5, Loss: 1.5152398347854614, Accuracy: 61.07499694824219, Val Loss: 1.4932222366333008, Val Accuracy: 62.350006103515625
Saving new best model with loss:  1.4932222




Epoch 9, Loss: 1.1236075162887573, Accuracy: 70.52499389648438, Val Loss: 1.1055552959442139, Val Accuracy: 69.07499694824219
Saving new best model with loss:  1.1055553




Epoch 13, Loss: 0.7904539704322815, Accuracy: 78.07499694824219, Val Loss: 0.7906769514083862, Val Accuracy: 77.92498779296875
Saving new best model with loss:  0.79067695




Epoch 17, Loss: 0.7175846695899963, Accuracy: 79.79999542236328, Val Loss: 0.6741983294487, Val Accuracy: 80.9000015258789
Saving new best model with loss:  0.6741983




Epoch 21, Loss: 0.596682608127594, Accuracy: 83.30000305175781, Val Loss: 0.5750295519828796, Val Accuracy: 84.04998779296875
Saving new best model with loss:  0.57502955
Epoch 25, Loss: 0.5701245665550232, Accuracy: 83.4999771118164, Val Loss: 0.5775517821311951, Val Accuracy: 84.4749984741211




Epoch 29, Loss: 0.5589112639427185, Accuracy: 84.8750228881836, Val Loss: 0.509955883026123, Val Accuracy: 85.37499237060547
Saving new best model with loss:  0.5099559




Epoch 33, Loss: 0.5344743132591248, Accuracy: 84.47498321533203, Val Loss: 0.4585547149181366, Val Accuracy: 86.54998779296875
Saving new best model with loss:  0.45855471
Epoch 37, Loss: 0.45885294675827026, Accuracy: 86.29998779296875, Val Loss: 0.47332248091697693, Val Accuracy: 86.39998626708984




Epoch 41, Loss: 0.4515961706638336, Accuracy: 86.77500915527344, Val Loss: 0.44246596097946167, Val Accuracy: 87.20001220703125
Saving new best model with loss:  0.44246596




Epoch 45, Loss: 0.473781555891037, Accuracy: 86.32500457763672, Val Loss: 0.42289459705352783, Val Accuracy: 86.64999389648438
Saving new best model with loss:  0.4228946




Epoch 49, Loss: 0.4183744490146637, Accuracy: 87.95000457763672, Val Loss: 0.4212239384651184, Val Accuracy: 88.77501678466797
Saving new best model with loss:  0.42122394




Epoch 53, Loss: 0.39514362812042236, Accuracy: 88.02500915527344, Val Loss: 0.39987286925315857, Val Accuracy: 87.95000457763672
Saving new best model with loss:  0.39987287
Epoch 57, Loss: 0.39457637071609497, Accuracy: 88.69999694824219, Val Loss: 0.4061111807823181, Val Accuracy: 87.82498931884766




Epoch 61, Loss: 0.42657870054244995, Accuracy: 87.79999542236328, Val Loss: 0.38755154609680176, Val Accuracy: 88.54999542236328
Saving new best model with loss:  0.38755155
Epoch 65, Loss: 0.36409518122673035, Accuracy: 89.57503509521484, Val Loss: 0.39198046922683716, Val Accuracy: 88.27499389648438




Epoch 69, Loss: 0.38136303424835205, Accuracy: 89.12499237060547, Val Loss: 0.36226391792297363, Val Accuracy: 89.70001220703125
Saving new best model with loss:  0.36226392
Epoch 73, Loss: 0.37955912947654724, Accuracy: 88.72501373291016, Val Loss: 0.38923558592796326, Val Accuracy: 88.97499084472656




Epoch 77, Loss: 0.3851100504398346, Accuracy: 89.64998626708984, Val Loss: 0.3371191918849945, Val Accuracy: 90.14998626708984
Saving new best model with loss:  0.3371192




Epoch 81, Loss: 0.3253534436225891, Accuracy: 89.55000305175781, Val Loss: 0.3144795894622803, Val Accuracy: 90.70003509521484
Saving new best model with loss:  0.3144796
Epoch 85, Loss: 0.3460928797721863, Accuracy: 90.4000015258789, Val Loss: 0.33219459652900696, Val Accuracy: 90.37499237060547




Epoch 89, Loss: 0.32944121956825256, Accuracy: 91.20001983642578, Val Loss: 0.31071996688842773, Val Accuracy: 91.32498931884766
Saving new best model with loss:  0.31071997
Epoch 93, Loss: 0.33780524134635925, Accuracy: 89.85001373291016, Val Loss: 0.32075339555740356, Val Accuracy: 90.7249984741211




Epoch 97, Loss: 0.3245494067668915, Accuracy: 91.15001678466797, Val Loss: 0.2985217571258545, Val Accuracy: 91.74998474121094
Saving new best model with loss:  0.29852176




Epoch 101, Loss: 0.28559476137161255, Accuracy: 91.7750015258789, Val Loss: 0.2969822585582733, Val Accuracy: 90.92498779296875
Saving new best model with loss:  0.29698226




Epoch 105, Loss: 0.30014467239379883, Accuracy: 91.55000305175781, Val Loss: 0.2726583182811737, Val Accuracy: 92.24998474121094
Saving new best model with loss:  0.27265832
Epoch 109, Loss: 0.2911600172519684, Accuracy: 91.57501220703125, Val Loss: 0.30165091156959534, Val Accuracy: 91.02501678466797
Epoch 113, Loss: 0.2899489104747772, Accuracy: 91.625, Val Loss: 0.2847962975502014, Val Accuracy: 91.99998474121094




Epoch 117, Loss: 0.27020856738090515, Accuracy: 91.6500015258789, Val Loss: 0.2704337239265442, Val Accuracy: 91.80001831054688
Saving new best model with loss:  0.27043372
Epoch 121, Loss: 0.24964532256126404, Accuracy: 92.29999542236328, Val Loss: 0.28163671493530273, Val Accuracy: 92.02499389648438
Epoch 125, Loss: 0.2722587287425995, Accuracy: 92.14998626708984, Val Loss: 0.270659863948822, Val Accuracy: 91.92498779296875
Epoch 129, Loss: 0.2793567180633545, Accuracy: 91.8499984741211, Val Loss: 0.2874007821083069, Val Accuracy: 91.5250015258789
Epoch 133, Loss: 0.2729196846485138, Accuracy: 91.95001983642578, Val Loss: 0.28019216656684875, Val Accuracy: 91.3499984741211




Epoch 137, Loss: 0.2395811527967453, Accuracy: 92.875, Val Loss: 0.2593238055706024, Val Accuracy: 92.77499389648438
Saving new best model with loss:  0.2593238




Epoch 141, Loss: 0.2539675235748291, Accuracy: 92.62500762939453, Val Loss: 0.2395191192626953, Val Accuracy: 92.64998626708984
Saving new best model with loss:  0.23951912




Epoch 145, Loss: 0.2812284529209137, Accuracy: 91.92499542236328, Val Loss: 0.23839332163333893, Val Accuracy: 92.57501983642578
Saving new best model with loss:  0.23839332




Epoch 149, Loss: 0.2630496919155121, Accuracy: 92.29998779296875, Val Loss: 0.22634939849376678, Val Accuracy: 93.2249984741211
Saving new best model with loss:  0.2263494
Epoch 153, Loss: 0.2675328850746155, Accuracy: 92.42500305175781, Val Loss: 0.2674693763256073, Val Accuracy: 92.02501678466797
Epoch 157, Loss: 0.251404345035553, Accuracy: 92.82499694824219, Val Loss: 0.2337638884782791, Val Accuracy: 93.12498474121094




Epoch 161, Loss: 0.25193464756011963, Accuracy: 92.79998779296875, Val Loss: 0.2179429680109024, Val Accuracy: 93.27497863769531
Saving new best model with loss:  0.21794297




Epoch 165, Loss: 0.22168253362178802, Accuracy: 93.45000457763672, Val Loss: 0.20432928204536438, Val Accuracy: 93.60000610351562
Saving new best model with loss:  0.20432928
Epoch 169, Loss: 0.23721417784690857, Accuracy: 92.84998321533203, Val Loss: 0.20978718996047974, Val Accuracy: 93.67501068115234
Epoch 173, Loss: 0.2343026101589203, Accuracy: 93.32502746582031, Val Loss: 0.23101498186588287, Val Accuracy: 93.29999542236328
Epoch 177, Loss: 0.23195543885231018, Accuracy: 93.09999084472656, Val Loss: 0.22966672480106354, Val Accuracy: 93.80001068115234
Epoch 181, Loss: 0.21852423250675201, Accuracy: 93.84996795654297, Val Loss: 0.2323383390903473, Val Accuracy: 92.87500762939453
Epoch 185, Loss: 0.22868211567401886, Accuracy: 93.57498168945312, Val Loss: 0.24987754225730896, Val Accuracy: 92.89997863769531
Epoch 189, Loss: 0.21965712308883667, Accuracy: 93.04997253417969, Val Loss: 0.2370077520608902, Val Accuracy: 92.5999984741211
Epoch 193, Loss: 0.2252950221300125, Accuracy: 93



Epoch 197, Loss: 0.23177404701709747, Accuracy: 93.77497863769531, Val Loss: 0.19102539122104645, Val Accuracy: 94.3249740600586
Saving new best model with loss:  0.19102539




Epoch 201, Loss: 0.22074775397777557, Accuracy: 93.6500015258789, Val Loss: 0.18636654317378998, Val Accuracy: 94.52497100830078
Saving new best model with loss:  0.18636654
Epoch 205, Loss: 0.2062712460756302, Accuracy: 93.77498626708984, Val Loss: 0.2193373143672943, Val Accuracy: 93.49998474121094
Epoch 209, Loss: 0.19275163114070892, Accuracy: 94.29998779296875, Val Loss: 0.1966104954481125, Val Accuracy: 94.27497863769531
Epoch 213, Loss: 0.22827354073524475, Accuracy: 93.57501220703125, Val Loss: 0.2046828418970108, Val Accuracy: 93.97500610351562
Epoch 217, Loss: 0.22473737597465515, Accuracy: 93.79998779296875, Val Loss: 0.2085060477256775, Val Accuracy: 93.82498168945312
Epoch 221, Loss: 0.18876248598098755, Accuracy: 94.44995880126953, Val Loss: 0.19851471483707428, Val Accuracy: 93.90001678466797
Epoch 225, Loss: 0.20458747446537018, Accuracy: 94.22498321533203, Val Loss: 0.19220851361751556, Val Accuracy: 94.22494506835938
Epoch 229, Loss: 0.16747373342514038, Accuracy: 95.



Epoch 241, Loss: 0.18552295863628387, Accuracy: 94.40000915527344, Val Loss: 0.18530941009521484, Val Accuracy: 94.57499694824219
Saving new best model with loss:  0.18530941


In [27]:
reptile_save_paths = ["./results/models/omniglot_train_reptile0.h5", "./results/models/omniglot_train_reptile1.h5", \
                      "./results/models/omniglot_train_reptile2.h5",\
                      "./results/models/omniglot_train_reptile3.h5"]

In [34]:
accuracies_reptile_0 = {}
support = np.zeros([num_way, num_shot, img_width, img_height, channels], dtype=np.float32)
query = np.zeros([num_way, num_query, img_width, img_height, channels], dtype=np.float32)
model = Prototypical(support, query, img_width, img_height, channels)

def calc_loss(model, support, query):
  loss, acc = model.call(support, query)
  return loss, acc

for save_path in reptile_save_paths:
  model_path = save_path
  model.load(model_path)
  print("Model with path {} loaded.".format(save_path))
  num_episodes = 1000

  #number of classes
  num_way = test_num_ways[0]

  #number of examples per class for support set
  num_shot = num_shots[0]  

  #number of query points
  num_query = num_shots[0] 

  # Metrics to gather
  test_loss = tf.metrics.Mean(name='test_loss')
  test_acc = tf.metrics.Mean(name='test_accuracy')


  for i_episode in range(num_episodes):
    test_support, test_query = get_next_episode(test_dataset, num_way, num_shot, num_shot, no_test_classes)
    loss, acc = calc_loss(model, test_support, test_query)
    test_loss(loss)
    test_acc(acc)

  loss = test_loss.result().numpy()
  accuracy = test_acc.result().numpy() * 100
  print("Loss: ", loss)
  print("Accuracy: ", accuracy)
  accuracies_reptile_0[model_path] = accuracy

Model with path ./results/models/omniglot_train_reptile0.h5 loaded.
Loss:  1.6092929
Accuracy:  82.28806853294373
Model with path ./results/models/omniglot_train_reptile1.h5 loaded.
Loss:  1.6091865
Accuracy:  81.18806481361389
Model with path ./results/models/omniglot_train_reptile2.h5 loaded.
Loss:  1.6092588
Accuracy:  82.27205872535706
Model with path ./results/models/omniglot_train_reptile3.h5 loaded.
Loss:  1.6092482
Accuracy:  81.26807808876038


In [35]:
accuracies_reptile_1 = {}

def calc_loss(model, support, query):
  loss, acc = model.call(support, query)
  return loss, acc

for save_path in reptile_save_paths:
  model_path = save_path
  model.load(model_path)
  print("Model with path {} loaded.".format(save_path))
  num_episodes = 1000

  #number of classes
  num_way = test_num_ways[1]

  #number of examples per class for support set
  num_shot = num_shots[1]  

  #number of query points
  num_query = num_shots[1] 

  # Metrics to gather
  test_loss = tf.metrics.Mean(name='test_loss')
  test_acc = tf.metrics.Mean(name='test_accuracy')

  print("Testing {} way {} shot".format(num_way, num_shot))

  for i_episode in range(num_episodes):
    test_support, test_query = get_next_episode(test_dataset, num_way, num_shot, num_shot, no_test_classes)
    loss, acc = calc_loss(model, test_support, test_query)
    test_loss(loss)
    test_acc(acc)

  loss = test_loss.result().numpy()
  accuracy = test_acc.result().numpy() * 100
  print("Loss: ", loss)
  print("Accuracy: ", accuracy)
  accuracies_reptile_1[model_path] = accuracy

Model with path ./results/models/omniglot_train_reptile0.h5 loaded.
Testing 5 way 1 shot
Loss:  1.6092882
Accuracy:  63.75981569290161
Model with path ./results/models/omniglot_train_reptile1.h5 loaded.
Testing 5 way 1 shot
Loss:  1.6091793
Accuracy:  64.97978568077087
Model with path ./results/models/omniglot_train_reptile2.h5 loaded.
Testing 5 way 1 shot
Loss:  1.6092618
Accuracy:  65.17989039421082
Model with path ./results/models/omniglot_train_reptile3.h5 loaded.
Testing 5 way 1 shot
Loss:  1.6092479
Accuracy:  63.45980763435364


In [36]:
accuracies_reptile_2 = {}

def calc_loss(model, support, query):
  loss, acc = model.call(support, query)
  return loss, acc

for save_path in reptile_save_paths:
  model_path = save_path
  model.load(model_path)
  print("Model with path {} loaded.".format(save_path))
  num_episodes = 1000

  #number of classes
  num_way = test_num_ways[2]

  #number of examples per class for support set
  num_shot = num_shots[2]  

  #number of query points
  num_query = num_shots[2] 

  # Metrics to gather
  test_loss = tf.metrics.Mean(name='test_loss')
  test_acc = tf.metrics.Mean(name='test_accuracy')

  print("Testing {} way {} shot".format(num_way, num_shot))

  for i_episode in range(num_episodes):
    test_support, test_query = get_next_episode(test_dataset, num_way, num_shot, num_shot, no_test_classes)
    loss, acc = calc_loss(model, test_support, test_query)
    test_loss(loss)
    test_acc(acc)

  loss = test_loss.result().numpy()
  accuracy = test_acc.result().numpy() * 100
  print("Loss: ", loss)
  print("Accuracy: ", accuracy)
  accuracies_reptile_2[model_path] = accuracy

Model with path ./results/models/omniglot_train_reptile0.h5 loaded.
Testing 20 way 5 shot
Loss:  2.9955788
Accuracy:  64.20199871063232
Model with path ./results/models/omniglot_train_reptile1.h5 loaded.
Testing 20 way 5 shot
Loss:  2.9954238
Accuracy:  61.94004416465759
Model with path ./results/models/omniglot_train_reptile2.h5 loaded.
Testing 20 way 5 shot
Loss:  2.995539
Accuracy:  63.39397430419922
Model with path ./results/models/omniglot_train_reptile3.h5 loaded.
Testing 20 way 5 shot
Loss:  2.9955156
Accuracy:  61.36692762374878


In [37]:
accuracies_reptile_3 = {}

def calc_loss(model, support, query):
  loss, acc = model.call(support, query)
  return loss, acc

for save_path in reptile_save_paths:
  model_path = save_path
  model.load(model_path)
  print("Model with path {} loaded.".format(save_path))
  num_episodes = 1000

  #number of classes
  num_way = test_num_ways[3]

  #number of examples per class for support set
  num_shot = num_shots[3]  

  #number of query points
  num_query = num_shots[3] 

  # Metrics to gather
  test_loss = tf.metrics.Mean(name='test_loss')
  test_acc = tf.metrics.Mean(name='test_accuracy')

  print("Testing {} way {} shot".format(num_way, num_shot))

  for i_episode in range(num_episodes):
    test_support, test_query = get_next_episode(test_dataset, num_way, num_shot, num_shot, no_test_classes)
    loss, acc = calc_loss(model, test_support, test_query)
    test_loss(loss)
    test_acc(acc)

  loss = test_loss.result().numpy()
  accuracy = test_acc.result().numpy() * 100
  print("Loss: ", loss)
  print("Accuracy: ", accuracy)
  accuracies_reptile_3[model_path] = accuracy

Model with path ./results/models/omniglot_train_reptile0.h5 loaded.
Testing 20 way 1 shot
Loss:  2.9955764
Accuracy:  41.069960594177246
Model with path ./results/models/omniglot_train_reptile1.h5 loaded.
Testing 20 way 1 shot
Loss:  2.9954288
Accuracy:  40.87496995925903
Model with path ./results/models/omniglot_train_reptile2.h5 loaded.
Testing 20 way 1 shot
Loss:  2.99553
Accuracy:  42.574989795684814
Model with path ./results/models/omniglot_train_reptile3.h5 loaded.
Testing 20 way 1 shot
Loss:  2.9955115
Accuracy:  40.43997526168823


In [None]:
def build_combined_dataset(train_dataset, val_dataset, test_dataset):
  total_rows = train_dataset.shape[0] + val_dataset.shape[0] + test_dataset.shape[0]
  train_dataset_images = int(train_dataset.shape[1]/2)
  test_dataset_images = int(train_dataset.shape[1]/4)
  val_dataset_images = int(train_dataset.shape[1]/4)
  combined_train_dataset = np.zeros([total_rows, train_dataset_images, img_height, img_width, channels], dtype=np.float32)
  combined_val_dataset = np.zeros([total_rows, val_dataset_images, img_height, img_width, channels], dtype=np.float32)
  combined_test_dataset = np.zeros([total_rows, test_dataset_images, img_height, img_width, channels], dtype=np.float32)

  for i in range(train_dataset.shape[0]):
    for j in range(train_dataset_images):
      combined_train_dataset[i][j] = train_dataset[i][j]
    for k in range(val_dataset_images):
      combined_val_dataset[i][k] = train_dataset[i][train_dataset_images + k]
      combined_test_dataset[i][k] = train_dataset[i][train_dataset_images + val_dataset_images + k]

  for i in range (val_dataset.shape[0]):
    for j in range(train_dataset_images):
      combined_train_dataset[train_dataset.shape[0] + i][j] = val_dataset[i][j]
    for k in range(val_dataset_images):
      combined_val_dataset[train_dataset.shape[0] + i][k] = val_dataset[i][train_dataset_images + k]
      combined_test_dataset[train_dataset.shape[0] + i][k] = train_dataset[i][train_dataset_images + val_dataset_images + k]

  for i in range (test_dataset.shape[0]):
    for j in range(train_dataset_images):
      combined_train_dataset[train_dataset.shape[0] + val_dataset.shape[0] + i][j] = test_dataset[i][j]
    for k in range(val_dataset_images):
      combined_val_dataset[train_dataset.shape[0] + val_dataset.shape[0] + i][k] = test_dataset[i][train_dataset_images + k]
      combined_test_dataset[train_dataset.shape[0] + val_dataset.shape[0] + i][k] = train_dataset[i][train_dataset_images + val_dataset_images + k]

  return combined_train_dataset, combined_val_dataset, combined_test_dataset

In [None]:
combined_train_dataset, combined_val_dataset, combined_test_dataset = build_combined_dataset(train_dataset, val_dataset, test_dataset)
 

In [None]:
print(combined_train_dataset.shape)
print(combined_val_dataset.shape)
print(combined_test_dataset.shape)

(6492, 10, 32, 32, 1)
(6492, 5, 32, 32, 1)
(6492, 5, 32, 32, 1)


In [None]:
combined_train_dataset_reshaped = combined_train_dataset.reshape(64920, 32, 32, 1)
combined_val_dataset_reshaped = combined_val_dataset.reshape(32460, 32, 32, 1)
combined_test_dataset_reshaped = combined_test_dataset.reshape(32460, 32, 32, 1)

In [None]:
train_labels = np.zeros(combined_train_dataset_reshaped.shape[0])
val_labels = np.zeros(combined_val_dataset_reshaped.shape[0])
test_labels = np.zeros(combined_test_dataset_reshaped.shape[0])
val = 0
for i in range(combined_train_dataset.shape[0]):
  for j in range(combined_train_dataset.shape[1]):
    train_labels[(i * 10) + j] = val
  val = val + 1

val = 0
for i in range(combined_val_dataset.shape[0]):
  for j in range(combined_val_dataset.shape[1]):
    val_labels[(i * 5) + j] = val
    test_labels[(i * 5) + j] = val
  val = val + 1

train_labels[64500:]

array([6450., 6450., 6450., 6450., 6450., 6450., 6450., 6450., 6450.,
       6450., 6451., 6451., 6451., 6451., 6451., 6451., 6451., 6451.,
       6451., 6451., 6452., 6452., 6452., 6452., 6452., 6452., 6452.,
       6452., 6452., 6452., 6453., 6453., 6453., 6453., 6453., 6453.,
       6453., 6453., 6453., 6453., 6454., 6454., 6454., 6454., 6454.,
       6454., 6454., 6454., 6454., 6454., 6455., 6455., 6455., 6455.,
       6455., 6455., 6455., 6455., 6455., 6455., 6456., 6456., 6456.,
       6456., 6456., 6456., 6456., 6456., 6456., 6456., 6457., 6457.,
       6457., 6457., 6457., 6457., 6457., 6457., 6457., 6457., 6458.,
       6458., 6458., 6458., 6458., 6458., 6458., 6458., 6458., 6458.,
       6459., 6459., 6459., 6459., 6459., 6459., 6459., 6459., 6459.,
       6459., 6460., 6460., 6460., 6460., 6460., 6460., 6460., 6460.,
       6460., 6460., 6461., 6461., 6461., 6461., 6461., 6461., 6461.,
       6461., 6461., 6461., 6462., 6462., 6462., 6462., 6462., 6462.,
       6462., 6462.,

In [None]:
from numpy import argmax
from tensorflow.keras.utils import to_categorical
# one hot encode
train_labels_encoded = to_categorical(train_labels)
val_labels_encoded = to_categorical(val_labels)
test_labels_encoded = to_categorical(test_labels)
print(train_labels_encoded.shape)
print(val_labels_encoded.shape)
print(test_labels_encoded.shape)

(64920, 6492)
(32460, 6492)
(32460, 6492)


In [None]:
from tensorflow.keras import applications
from tensorflow.keras.models import Sequential

base_model = applications.resnet50.ResNet50(weights= None, include_top=False, input_shape= (img_height,img_width,1))
x = base_model.output
x = GlobalMaxPooling2D()(x)
x = Dropout(0.7)(x)
predictions = Dense(6492, activation= 'softmax')(x)
model = Model(inputs = base_model.input, outputs = predictions)
optimizer = tf.keras.optimizers.Adam(learning_rate)
model.compile(optimizer= optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(combined_train_dataset_reshaped, train_labels_encoded, epochs = 100, batch_size = 64, validation_data=(combined_val_dataset_reshaped, val_labels_encoded))


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x7fa6601f39d0>

In [None]:
preds = model.evaluate(combined_test_dataset_reshaped, test_labels_encoded)
print ("Loss = " + str(preds[0]))
print ("Test Accuracy = " + str(preds[1]))

Loss = 14.665081024169922
Test Accuracy = 0.014725816436111927
