In [1]:
import logging
import numpy as np
import tensorflow as tf
from node.core import get_node_function
from node.solvers import RK4Solver
from node.utils.initializers import GlorotUniform


# for reproducibility
SEED = 15
np.random.seed(SEED)
tf.random.set_seed(SEED)


@tf.function
def normalize_v1(x, axis=None):
    M = tf.reduce_max(x, axis, keepdims=True)
    m = tf.reduce_min(x, axis, keepdims=True)
    return (x - m) / (M - m + 1e-8)


@tf.function
def normalize_v2(x, axis=None):
    mean, variance = tf.nn.moments(x, axis, keepdims=True)
    std = tf.sqrt(variance)
    return (x - mean) / (std + 1e-8)


class HopfieldLayer(tf.keras.layers.Layer):

    def __init__(self, filters, kernel_size, solver, t,
                 kernel_initializer=GlorotUniform(1e-1),
                 **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.solver = solver
        self.t = t

        self.cnn = tf.keras.layers.Conv2D(
            filters, kernel_size, padding='same',
            kernel_initializer=kernel_initializer)

        @tf.function
        def pvf(t, x):
            z = self.cnn(x)
            with tf.GradientTape() as g:
                g.watch(x)
                r = normalize_v1(x, axis=[-3, -2])
            return g.gradient(r, x, z)

        self._pvf = pvf
        self._node_fn = get_node_function(solver, 0., pvf)

    def call(self, x):
        y = self._node_fn(self.t, x)
        return y


class RepeatLayer(tf.keras.layers.Layer):
    
    def __init__(self, n, **kwargs):
        super().__init__(**kwargs)
        self.n = n

    def call(self, x):
        return tf.stack([x] * self.n, axis=-1)

In [2]:
def get_compiled_model(num_filters, kernel_size, t, save_path=None):
    model = tf.keras.Sequential([
        tf.keras.layers.Input([14, 14]),
        RepeatLayer(num_filters),
        HopfieldLayer(num_filters, kernel_size, RK4Solver(0.1), t),
        tf.keras.layers.Conv2D(1, 1),
        tf.keras.layers.Reshape([14, 14])
    ])

    accuracy = tf.keras.metrics.BinaryAccuracy()
    model.compile(loss='mse', optimizer='adam', metrics=[accuracy])

    if save_path is not None:
        try:
            model.load_weights(save_path)
        except Exception as e:
            print(str(e))

    return model

In [3]:
num_filters = 1
kernel_size = 5
save_path = '../dat/tmp_weights/model_3'
t = 0.3

model = get_compiled_model(num_filters, kernel_size, t, save_path)
model.summary()

Unsuccessful TensorSliceReader constructor: Failed to find any matching files for ../dat/tmp_weights/model_3
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
repeat_layer (RepeatLayer)   (None, 14, 14, 1)         0         
_________________________________________________________________
hopfield_layer (HopfieldLaye (None, 14, 14, 1)         26        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 1)         2         
_________________________________________________________________
reshape (Reshape)            (None, 14, 14)            0         
Total params: 28
Trainable params: 28
Non-trainable params: 0
_________________________________________________________________


In [4]:
def add_noise(noise_scale, x):
    return tf.where(tf.random.uniform(x.shape) < noise_scale,
                     tf.ones_like(x) - x, x)

In [5]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

def process(x):
    x = tf.cast(x, tf.float32) / 255.
    x = tf.squeeze(
        tf.nn.max_pool2d(tf.expand_dims(x, -1),
                         2, 2, 'VALID'),
        -1)
    threshold = tf.ones_like(x) * 0.5
    x = tf.where(x > threshold,
                 tf.ones_like(x),
                 tf.zeros_like(x))
    return x

x_train = process(x_train)
x_test = process(x_test)

In [6]:
epochs = 10
noise_scale = 0.1

baseline = model.evaluate(add_noise(noise_scale, x_test), x_test, verbose=2)
print(f'baseline val_loss: {baseline}')

for i in range(epochs):
    print(f'Epoch {i}/{epochs}')
    model.fit(add_noise(noise_scale, x_train), x_train,
              validation_data=(add_noise(noise_scale, x_test), x_test))

10000/1 - 5s - loss: 0.2358 - binary_accuracy: 0.8001
baseline val_loss: [0.22011935119628906, 0.8001311]
Epoch 0/10
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Train on 60000 samples, validate on 10000 samples
Epoch 2/10
Train on 60000 samples, validate on 10000 samples
Epoch 3/10
Train on 60000 samples, validate on 10000 samples
Epoch 4/10
Train on 60000 samples, validate on 10000 samples
Epoch 5/10
Train on 60000 samples, validate on 10000 samples
Epoch 6/10
Train on 60000 samples, validate on 10000 samples
Epoch 7/10
Train on 60000 samples, validate on 10000 samples
Epoch 8/10
Train on 60000 samples, validate on 10000 samples
10240/60000 [====>.........................] - ETA: 1:21 - loss: 0.1386 - binary_accuracy: 0.8122

KeyboardInterrupt: 

In [7]:
eval_model = get_compiled_model(num_filters, kernel_size, 0.5, save_path)
print(eval_model.evaluate(add_noise(0.3, x_test), x_test, verbose=2))

eval_model = get_compiled_model(num_filters, kernel_size, 5., save_path)
print(eval_model.evaluate(add_noise(0.3, x_test), x_test, verbose=2))

eval_model = get_compiled_model(num_filters, kernel_size, 0.5, save_path)
print(eval_model.evaluate(add_noise(0.1, x_test), x_test, verbose=2))

eval_model = get_compiled_model(num_filters, kernel_size, 0., save_path)
print(eval_model.evaluate(add_noise(0.1, x_test), x_test, verbose=2))

10000/1 - 8s - loss: 0.1862 - binary_accuracy: 0.7907
[0.17881290814876558, 0.79070574]
10000/1 - 63s - loss: 5.0363 - binary_accuracy: 0.5519
[4.994714653778076, 0.55191267]
10000/1 - 8s - loss: 0.1559 - binary_accuracy: 0.8305
[0.14982584664821624, 0.8304916]

Two checkpoint references resolved to different objects (<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7f4a72486690> and <__main__.HopfieldLayer object at 0x7f4a72486e10>).
10000/1 - 0s - loss: 0.1615 - binary_accuracy: 0.8001
[0.1531829715013504, 0.8001311]


In [None]:
eval_model = get_compiled_model(num_filters, kernel_size, 0., save_path)
print(eval_model.evaluate(x_test, x_test, verbose=2))