In [1]:

'''Trains a simple deep NN on the MNIST dataset.
Gets to 98.40% test accuracy after 20 epochs
(there is *a lot* of margin for parameter tuning).
2 seconds per epoch on a K520 GPU.
'''

from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import RMSprop


from keras import backend as K
from keras import layers
import numpy as np
import tensorflow as tf

class NormalDensity(layers.Layer):

    def __init__(self, output_dim, batch_size, **kwargs):
        self.output_dim = output_dim
        self.batch_size = batch_size
        super(NormalDensity, self).__init__(**kwargs)
        
    def build(self, input_shape):
        
        #Define set of traditional weights
        self.w = self.add_weight(name='w', 
                                 shape=(input_shape[1], self.output_dim),
                                 initializer='uniform',
                                 trainable=True)
        
        #Define our plasticity coefficient
        self.alpha = self.add_weight(name='alpha', 
                              shape=(1, 1),
                              initializer='uniform',
                              trainable=False)
        
        
        #The Hebbian trace
        self.hebb = self.add_weight(name='hebb', 
                              shape=(input_shape[1], self.output_dim),
                              initializer='zeros',
                              trainable=False)
        
        #Step size will be optimized
        self.eta = self.add_weight(name='eta', 
                                      shape=(1, 1),
                                      initializer='uniform',
                                      trainable=False)
        super(NormalDensity, self).build(input_shape)

        
        #yout = F.tanh( yin.mm(self.w + torch.mul(self.alpha, hebb)) + input )
        #hebb = (1 - self.eta) * hebb + self.eta * torch.bmm(yin.unsqueeze(2), yout.unsqueeze(1))[0] # bmm here is used to implement an outer product between yin and yout, with the help of unsqueeze (i.e. added empty dimensions)
        #return yout, hebb

    def call(self, x):
        
        #X (layer input)     : shape(?, INPUT_DIM)
        #W                   : shape(INPUT_DIM, OUTPUT_DIM)
        #hebb                : shape(INPUT_DIM, OUTPUT_DIM)
        #Y (layer output)    : shape(?, OUTPUT_DIM)
        #ETA                 : scalar (one per layer)
        
        #yout = K.maximum(0.0, np.add((K.dot(self.y, np.add(K.dot(self.alpha, K.transpose(self.hebb)), self.w))), x))
        #hebb = (1 - 0.01) * self.hebb + 0.01 * K.dot(self.y, yout)
        #yout = K.maximum(0.0, np.add(self.y * np.add(self.alpha * self.hebb, self.w), x))
        
        y = K.dot(x, self.w)
        plastic_y = self.alpha * (K.dot(x, self.hebb))   
        model_out = K.maximum(0.0, y + plastic_y)
        
        #Hebbian update - option 1
        self.hebb = self.eta * K.dot(x, model_out) + (1 - self.eta) * self.hebb
        print(self.hebb)
        
        #Hebbian update - option 2
        #self.hebb +=self.eta * K.dot(model_out, (x - (K.dot(model_out, self.hebb))))

        return model_out
        #return K.maximum(0.0, y)
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)




batch_size = 128
num_classes = 10
epochs = 16

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255.0
x_test /= 255.0
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(NormalDensity(50, batch_size, input_shape=(784,)))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

model.summary()

model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


60000 train samples
10000 test samples
Tensor("normal_density_1/add_1:0", shape=(784, 50), dtype=float32)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
normal_density_1 (NormalDens (None, 50)                78402     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                510       
_________________________________________________________________
activation_1 (Activation)    (None, 10)                0         
Total params: 78,912
Trainable params: 39,710
Non-trainable params: 39,202
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/16


InternalError: Blas GEMM launch failed : a.shape=(128, 784), b.shape=(784, 50), m=128, n=50, k=784
	 [[Node: normal_density_1/MatMul_1 = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"](_arg_normal_density_1_input_0_2/_27, normal_density_1/hebb/read)]]
	 [[Node: training/RMSprop/gradients/AddN/_65 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_192_training/RMSprop/gradients/AddN", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'normal_density_1/MatMul_1', defined at:
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\ipykernel\kernelapp.py", line 486, in start
    self.io_loop.start()
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tornado\platform\asyncio.py", line 127, in start
    self.asyncio_loop.run_forever()
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\asyncio\base_events.py", line 422, in run_forever
    self._run_once()
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\asyncio\base_events.py", line 1432, in _run_once
    handle._run()
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\asyncio\events.py", line 145, in _run
    self._callback(*self._args)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tornado\platform\asyncio.py", line 117, in _handle_events
    handler_func(fileobj, events)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tornado\stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\zmq\eventloop\zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\zmq\eventloop\zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\zmq\eventloop\zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tornado\stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\ipykernel\kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\ipykernel\ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\ipykernel\zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\IPython\core\interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\IPython\core\interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\IPython\core\interactiveshell.py", line 2903, in run_ast_nodes
    if self.run_code(code, result):
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\IPython\core\interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-830b41d4ec97>", line 115, in <module>
    model.add(NormalDensity(50, batch_size, input_shape=(784,)))
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\keras\models.py", line 497, in add
    layer(x)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\keras\engine\topology.py", line 619, in __call__
    output = self.call(inputs, **kwargs)
  File "<ipython-input-1-830b41d4ec97>", line 75, in call
    plastic_y = self.alpha * (K.dot(x, self.hebb))
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\keras\backend\tensorflow_backend.py", line 1076, in dot
    out = tf.matmul(x, y)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\ops\math_ops.py", line 2122, in matmul
    a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 4567, in mat_mul
    name=name)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\framework\ops.py", line 3392, in create_op
    op_def=op_def)
  File "c:\users\bjarn\appdata\local\programs\python\python36\lib\site-packages\tensorflow\python\framework\ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InternalError (see above for traceback): Blas GEMM launch failed : a.shape=(128, 784), b.shape=(784, 50), m=128, n=50, k=784
	 [[Node: normal_density_1/MatMul_1 = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"](_arg_normal_density_1_input_0_2/_27, normal_density_1/hebb/read)]]
	 [[Node: training/RMSprop/gradients/AddN/_65 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_192_training/RMSprop/gradients/AddN", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
