In [None]:

'''Trains a simple deep NN on the MNIST dataset.
Gets to 98.40% test accuracy after 20 epochs
(there is *a lot* of margin for parameter tuning).
2 seconds per epoch on a K520 GPU.
'''

from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import RMSprop


from keras import backend as K
from keras import layers
import numpy as np

class NormalDensity(layers.Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(NormalDensity, self).__init__(**kwargs)
        
    def build(self, input_shape):
        self.w = self.add_weight(name='w', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        self.alpha = self.add_weight(name='alpha', 
                              shape=(input_shape[1], self.output_dim),
                              initializer='uniform',
                              trainable=False)
        self.y = self.add_weight(name='y', 
                              shape=(input_shape[1], self.output_dim),
                              initializer='zeros',
                              trainable=False)
        self.hebb = self.add_weight(name='hebb', 
                              shape=(input_shape[1], self.output_dim),
                              initializer='zeros',
                              trainable=False)
        self.eta = self.add_weight(name='eta', 
                                      shape=(input_shape[1], 1),
                                      initializer='uniform',
                                      trainable=False)
        super(NormalDensity, self).build(input_shape)

        
        #yout = F.tanh( yin.mm(self.w + torch.mul(self.alpha, hebb)) + input )
        #hebb = (1 - self.eta) * hebb + self.eta * torch.bmm(yin.unsqueeze(2), yout.unsqueeze(1))[0] # bmm here is used to implement an outer product between yin and yout, with the help of unsqueeze (i.e. added empty dimensions)
        #return yout, hebb
    def call(self, x):
        yout = K.maximum(0.0, (K.dot(self.y, np.add(K.dot(self.alpha, K.transpose(self.hebb)), self.w))) + x)
        print(yout.shape)
        print(K.dot(self.y, yout).shape)
        hebb = (1 - 0.01) * self.hebb + 0.01 * K.dot(self.y, yout)
        self.y = yout
        self.hebb = hebb
        return yout
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)




batch_size = 1
num_classes = 10
epochs = 20

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(NormalDensity(784, input_shape=(784,)))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

model.summary()

model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

60000 train samples
10000 test samples
(784, 784)
(784, 784)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
normal_density_1 (NormalDens (None, 784)               2459408   
_________________________________________________________________
dense_1 (Dense)              (None, 10)                7850      
_________________________________________________________________
activation_1 (Activation)    (None, 10)                0         
Total params: 2,467,258
Trainable params: 622,506
Non-trainable params: 1,844,752
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
