<a href="https://colab.research.google.com/github/pchumphreys/Neural/blob/master/FeedbackAlignment_Investivations/FeedbackAlignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Testing feedback alignment approaches

Ok, so what is the plan?

1.) Load the CIFAR10 dataset
2.) Get standard BP working, with a standard FF net
3.) We need a duplicate set of weights for feedback. This could be straighforward if just duplicate the graph. For FA, basically just swap out the gradient rule so that using these weights as opposed to other weights. Also, need to work out how to apply gradients to other weights. Maybe instead of having a duplicate net, just duplicate the weight variables? Might be more elegant.
4.) I think that this should be sufficient for the first experiments. What do we want to do?

  a) Test FA vs BP, check matches
  b) Maybe look at the versions from the Ashok paper
  c) Learning FB weights as well
  d) Think carefully about the initialisation issue

### Util code

In [0]:
import copy
from sklearn.utils import shuffle

class Batcher():
    def __init__(self,x_data,y_data,mini_batch_size):
        self.epoch = 1
        self.mini_batch_size = mini_batch_size
        self.nbatches = int(len(x_data)/mini_batch_size)
        if len(x_data)%self.nbatches==0:
            self.perfect_sample = 1
        else:
            self.perfect_sample = 0
        self.samples = copy.deepcopy(x_data)
        self.targets = copy.deepcopy(y_data)
        self.samples, self.targets = shuffle(self.samples,self.targets)
        self.batch=0
    def get_mini_batch(self):
        if self.batch < (self.nbatches-self.perfect_sample):
            samples= self.samples[self.batch*self.mini_batch_size:((self.batch+1)*self.mini_batch_size)]
            targets= self.targets[self.batch*self.mini_batch_size:((self.batch+1)*self.mini_batch_size)]
            last_batch = False
        else:
            samples = self.samples[self.batch*self.mini_batch_size:]
            targets = self.targets[self.batch*self.mini_batch_size:]
            self.samples, self.targets = shuffle(self.samples,self.targets)
            self.epoch += 1
            self.batch = 0
            last_batch = True
        self.batch += 1
        return samples,targets,last_batch

## Custom gradient propagation

We need to be able to inject our own gradients!

Here is a code snippet for overriding gradient op:
```
@tf.custom_gradient
def clip_grad_layer(x):
  def grad(dy):
    return tf.clip_by_value(dy, -0.1, 0.1)
  return tf.identity(x), grad
```
So I guess this means we have to make a custom wrapper to define a layer, create the variables etc that will incorporate a custom matmul function.



In [0]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
import numpy as np

def feedback_matmul(x,weights,weights_fb,update_fb_weights = False, is_sparse = False):
  '''
  We want to use a custom matrix multiplication function that uses a separate feedback network for gradient calculation
  Note that cannot handle transpose due to limitations of custom_gradient code
  '''
  
  def _fb_helper(weight_grad):
    if update_fb_weights:
      return weight_grad
    else: 
      return None
  
  def _forward_helper(x,weights):
    if is_sparse:
      return sparse_matmul(x, weights, b_is_sparse=True)
    else:
      return tf.matmul(x,weights)
    
  @tf.custom_gradient
  def matmul_function(x,weights,weights_fb):
    def grad(dzdy):
      weight_grad = tf.matmul(dzdy,x,transpose_a=True)
      weight_grad_fb = _fb_helper(weight_grad)
      return tf.matmul(dzdy,weights_fb),weight_grad,weight_grad_fb # Since matrix multiplication has two args, need two gradients, but we don't care about propagating through the weight
    return _forward_helper(x,weights), grad
  
  return matmul_function(x,weights,weights_fb)


# Here is our fully connected layer!
def feedback_fc(input_tensor,output_size,mode = 'BP',scope=None,reuse = None,weights_initializer = tf.initializers.glorot_normal(),weights_fb_initializer = tf.initializers.glorot_normal(),biases_initializer=init_ops.zeros_initializer(),activation_fn=nn.relu):
  if activation_fn == None:
    activation_fn = tf.identity
  with tf.variable_scope(scope,reuse=reuse):
    weights = slim.model_variable('weights',shape = [input_tensor.shape[1],output_size],initializer=weights_initializer)
    weights_fb = slim.model_variable('weights_fb',shape = [input_tensor.shape[1],output_size],initializer=weights_fb_initializer)
    biases = slim.model_variable('biases',shape = [output_size],initializer=biases_initializer)
    if mode == 'BP':
      return activation_fn(tf.matmul(input_tensor,weights) + biases)
    elif mode == 'FA':
      return activation_fn(feedback_matmul(input_tensor,weights,weights_fb) + biases)
    elif mode == 'FA_update_both':
      return activation_fn(feedback_matmul(input_tensor,weights,weights_fb,update_fb_weights = True) + biases)
    else:
      raise ValueError('Undefined mode {}'.format(mode))

Run some tests to check makes sense

In [0]:
tf.reset_default_graph() # THIS IS NECESSARY BEFORE MAKING NEW SESSION TO STOP IT ERRORING!!

if not(tf.get_default_session() is None):
    tf.get_default_session().close()

sess = tf.InteractiveSession()

inputs = tf.placeholder(tf.float32,shape=[None, 1],name='inputs')
x = feedback_fc(inputs,1,scope='fc',mode='FA',activation_fn=None)

tf.global_variables_initializer().run()

test_inputs = [[1.0],[2.0]]
forward = sess.run(x,feed_dict = {inputs:test_inputs})

weight = [var for var in tf.global_variables() if var.name == 'fc/weights:0'][0].eval()[[0]]

assert (forward/test_inputs==weight).all()

grad = tf.gradients(x,inputs)[0].eval(feed_dict = {inputs:test_inputs})
fb_weight = [var for var in tf.global_variables() if var.name == 'fc/weights_fb:0'][0].eval()[[0]]
assert (grad == fb_weight).all()
        

## MNIST

Start with MNIST, check that works

In [0]:
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255
y_train = y_train.astype('int32')
y_test = y_test.astype('int32')

### Define the network to replicate DFA paper - arxiv 1609.01596

#### Test BP with RELU

In [0]:
tf.reset_default_graph() # THIS IS NECESSARY BEFORE MAKING NEW SESSION TO STOP IT ERRORING!!

if not(tf.get_default_session() is None):
    tf.get_default_session().close()

sess = tf.InteractiveSession()

network_mode = 'BP'

inputs = tf.placeholder(tf.float32,shape=[None, 28,28],name='inputs')
reshaped_inputs = tf.reshape(inputs,[-1,28*28])
training = tf.placeholder(tf.bool)

n_categories = 10 

dropout_rate = 0.0
layers = [800,800]
activation = tf.nn.relu

x = reshaped_inputs
n_prev = x.shape[1].value

for i in range(len(layers)):
  if i == len(layers)-1:
    n_next = n_categories
  else:
    n_next = layers[i+1]
    
  weights_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_prev),maxval = 1/np.sqrt(n_prev))
  weights_fb_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_next),maxval = 1/np.sqrt(n_next))
  biases_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_prev),maxval = 1/np.sqrt(n_prev))

  x = feedback_fc(x,layers[i],scope='fc_' + str(i),mode=network_mode,activation_fn=activation,
                 weights_initializer = weights_initializer, biases_initializer = biases_initializer)
  x = tf.layers.dropout(x,rate = dropout_rate, training = training)
  
  n_prev = x.shape[1].value
  
logits = feedback_fc(x,n_categories,scope='soft',mode=network_mode,activation_fn=None)

targets = tf.placeholder(tf.int64,shape=[None],name='targets')
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.squeeze(targets),logits = logits),axis=-1)
error = tf.reduce_mean(tf.cast(tf.logical_not(tf.equal(tf.argmax(logits,axis=-1),tf.squeeze(targets))),'float'))
train_op = tf.train.RMSPropOptimizer(learning_rate = 1e-3).minimize(loss)

tf.summary.scalar('loss',loss)
tf.summary.scalar('error',error)

merged = tf.summary.merge_all()
tf.global_variables_initializer().run()

Run

In [16]:
mini_batch_size = 64

b = Batcher(x_train,y_train,mini_batch_size)
max_epochs = 100
test_epoch_int = 10

print("Batches per epoch %d" %b.nbatches)
while b.epoch <= max_epochs:
    batch_x,batch_y, new_epoch = b.get_mini_batch()
    sess.run([train_op],feed_dict = {inputs : batch_x, targets:batch_y, training : True})
    if new_epoch:
        batch_err,summary = sess.run([error,merged],feed_dict = {inputs : batch_x, targets:batch_y, training : False})
#         train_writer.add_summary(summary)
        if (b.epoch-1) % test_epoch_int == 0:
          print("Train error is {}% at epoch {}".format(batch_err*100,b.epoch-1))
          test_err,summary = sess.run([error,merged],feed_dict = {inputs : x_test, targets : y_test, training : False})
          print("Test error is {}% at epoch {}".format(test_err*100,b.epoch-1))

Batches per epoch 937
Train error is 0.0% at epoch 10
Test error is 1.8400000408291817% at epoch 10
Train error is 0.0% at epoch 20
Test error is 1.9200000911951065% at epoch 20
Train error is 0.0% at epoch 30
Test error is 1.8799999728798866% at epoch 30
Train error is 0.0% at epoch 40
Test error is 1.8400000408291817% at epoch 40
Train error is 0.0% at epoch 50
Test error is 1.5799999237060547% at epoch 50
Train error is 0.0% at epoch 60
Test error is 1.6200000420212746% at epoch 60
Train error is 0.0% at epoch 70
Test error is 1.489999983459711% at epoch 70


KeyboardInterrupt: ignored

```
Batches per epoch 937
Train error is 0.0% at epoch 10
Test error is 2.15000007301569% at epoch 10
Train error is 0.0% at epoch 20
Test error is 1.7100000753998756% at epoch 20
Train error is 0.0% at epoch 30
Test error is 1.7999999225139618% at epoch 30
Train error is 0.0% at epoch 40
Test error is 1.510000042617321% at epoch 40
Train error is 0.0% at epoch 50
Test error is 1.5300000086426735% at epoch 50
Train error is 0.0% at epoch 60
Test error is 1.489999983459711% at epoch 60
Train error is 0.0% at epoch 70
Test error is 1.4800000004470348% at epoch 70
Train error is 0.0% at epoch 80
Test error is 1.4800000004470348% at epoch 80
```

#### Test feedback alignment

In [0]:
tf.reset_default_graph() # THIS IS NECESSARY BEFORE MAKING NEW SESSION TO STOP IT ERRORING!!

if not(tf.get_default_session() is None):
    tf.get_default_session().close()

sess = tf.InteractiveSession()

network_mode = 'FA'

inputs = tf.placeholder(tf.float32,shape=[None, 28,28],name='inputs')
reshaped_inputs = tf.reshape(inputs,[-1,28*28])
training = tf.placeholder(tf.bool)

n_categories = 10 

dropout_rate = 0.0
layers = [800,800]
activation = tf.nn.relu

x = reshaped_inputs
n_prev = x.shape[1].value

for i in range(len(layers)):
  if i == len(layers)-1:
    n_next = n_categories
  else:
    n_next = layers[i+1]
    
  weights_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_prev),maxval = 1/np.sqrt(n_prev))
  weights_fb_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_next),maxval = 1/np.sqrt(n_next))
  biases_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_prev),maxval = 1/np.sqrt(n_prev))

  x = feedback_fc(x,layers[i],scope='fc_' + str(i),mode=network_mode,activation_fn=activation,
                 weights_initializer = weights_initializer, biases_initializer = biases_initializer)
  x = tf.layers.dropout(x,rate = dropout_rate, training = training)
  
  n_prev = x.shape[1].value
  
logits = feedback_fc(x,n_categories,scope='soft',mode=network_mode,activation_fn=None)

targets = tf.placeholder(tf.int64,shape=[None],name='targets')
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.squeeze(targets),logits = logits),axis=-1)
error = tf.reduce_mean(tf.cast(tf.logical_not(tf.equal(tf.argmax(logits,axis=-1),tf.squeeze(targets))),'float'))
train_op = tf.train.RMSPropOptimizer(learning_rate = 1e-4).minimize(loss)

tf.summary.scalar('loss',loss)
tf.summary.scalar('error',error)

merged = tf.summary.merge_all()
tf.global_variables_initializer().run()

Run

In [104]:
mini_batch_size = 64

b = Batcher(x_train,y_train,mini_batch_size)
max_epochs = 100
test_epoch_int = 10

print("Batches per epoch %d" %b.nbatches)
while b.epoch <= max_epochs:
    batch_x,batch_y, new_epoch = b.get_mini_batch()
    sess.run([train_op],feed_dict = {inputs : batch_x, targets:batch_y, training : True})
    if new_epoch:
        batch_err,summary = sess.run([error,merged],feed_dict = {inputs : batch_x, targets:batch_y, training : False})
#         train_writer.add_summary(summary)
        if (b.epoch-1) % test_epoch_int == 0:
          print("Train error is {}% at epoch {}".format(batch_err*100,b.epoch-1))
          test_err,summary = sess.run([error,merged],feed_dict = {inputs : x_test, targets : y_test, training : False})
          print("Test error is {}% at epoch {}".format(test_err*100,b.epoch-1))

Batches per epoch 937
Train error is 3.125% at epoch 10
Test error is 2.4399999529123306% at epoch 10
Train error is 0.0% at epoch 20
Test error is 1.9200000911951065% at epoch 20
Train error is 0.0% at epoch 30
Test error is 1.7999999225139618% at epoch 30
Train error is 0.0% at epoch 40
Test error is 1.9600000232458115% at epoch 40
Train error is 0.0% at epoch 50
Test error is 1.7999999225139618% at epoch 50
Train error is 0.0% at epoch 60
Test error is 1.7799999564886093% at epoch 60
Train error is 0.0% at epoch 70
Test error is 1.7400000244379044% at epoch 70
Train error is 0.0% at epoch 80
Test error is 1.8200000748038292% at epoch 80
Train error is 0.0% at epoch 90
Test error is 1.7599999904632568% at epoch 90
Train error is 0.0% at epoch 100
Test error is 1.8400000408291817% at epoch 100
Train error is 0.0% at epoch 110
Test error is 1.769999973475933% at epoch 110
Train error is 0.0% at epoch 120
Test error is 1.7799999564886093% at epoch 120
Train error is 0.0% at epoch 130
Te

### Test FA, both learn

In [0]:
tf.reset_default_graph() # THIS IS NECESSARY BEFORE MAKING NEW SESSION TO STOP IT ERRORING!!

if not(tf.get_default_session() is None):
    tf.get_default_session().close()

sess = tf.InteractiveSession()

network_mode = 'FA_update_both'

inputs = tf.placeholder(tf.float32,shape=[None, 28,28],name='inputs')
reshaped_inputs = tf.reshape(inputs,[-1,28*28])
training = tf.placeholder(tf.bool)

n_categories = 10 

dropout_rate = 0.0
layers = [800,800]
activation = tf.nn.relu

x = reshaped_inputs
n_prev = x.shape[1].value

for i in range(len(layers)):
  if i == len(layers)-1:
    n_next = n_categories
  else:
    n_next = layers[i+1]
    
  weights_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_prev),maxval = 1/np.sqrt(n_prev))
  weights_fb_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_next),maxval = 1/np.sqrt(n_next))
  biases_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_prev),maxval = 1/np.sqrt(n_prev))

  x = feedback_fc(x,layers[i],scope='fc_' + str(i),mode=network_mode,activation_fn=activation,
                 weights_initializer = weights_initializer, biases_initializer = biases_initializer)
  x = tf.layers.dropout(x,rate = dropout_rate, training = training)
  
  n_prev = x.shape[1].value
  
logits = feedback_fc(x,n_categories,scope='soft',mode=network_mode,activation_fn=None)

targets = tf.placeholder(tf.int64,shape=[None],name='targets')
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.squeeze(targets),logits = logits),axis=-1)
error = tf.reduce_mean(tf.cast(tf.logical_not(tf.equal(tf.argmax(logits,axis=-1),tf.squeeze(targets))),'float'))
train_op = tf.train.RMSPropOptimizer(learning_rate = 1e-4).minimize(loss)

tf.summary.scalar('loss',loss)
tf.summary.scalar('error',error)

merged = tf.summary.merge_all()
tf.global_variables_initializer().run()

Run

In [115]:
mini_batch_size = 64

b = Batcher(x_train,y_train,mini_batch_size)
max_epochs = 100
test_epoch_int = 10

print("Batches per epoch %d" %b.nbatches)
while b.epoch <= max_epochs:
    batch_x,batch_y, new_epoch = b.get_mini_batch()
    sess.run([train_op],feed_dict = {inputs : batch_x, targets:batch_y, training : True})
    if new_epoch:
        batch_err,summary = sess.run([error,merged],feed_dict = {inputs : batch_x, targets:batch_y, training : False})
#         train_writer.add_summary(summary)
        if (b.epoch-1) % test_epoch_int == 0:
          print("Train error is {}% at epoch {}".format(batch_err*100,b.epoch-1))
          test_err,summary = sess.run([error,merged],feed_dict = {inputs : x_test, targets : y_test, training : False})
          print("Test error is {}% at epoch {}".format(test_err*100,b.epoch-1))

Batches per epoch 937
Train error is 0.0% at epoch 10
Test error is 2.319999970495701% at epoch 10
Train error is 0.0% at epoch 20
Test error is 1.8899999558925629% at epoch 20
Train error is 0.0% at epoch 30
Test error is 1.9099999219179153% at epoch 30
Train error is 0.0% at epoch 40
Test error is 1.8400000408291817% at epoch 40
Train error is 0.0% at epoch 50
Test error is 1.8600000068545341% at epoch 50
Train error is 0.0% at epoch 60
Test error is 1.8400000408291817% at epoch 60


KeyboardInterrupt: ignored

## CIFAR10

### Load in the data

In [0]:
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255
y_train = y_train.astype('int32')
y_test = y_test.astype('int32')

### Vanilla BP

In [0]:
tf.reset_default_graph() # THIS IS NECESSARY BEFORE MAKING NEW SESSION TO STOP IT ERRORING!!

if not(tf.get_default_session() is None):
    tf.get_default_session().close()

sess = tf.InteractiveSession()

network_mode = 'BP'

inputs = tf.placeholder(tf.float32,shape=[None, 32,32,3],name='inputs')
reshaped_inputs = tf.reshape(inputs,[-1,32*32*3])
training = tf.placeholder(tf.bool)

n_categories = 10 

dropout_rate = 0.1
layers = [1024,1024,1024]
activation = tf.nn.tanh

x = reshaped_inputs

for i in range(len(layers)):
#   weights_fb_initializer = tf.initializers.random_uniform(minval = -1/np.sqrt(n_next),maxval = 1/np.sqrt(n_next))
 
  x = feedback_fc(x,layers[i],scope='fc_' + str(i),mode=network_mode,activation_fn=activation)
  x = tf.layers.dropout(x,rate = dropout_rate, training = training)
  
logits = feedback_fc(x,n_categories,scope='soft',mode=network_mode,activation_fn=None)

targets = tf.placeholder(tf.int64,shape=[None,1],name='targets')
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.squeeze(targets),logits = logits),axis=-1)
error = tf.reduce_mean(tf.cast(tf.logical_not(tf.equal(tf.argmax(logits,axis=-1),tf.squeeze(targets))),'float'))
train_op = tf.train.AdamOptimizer(learning_rate = 2e-5).minimize(loss)

tf.summary.scalar('loss',loss)
tf.summary.scalar('error',error)

merged = tf.summary.merge_all()
tf.global_variables_initializer().run()

Run

In [129]:
mini_batch_size = 64

b = Batcher(x_train,y_train,mini_batch_size)
max_epochs = 300
test_epoch_int = 10

print("Batches per epoch %d" %b.nbatches)
while b.epoch <= max_epochs:
    batch_x,batch_y, new_epoch = b.get_mini_batch()
    sess.run([train_op],feed_dict = {inputs : batch_x, targets:batch_y, training : True})
    if new_epoch:
        batch_err,summary = sess.run([error,merged],feed_dict = {inputs : batch_x, targets:batch_y, training : False})
#         train_writer.add_summary(summary)
        if (b.epoch-1) % test_epoch_int == 0:
          print("Train error is {}% at epoch {}".format(batch_err*100,b.epoch-1))
          test_err,summary = sess.run([error,merged],feed_dict = {inputs : x_test, targets : y_test, training : False})
          print("Test error is {}% at epoch {}".format(test_err*100,b.epoch-1))

Batches per epoch 781
Train error is 50.0% at epoch 10
Test error is 53.039997816085815% at epoch 10
Train error is 37.5% at epoch 20
Test error is 50.24999976158142% at epoch 20
Train error is 31.25% at epoch 30
Test error is 47.65999913215637% at epoch 30
Train error is 50.0% at epoch 40
Test error is 46.93000018596649% at epoch 40
Train error is 25.0% at epoch 50
Test error is 45.100000500679016% at epoch 50
Train error is 25.0% at epoch 60
Test error is 44.65000033378601% at epoch 60
Train error is 31.25% at epoch 70
Test error is 43.84999871253967% at epoch 70
Train error is 31.25% at epoch 80
Test error is 44.200000166893005% at epoch 80
Train error is 25.0% at epoch 90
Test error is 43.25999915599823% at epoch 90
Train error is 31.25% at epoch 100
Test error is 42.71000027656555% at epoch 100
Train error is 12.5% at epoch 110
Test error is 42.809998989105225% at epoch 110
Train error is 12.5% at epoch 120
Test error is 42.64000058174133% at epoch 120


KeyboardInterrupt: ignored

```
Batches per epoch 781
Train error is 50.0% at epoch 10
Test error is 53.039997816085815% at epoch 10
Train error is 37.5% at epoch 20
Test error is 50.24999976158142% at epoch 20
Train error is 31.25% at epoch 30
Test error is 47.65999913215637% at epoch 30
Train error is 50.0% at epoch 40
Test error is 46.93000018596649% at epoch 40
Train error is 25.0% at epoch 50
Test error is 45.100000500679016% at epoch 50
Train error is 25.0% at epoch 60
Test error is 44.65000033378601% at epoch 60
Train error is 31.25% at epoch 70
Test error is 43.84999871253967% at epoch 70
Train error is 31.25% at epoch 80
Test error is 44.200000166893005% at epoch 80
Train error is 25.0% at epoch 90
Test error is 43.25999915599823% at epoch 90
Train error is 31.25% at epoch 100
Test error is 42.71000027656555% at epoch 100
Train error is 12.5% at epoch 110
Test error is 42.809998989105225% at epoch 110
Train error is 12.5% at epoch 120
Test error is 42.64000058174133% at epoch 120```

## Locally connected..

In [19]:
tf.keras.layers.LocallyConnected2D

tensorflow.python.keras.layers.local.LocallyConnected2D

In [0]:
class LocallyConnected2D(tf.keras.layers.LocallyConnected2D):
  
  def build(self, input_shape):
    tf.keras.layers.LocallyConnected2D.build(self,input_shape)
    
    self.kernel_fb = self.add_weight(shape=self.kernel_shape,
                                        initializer=self.kernel_initializer,
                                        name='kernel',
                                        regularizer=self.kernel_regularizer,
                                        constraint=self.kernel_constraint)
  def call(self, inputs):
    if self.implementation == 2:
      output = local_conv_matmul_fb(inputs, self.kernel, self.kernel_fb,
                                 self.kernel_mask,
                                 self.compute_output_shape(inputs.shape))

    else:
      raise ValueError('Unsupported implementation mode: %d.'
                       % self.implementation)

    if self.use_bias:
      output = K.bias_add(output, self.bias, data_format=self.data_format)

    output = self.activation(output)
    return output
  
def local_conv_matmul_fb(inputs, kernel, kernel_fb, kernel_mask, output_shape):
  
  inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1))

  kernel = kernel_mask * kernel
  kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2)

  output_flat = feedback_matmul(inputs_flat,kernel,kernel_fb,update_fb_weights = False, is_sparse = True)
  
  output = K.reshape(output_flat,
                     [K.shape(output_flat)[0],] + output_shape.as_list()[1:])
  return output



