In [None]:
import tensorflow as tf
import numpy as np
from keras.layers import Attention, Input, Conv2D, Flatten, Dense, Multiply, Activation, Lambda, Permute, RepeatVector, merge
from keras.models import Model
from keras.optimizers import Adam
from keras.losses import MeanSquaredError
from keras import backend as K

In [47]:
img_batch = np.array([[
                [1,1,1,1],
                [1,5,5,1],
                [1,1,1,1]
                ],
                [
                [1,1,1,1],
                [1,5,5,1],
                [1,1,1,1]
                ]]
            )
img_batch = img_batch.reshape((2,3,4,1))
print(img_batch)

[[[[1]
   [1]
   [1]
   [1]]

  [[1]
   [5]
   [5]
   [1]]

  [[1]
   [1]
   [1]
   [1]]]


 [[[1]
   [1]
   [1]
   [1]]

  [[1]
   [5]
   [5]
   [1]]

  [[1]
   [1]
   [1]
   [1]]]]


### Self-Attention based implementation with Attention() layer

In [7]:
inp = Input(shape=(3,4,1))

query = Conv2D(filters=1, kernel_size=1, use_bias=False)(inp) #filters = in_dimension//8
key = Conv2D(filters=1, kernel_size=1, use_bias=False)(inp) #filters = in_dimension//8 ==> TRANSPOSE per Torrado et al. 2019? --> transpose seems not necessary as the matmul() operation within Attention() appears to transpose automatically. Paper seems to transpose QUERY rather than KEY!!
value = Conv2D(filters=1, kernel_size=1, use_bias=False)(inp)

out = Attention(use_scale=True)([query, key, value]) #scale adds a learnable parameter applied to the attention scores
out = Activation("sigmoid")(out) #necessary? att 1x1 conv?
test = Model(inp, out)

test.summary()

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 3, 4, 1)]    0                                            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 3, 4, 1)      2           input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 3, 4, 1)      2           input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 3, 4, 1)      2           input_3[0][0]                    
_______________________________________________________________________________________

### Implementation Closer to Zhang et al. 2019

In [19]:
class LinScaler(tf.keras.layers.Layer):
    def __init__(self):
        super(LinScaler, self).__init__()
        self.scale = tf.Variable(trainable=True, name='AttentionMap_ScaleFactor', initial_value=0.)
        
    def call(self, inputs):
        att_map = inputs[0]
        att_input = inputs[1]
        return self.scale * att_map + att_input

inp = Input(shape=(3,4,1))

#linear combinations implemented as convolutional layer (input to self-attention layer)
query = Conv2D(filters=inp.shape[-1]/8, kernel_size=1)(inp) #filters = in_dimension//8
key = Conv2D(filters=inp.shape[-1]/8, kernel_size=1)(inp) #filters = in_dimension//8 ==> TRANSPOSE per Torrado et al. 2019? --> transpose seems not necessary as the matmul() operation within Attention() appears to transpose the second parameters automatically. Paper seems to transpose QUERY rather than KEY!!
value = Conv2D(filters=inp.shape[-1], kernel_size=1)(inp)

#scale necessary? not necessarily equal to scale parameter in Zhang et al.?
out = Attention(use_scale=True)([key, query, value]) #switched key + query since the second parameter gets transposed (query in Torrado et al.)
out = Conv2D(filters=out.shape[-1], kernel_size=1)(out) #output of Attention layer

#scaling and adding initial feature map
#scale = tf.Variable(trainable=True, name='AttentionMap_ScaleFactor', initial_value=0.) #this variable does not appear to be tracked --> create custom layer by subclassing tf.keras.layers.Layer: https://keras.io/api/layers/core_layers/lambda/
#out = scale * out + inp
out = LinScaler()([out, inp])

test = Model(inp, out)

test.summary()


Model: "functional_17"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_13 (InputLayer)           [(None, 3, 4, 1)]    0                                            
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 3, 4, 1)      2           input_13[0][0]                   
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 3, 4, 1)      2           input_13[0][0]                   
__________________________________________________________________________________________________
conv2d_47 (Conv2D)              (None, 3, 4, 1)      2           input_13[0][0]                   
______________________________________________________________________________________

## ABOVE DOES NOT WORK DUE TO TRANSPOSE / MATMUL NOT BEING COMMUTATIVE

In [None]:
# custom layer implementation self-attention
class SelfAttention(tf.keras.layers.Layer):
    def __init__(self, dim):
        super(LinScaler, self).__init__()
        self.dim = dim
        self.k = 8
        self.scale = tf.Variable(trainable=True, name='AttentionMap_ScaleFactor', initial_value=0.)
        
        # linear combination of input to query f(x), key g(x) and value h(x)
        self.query = Conv2D(filters=self.dim[-1]/self.k, kernel_size=1)(inputs) #filters = in_dimension//8
        self.key = Conv2D(filters=self.dim[-1]/self.k, kernel_size=1)(inputs) #filters = in_dimension//8 ==> TRANSPOSE per Torrado et al. 2019? --> transpose seems not necessary as the matmul() operation within Attention() appears to transpose the second parameters automatically. Paper seems to transpose QUERY rather than KEY!!
        self.value = Conv2D(filters=self.dim[-1], kernel_size=1)(inputs)
        
    def call(self, inputs):
               
        # matmul transposed query with key
        p_query = tf.transpose(self.query, perm=[0,2,1]))
        attention = K.batch_dot(p_query, self.key)
        attention = Activation('sigmoid')(attention)
        
        out = K.batch_dot(self.value, tf.transpose(attention, perm=[0,2,1]))
        
        return self.scale * out + self.value


### multi-headed attention with query, key and value are the same
== self-attent see https://keras.io/api/layers/attention_layers/multi_head_attention/

--> only available in tf 2.4.1!

In [109]:
test(img_batch)

<tf.Tensor: shape=(2, 3, 4, 1), dtype=float32, numpy=
array([[[[0.18451074],
         [0.18451074],
         [0.18451074],
         [0.18451074]],

        [[0.00062261],
         [0.0005926 ],
         [0.0005926 ],
         [0.00062261]],

        [[0.18451074],
         [0.18451074],
         [0.18451074],
         [0.18451074]]],


       [[[0.18451074],
         [0.18451074],
         [0.18451074],
         [0.18451074]],

        [[0.00062261],
         [0.0005926 ],
         [0.0005926 ],
         [0.00062261]],

        [[0.18451074],
         [0.18451074],
         [0.18451074],
         [0.18451074]]]], dtype=float32)>

In [84]:
# checking against manual implementation
input1 = Input(shape=(3,4,1))
e=Dense(1, activation='tanh')(input1)
# Now do all the softmax business taking the above o/p
e=Flatten()(e)
a=Activation('softmax')(e)
temp=RepeatVector(256)(a)
print(temp)
temp=Permute([2,1])(temp)
print(temp)
# multiply weight with lstm layer o/p
output = merge.Multiply()([input1, temp])
# Get the attention adjusted output state
output = Lambda(lambda values: K.sum(values, axis=1))(output)

test2 = Model(input1, output)

Tensor("repeat_vector_28/Tile:0", shape=(None, 256, 12), dtype=float32)
Tensor("permute_26/transpose:0", shape=(None, 12, 256), dtype=float32)


ValueError: Operands could not be broadcast together with shapes (3, 4, 1) (12, 256)