In [1]:
import logging
import numpy as np
import tensorflow as tf
from node.core import get_node_function
from node.fix_grid import RKSolver
from node.utils.mnist import get_dataset, process_datum, accuracy


# for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


@tf.function
def normalize(x, axis=None):
    M = tf.reduce_max(x, axis, keepdims=True)
    m = tf.reduce_min(x, axis, keepdims=True)
    return (x - m) / (M - m + 1e-8)


class PrintLayer(tf.keras.layers.Layer):
    
    def __init__(self, name, **kwargs):
        super().__init__(**kwargs)
        self.print_name = name
    
    def call(self, x):
        tf.print(self.print_name, x)
        return x


class HopfieldLayer(tf.keras.layers.Layer):

    def __init__(self, network, solver, t, axis, **kwargs):
        super().__init__(**kwargs)
        self.network = network
        self.solver = solver
        self.t = t

        @tf.function
        def pvf(t, x):
            f = self.network(x)
            with tf.GradientTape() as g:
                g.watch(x)
                r = normalize(x, axis)
            return g.gradient(r, x, f)

        self._pvf = pvf
        self._node_fn = get_node_function(solver, 0., pvf)

    def call(self, x):
        y = self._node_fn(self.t, x)
        return y


def get_train_fn(model, optimizer,
                 num_epochs=5,
                 batch_size=128,
                 skip_step=50):
    loss_fn = tf.losses.CategoricalCrossentropy()

    @tf.function
    def train_one_step(X, y):
        with tf.GradientTape() as tape:
            outputs = model(X)
            loss = loss_fn(y, outputs)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        acc = accuracy(y, outputs)
        return loss, acc

    @tf.function
    def train(X_train, y_train, X_test, y_test):
        num_steps_per_epoch = int(len(X_train) / batch_size)

        train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
        train_dataset = (train_dataset.map(process_datum)
                          .shuffle(10000)
                          .repeat(num_epochs)
                          .batch(batch_size))
        test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
        test_dataset = test_dataset.map(process_datum).batch(batch_size)

        def evaluate():
            step = 0
            total_accuracy = 0.
            for X, y in test_dataset:
                outputs = model(X)
                total_accuracy += accuracy(y, outputs)
                step += 1
            return total_accuracy / tf.cast(step, total_accuracy.dtype)

        step = 0
        loss = float('inf')
        reg = float('inf')
        acc = 0.
        for X, y in train_dataset:
            loss, acc = train_one_step(X, y)

            if step % skip_step == 0:
                tf.print(step, loss, acc)

            if step % num_steps_per_epoch == 0:
                tf.print('testing')
                test_acc = evaluate()
                tf.print('test accuracy:', test_acc)

            step += 1
        return loss, acc

    return train

In [2]:
num_filters = 4
kernel_size = 3
dense_network = tf.keras.layers.Dense(28 * 28)
cnn_network = tf.keras.layers.Conv2D(
    num_filters, kernel_size, activation='relu', padding='same')
hopfield_layer = HopfieldLayer(
    cnn_network, RKSolver(1e-1), t=2e-1, axis=[1, 2, 3])
output_layer = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax'),
])
model = tf.keras.Sequential([
    tf.keras.layers.Input([28 * 28]),
    tf.keras.layers.Reshape([28, 28, 1]),
    tf.keras.layers.Conv2D(num_filters, 3, activation='relu'),
#     PrintLayer('\n\n----- 1 -----\n'),
    hopfield_layer,
#     PrintLayer('\n\n----- 2 -----\n'),
    tf.keras.layers.Flatten(),
    output_layer,
#     PrintLayer('\n\n----- 3 -----\n'),
])
model.build([28, 28, 1])

In [3]:
optimizer = tf.compat.v1.train.AdamOptimizer()

train = get_train_fn(model, optimizer, num_epochs=2, skip_step=1)
train(*get_dataset())

0 2.31090355 0.09375
testing
test accuracy: 0.182258695
1 2.2707448 0.1640625
2 2.1775043 0.2421875
3 2.10810375 0.453125
4 2.11546803 0.3828125
5 2.02227068 0.46875
6 1.93333089 0.5546875
7 1.83904839 0.59375
8 1.75376856 0.71875
9 1.603935 0.71875
10 1.54928029 0.7421875
11 1.5500598 0.640625
12 1.40588772 0.71875
13 1.35288072 0.765625
14 1.24567497 0.703125
15 1.1108619 0.8046875
16 1.10524321 0.75
17 0.903495908 0.859375
18 0.932240605 0.7890625
19 0.884675443 0.8203125
20 0.849267066 0.8359375
21 0.829985321 0.7890625
22 0.801091194 0.828125
23 0.707480788 0.84375
24 0.791820168 0.796875
25 0.7052809 0.78125
26 0.714248836 0.7890625
27 0.691588283 0.828125
28 0.714313567 0.7578125
29 0.58737731 0.8046875
30 0.708268285 0.7734375
31 0.527247429 0.890625
32 0.698205233 0.8046875
33 0.574311435 0.84375
34 0.400269538 0.8984375
35 0.482905507 0.8984375
36 0.713258207 0.7734375
37 0.578715682 0.84375
38 0.639100611 0.765625
39 0.538882494 0.8203125
40 0.361855626 0.8828125
41 0.477208

332 0.164339423 0.953125
333 0.167418182 0.9609375
334 0.201494351 0.9296875
335 0.189416826 0.9453125
336 0.213712603 0.9296875
337 0.234512031 0.9453125
338 0.17585817 0.9296875
339 0.284295082 0.921875
340 0.196337402 0.9453125
341 0.233299077 0.9375
342 0.259139091 0.8984375
343 0.207843781 0.9296875
344 0.282982677 0.9296875
345 0.0962316692 0.9765625
346 0.182559758 0.9453125
347 0.23384434 0.921875
348 0.135792106 0.9765625
349 0.268658072 0.9453125
350 0.190073028 0.9609375
351 0.352261841 0.8828125
352 0.154116914 0.953125
353 0.258489728 0.953125
354 0.250685215 0.9453125
355 0.147524387 0.953125
356 0.160539076 0.96875
357 0.242909461 0.9375
358 0.138357237 0.9375
359 0.2119461 0.9375
360 0.0871962607 0.96875
361 0.151824027 0.9609375
362 0.306748092 0.90625
363 0.186482668 0.96875
364 0.132378966 0.9609375
365 0.151092857 0.96875
366 0.204672635 0.953125
367 0.122073159 0.953125
368 0.145967 0.9453125
369 0.18893519 0.953125
370 0.291836202 0.90625
371 0.29195416 0.921875
3

658 0.142651498 0.9609375
659 0.136457279 0.9609375
660 0.107973181 0.9765625
661 0.0835875273 0.9765625
662 0.133338407 0.953125
663 0.141779184 0.9609375
664 0.103650764 0.9609375
665 0.198821425 0.9296875
666 0.087598145 0.96875
667 0.276417464 0.921875
668 0.162674189 0.9609375
669 0.0727767795 0.984375
670 0.17851615 0.953125
671 0.0958905518 0.9765625
672 0.268470526 0.9453125
673 0.142146304 0.96875
674 0.109662 0.96875
675 0.094406724 0.96875
676 0.103998408 0.9609375
677 0.0855795518 0.96875
678 0.149347126 0.953125
679 0.124232203 0.953125
680 0.147462666 0.9609375
681 0.116745338 0.9609375
682 0.0874092 0.9765625
683 0.166691393 0.96875
684 0.135437936 0.953125
685 0.0751885772 0.9765625
686 0.106976122 0.96875
687 0.0987169594 0.953125
688 0.132881775 0.96875
689 0.162355751 0.953125
690 0.108278662 0.953125
691 0.0529849418 0.9765625
692 0.127574205 0.9609375
693 0.191858798 0.9609375
694 0.139503419 0.953125
695 0.096016936 0.9765625
696 0.116314113 0.9609375
697 0.187090

(<tf.Tensor: id=2126, shape=(), dtype=float32, numpy=0.12096748>,
 <tf.Tensor: id=2127, shape=(), dtype=float32, numpy=0.953125>)