In [14]:
import tensorflow as tf
from keras import backend as K
from keras.engine.topology import Layer

from keras import layers, models

import numpy as np

In [None]:
def aggregate(W, R):
    ''' Where R_ij = X_i - C_j
    
    Compute: E_k = Sum_i (W_ik*(X_i - C_k))
    '''
    pass

def scaledL2(R, S): 
    ''' Where R_ij = X_i - C_j:
    
    Compute: W_ik = S_k * ||X_i - C_k||^2
    '''
    pass

In [87]:
class Encoding(Layer):
    '''Dictionary encoding layer with learnable codebook and residual weights.

    Encodes an collection of N features of size D into a KxD representation.
    
    If input.shape = (N, D) and num_codes = K, then
        codebook.shape = (K, D)
        residual.shape = (N, K, D)
        weights.shape  = (N, K)
        output.shape   = (K, D)

    Input of shape (batch, H, W, D) from e.g., a conv layer, should be 
    squeezed to (batch, H*W, D) before using as Encoding input tensor.

    TODO: - allow for ndim(x) == 4
          - support for dropout?
    
    # Arguments:
        D : size of the features in X (& codewords)
        K : number of codewords in the codebook
    '''
    def __init__(self, D, K, **kwargs):
        super(Encoding, self).__init__(**kwargs)
        self.D, self.K = D, K

    def build(self, input_shape):
        # TODO: opt to init codebook manually? Or with GMM?

        self.codes = self.add_weight(name='codebook',
                                    shape=(self.K, self.D,),
                                    initializer='orthogonal', # should use uniform +/-std1?
                                    trainable=True)

        self.scale = self.add_weight(name='scale_factors',
                                    shape=(self.K,),
                                    initializer='uniform',
                                    trainable=True)

        super(Encoding, self).build(input_shape)

    def call(self, x): 
        ''' 
        Pseudo-code
        -----------
        Weights   : w_ij = exp(-S_j*||r_ij||^2) / Sum_k(exp(-S_k*||r_ik||^2))
        Aggregate : E_j  = Sum_i(w_ij * r_ij)
        '''
        #x_shape = K.shape(x).eval(session=K.get_session())
        if K.ndim(x) != 3:
            raise ValueError('`x` should have shape BxNxD')
        
        N = x.shape[1]
        
        # Compute residual vectors
        x_repd  = K.repeat_elements(x, self.K, 1)  #(?, N*K, D)
        c_tile = K.tile(self.codes, (N, 1)) #(N*K, D)
        R = K.reshape(x_rep - c_tile, (N, self.K, self.D))
        
        W = scaledL2(R, self.scale)
        E = aggregate(W, R)
        
        
        return self.codes
    
        # code-wise dot product

        #W = K.softmax(scaledL2(X, self.codewords, self.scale))

        #E = aggregate(W, X, self.codewords)
        
        #E = tf.einsum('ij,ij->j', W, R)

        #return E

    def compute_output_shape(self, input_shape):
        return (self.K, self.D)


In [89]:
inp = layers.Input(shape=(5,32,))
enc = Encoding(32, 3)
enc(inp)

(5, 3, 32)


<tf.Variable 'encoding_20/codebook:0' shape=(3, 32) dtype=float32_ref>

In [83]:
inp.shape[1]

Dimension(5)

In [41]:
N = 3
k = 2
D = 5

x = np.random.randint(0, 10, size=(N,D))
c = np.random.randint(0, 10, size=(k,D))

In [42]:
X = K.variable(x, dtype='float32')
C = K.variable(c, dtype='float32')

In [45]:
print(K.eval(X), end='\n\n')
print(K.eval(C))

[[8. 1. 6. 8. 3.]
 [9. 1. 4. 8. 5.]
 [0. 5. 9. 6. 9.]]

[[3. 8. 9. 9. 1.]
 [2. 3. 5. 5. 7.]]


In [49]:
X_expand = K.repeat_elements(X, k, 0)
C_expand = K.tile(C, (N,1))


print(X_expand.shape, C_expand.shape)
print(K.eval(X_expand), end='\n\n')
print(K.eval(C_expand))

(6, 5) (6, 5)
[[8. 1. 6. 8. 3.]
 [8. 1. 6. 8. 3.]
 [9. 1. 4. 8. 5.]
 [9. 1. 4. 8. 5.]
 [0. 5. 9. 6. 9.]
 [0. 5. 9. 6. 9.]]

[[3. 8. 9. 9. 1.]
 [2. 3. 5. 5. 7.]
 [3. 8. 9. 9. 1.]
 [2. 3. 5. 5. 7.]
 [3. 8. 9. 9. 1.]
 [2. 3. 5. 5. 7.]]


In [50]:
R = X_expand - C_expand
K.eval(R)

array([[ 5., -7., -3., -1.,  2.],
       [ 6., -2.,  1.,  3., -4.],
       [ 6., -7., -5., -1.,  4.],
       [ 7., -2., -1.,  3., -2.],
       [-3., -3.,  0., -3.,  8.],
       [-2.,  2.,  4.,  1.,  2.]], dtype=float32)

In [51]:
R = K.reshape(R, (N, k, D))
R = K.eval(R)

In [52]:
R

array([[[ 5., -7., -3., -1.,  2.],
        [ 6., -2.,  1.,  3., -4.]],

       [[ 6., -7., -5., -1.,  4.],
        [ 7., -2., -1.,  3., -2.]],

       [[-3., -3.,  0., -3.,  8.],
        [-2.,  2.,  4.,  1.,  2.]]], dtype=float32)

In [55]:
R[1,:,:]

array([[ 6., -7., -5., -1.,  4.],
       [ 7., -2., -1.,  3., -2.]], dtype=float32)

In [39]:
np.repeat(np.expand_dims(X, axis=-1), (1,1,K)).shape

ValueError: operands could not be broadcast together with shape (352,) (3,)

In [40]:
x = layers.Input(shape=(121,32,))

(11, 32, 1)