In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import optimizer


class Lookahead(optimizer.Optimizer):
    '''Tensorflow implementation of the lookahead wrapper.
    Lookahead Optimizer: https://arxiv.org/abs/1907.08610
    '''

    def __init__(self, optimizer, la_steps=5, la_alpha=0.8, use_locking=False, name="Lookahead"):
        """optimizer: inner optimizer
        la_steps (int): number of lookahead steps
        la_alpha (float): linear interpolation factor. 1.0 recovers the inner optimizer.
        """
        super(Lookahead, self).__init__(use_locking, name)
        self.optimizer = optimizer
        self._la_step = 0
        self._la_alpha = la_alpha
        self._total_la_steps = la_steps

    def _create_slots(self, var_list):
        self.optimizer._create_slots(var_list)

        self._var_list = var_list
        first_var = min(var_list, key=lambda x: x.name)
        self._create_non_slot_variable(initial_value=self._la_step,
                                       name="la_step",
                                       colocate_with=first_var)

        # Create slots for the cached parameters.
        for v in var_list:
            self._zeros_slot(v, "cached_params", self._name)

    def _prepare(self):
        self.optimizer._prepare()

        la_alpha = self._call_if_callable(self._la_alpha)
        total_la_steps = self._call_if_callable(self._total_la_steps)

        self._la_alpha_t = ops.convert_to_tensor(la_alpha, name="la_alpha")
        self._total_la_steps_t = ops.convert_to_tensor(total_la_steps, name="total_la_steps")

    def _get_la_step_accumulators(self):
        with ops.init_scope():
            if context.executing_eagerly():
                graph = None
            else:
                graph = ops.get_default_graph()
            return self._get_non_slot_variable("la_step", graph=graph)

    def _apply_dense(self, grad, var):
        return self.optimizer._apply_dense(grad, var)

    def _resource_apply_dense(self, grad, var):
        return self.optimizer._resource_apply_dense(grad, var)

    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
        return self.optimizer._apply_sparse_shared(grad, var, indices, scatter_add)

    def _apply_sparse(self, grad, var):
        return self.optimizer._apply_sparse(grad, var)

    def _resource_scatter_add(self, x, i, v):
        return self.optimizer._resource_scatter_add(x, i, v)

    def _resource_apply_sparse(self, grad, var, indices):
        return self.optimizer._resource_apply_sparse(grad, var, indices)

    def _finish(self, update_ops, name_scope):
        inner_finish_op = self.optimizer._finish(update_ops, name_scope)

        with ops.control_dependencies([inner_finish_op, ]):
            la_step = self._get_la_step_accumulators()
            with ops.colocate_with(la_step):
                def update_la_step_func():
                    # update the la_step
                    return control_flow_ops.group([la_step.assign(
                        la_step + 1, use_locking=self._use_locking), ])

                def pull_back_func():
                    # update the la_step
                    update_la_step = la_step.assign(
                        0, use_locking=self._use_locking)
                    # interpolate the variables
                    interpolation = [v.assign(
                        self.get_slot(v, "cached_params") + self._la_alpha_t * (v - self.get_slot(v, "cached_params")))
                                     for v in self._var_list]

                    # update the cached params
                    with ops.control_dependencies(interpolation):
                        update_cached_params = [self.get_slot(v, "cached_params").assign(updated_v) for v, updated_v in
                                                zip(self._var_list, interpolation)]
                    return control_flow_ops.group([update_la_step, ] + interpolation + update_cached_params)

                # condition for when to pull back the params
                condition = tf.greater_equal(la_step, self._total_la_steps_t)
                update_lookahead_states = tf.cond(condition,
                                                  pull_back_func,
                                                  update_la_step_func,
                                                  )

        return control_flow_ops.group([inner_finish_op, update_lookahead_states],
                                      name=name_scope)

    def _call_if_callable(self, param):
        """Call the function if param is callable."""
        return param() if callable(param) else param

In [3]:
import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [4]:
print(x_train.shape)
print(y_train.shape)

(60000, 28, 28, 1)
(60000,)


In [0]:
def create_model():
  model = tf.keras.models.Sequential()
  model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same', activation='elu'))
  model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))

  model.add(tf.keras.layers.Conv2D(128, (5, 5), padding='same', activation='elu'))
  model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))

  model.add(tf.keras.layers.Conv2D(256, (5, 5), padding='same', activation='elu'))
  model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))

  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dense(256))
  model.add(tf.keras.layers.Activation('elu'))
  model.add(tf.keras.layers.Dense(10))
  model.add(tf.keras.layers.Activation('softmax'))
  return model

In [0]:
model1 = create_model()

optimizer = tf.compat.v1.train.AdamOptimizer(1e-3)

model1.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)

In [0]:
model2 = create_model()

optimizer = tf.compat.v1.train.AdamOptimizer(1e-3)

model2.compile(
    optimizer=Lookahead(optimizer, la_steps=5, la_alpha=0.8),
    loss='sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)

In [0]:
model3 = create_model()

optimizer = tf.compat.v1.train.AdamOptimizer(1e-3)

model3.compile(
    optimizer=Lookahead(optimizer, la_steps=10, la_alpha=0.8),
    loss='sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)

In [0]:
model4 = create_model()

optimizer = tf.compat.v1.train.AdamOptimizer(1e-3)

model4.compile(
    optimizer=Lookahead(optimizer, la_steps=20, la_alpha=0.8),
    loss='sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)

In [0]:
model5 = create_model()

optimizer = tf.compat.v1.train.AdamOptimizer(1e-3)

model5.compile(
    optimizer=Lookahead(optimizer, la_steps=5, la_alpha=0.6),
    loss='sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)

In [0]:
model6 = create_model()

optimizer = tf.compat.v1.train.AdamOptimizer(1e-3)

model6.compile(
    optimizer=Lookahead(optimizer, la_steps=10, la_alpha=0.6),
    loss='sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)

In [12]:
history1 = model1.fit(
    x_train.astype(np.float32), y_train.astype(np.float32),
    epochs=10,
    steps_per_epoch=600,
    validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
    validation_freq=17
)

print(history1.history)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
{'loss': [1.0904664993286133, 0.2836657464504242, 0.24310573935508728, 0.21381902694702148, 0.19565501809120178, 0.18596601486206055, 0.1722569763660431, 0.16535115242004395, 0.15617439150810242, 0.15621936321258545], 'sparse_categorical_accuracy': [0.821566641330719, 0.895633339881897, 0.9109500050544739, 0.9203333258628845, 0.9273999929428101, 0.930899977684021, 0.9352999925613403, 0.9373499751091003, 0.9412999749183655, 0.9419500231742859]}


In [13]:
history2 = model2.fit(
    x_train.astype(np.float32), y_train.astype(np.float32),
    epochs=10,
    steps_per_epoch=600,
    validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
    validation_freq=17
)

print(history2.history)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
{'loss': [0.8968866467475891, 0.26313987374305725, 0.22648416459560394, 0.19670893251895905, 0.17874263226985931, 0.162642240524292, 0.1553630530834198, 0.14239048957824707, 0.12760722637176514, 0.1262667030096054], 'sparse_categorical_accuracy': [0.8300166726112366, 0.9029333591461182, 0.9158333539962769, 0.9264166951179504, 0.9328666925430298, 0.9396666884422302, 0.94118332862854, 0.9459166526794434, 0.9512333273887634, 0.9528499841690063]}


In [14]:
history3 = model3.fit(
    x_train.astype(np.float32), y_train.astype(np.float32),
    epochs=10,
    steps_per_epoch=600,
    validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
    validation_freq=17
)

print(history3.history)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
{'loss': [0.958195686340332, 0.2710147500038147, 0.228478342294693, 0.19928444921970367, 0.18250231444835663, 0.1677701622247696, 0.15127858519554138, 0.14517052471637726, 0.14106224477291107, 0.12215832620859146], 'sparse_categorical_accuracy': [0.8274000287055969, 0.9004999995231628, 0.9138666391372681, 0.9250166416168213, 0.9319166541099548, 0.9374666810035706, 0.9434999823570251, 0.9458833336830139, 0.9480833411216736, 0.9531833529472351]}


In [15]:
history4 = model4.fit(
    x_train.astype(np.float32), y_train.astype(np.float32),
    epochs=10,
    steps_per_epoch=600,
    validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
    validation_freq=17
)

print(history4.history)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
{'loss': [1.0629549026489258, 0.27821293473243713, 0.23568783700466156, 0.20998528599739075, 0.1900835782289505, 0.17038510739803314, 0.1604107916355133, 0.15150566399097443, 0.13729621469974518, 0.1343255341053009], 'sparse_categorical_accuracy': [0.8172833323478699, 0.895550012588501, 0.9120000004768372, 0.920366644859314, 0.9287833571434021, 0.9361166954040527, 0.9397666454315186, 0.9436833262443542, 0.9483000040054321, 0.94964998960495]}


In [16]:
history5 = model5.fit(
    x_train.astype(np.float32), y_train.astype(np.float32),
    epochs=10,
    steps_per_epoch=600,
    validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
    validation_freq=17
)

print(history5.history)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
{'loss': [0.8717005848884583, 0.26819178462028503, 0.22519956529140472, 0.19333207607269287, 0.17053692042827606, 0.1519627422094345, 0.13496793806552887, 0.12807711958885193, 0.11399805545806885, 0.10266362875699997], 'sparse_categorical_accuracy': [0.8213666677474976, 0.9006500244140625, 0.9165833592414856, 0.9285833239555359, 0.9355166554450989, 0.9430000185966492, 0.9491000175476074, 0.9517166614532471, 0.9564999938011169, 0.9615499973297119]}


In [17]:
history6 = model6.fit(
    x_train.astype(np.float32), y_train.astype(np.float32),
    epochs=10,
    steps_per_epoch=600,
    validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
    validation_freq=17
)

print(history6.history)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
{'loss': [1.0656099319458008, 0.27384957671165466, 0.22450201213359833, 0.19315072894096375, 0.17140942811965942, 0.14778201282024384, 0.13298730552196503, 0.12418126314878464, 0.10795404016971588, 0.10386539995670319], 'sparse_categorical_accuracy': [0.8129333257675171, 0.8990499973297119, 0.9157500267028809, 0.9283833503723145, 0.9354666471481323, 0.9430666565895081, 0.9484333395957947, 0.9527833461761475, 0.9591666460037231, 0.9603666663169861]}


In [0]:
from google.colab import files
files.download('final_ans_2.txt') 


In [0]:
with open("final_ans_2.txt","w+") as fi:
  fi.write(', '.join([str(num) for num in history1.history["loss"]]))
  fi.write('\n')
  fi.write(', '.join([str(num) for num in history2.history["loss"]]))
  fi.write('\n')
  fi.write(', '.join([str(num) for num in history3.history["loss"]]))
  fi.write('\n')
  fi.write(', '.join([str(num) for num in history4.history["loss"]]))
  fi.write('\n')
  fi.write(', '.join([str(num) for num in history5.history["loss"]]))
  fi.write('\n')
  fi.write(', '.join([str(num) for num in history6.history["loss"]]))
  fi.write('\n')