[View in Colaboratory](https://colab.research.google.com/github/rex-yue-wu/Notebooks/blob/master/hyperNetConv.ipynb)

In [0]:
from keras.layers import Input, Lambda, Conv2D, Flatten, MaxPool2D, Dense
from keras import backend as K
from keras.models import Model
from keras.utils import conv_utils
import tensorflow as tf
import numpy as np 

class MyConv2D( Conv2D ) :
    def __init__( self, filters,
                 kernel_size,
                 strides=(1, 1),
                 padding='valid',
                 data_format=None,
                 dilation_rate=(1, 1),
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        super( MyConv2D, self ).__init__( filters,
                 kernel_size=kernel_size,
                 strides=strides,
                 padding=padding,
                 data_format=data_format,
                 dilation_rate=dilation_rate,
                 activation=activation,
                 use_bias=use_bias,
                 kernel_initializer=kernel_initializer,
                 bias_initializer=bias_initializer,
                 kernel_regularizer=kernel_regularizer,
                 bias_regularizer=bias_regularizer,
                 activity_regularizer=activity_regularizer,
                 kernel_constraint=kernel_constraint,
                 bias_constraint=bias_constraint,
                 **kwargs)
        self.input_spec = None
    def build( self, input_shapes ):
        input_shape = input_shapes[0]
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        self.kernel_shape = self.kernel_size + (input_dim, self.filters)
        if self.use_bias:
            self.bias_shape = (self.filters,)
        else:
            self.bias_shape = None
        self.built = True
    def call( self, inputs ) :
        if len( inputs ) == 3 :
            x, kernel, bias = inputs
            kernel = K.reshape( kernel, self.kernel_shape )
            bias = K.reshape( bias, self.bias_shape )
        else :
            x, kernel = inputs
            kernel = K.reshape( kernel, self.kernel_shape )
        outputs = K.conv2d( x, kernel, 
                            strides=self.strides, 
                            padding=self.padding,
                            data_format=self.data_format,
                            dilation_rate=self.dilation_rate)
        if self.use_bias:
                outputs = K.bias_add( outputs,
                                      bias,
                                      data_format=self.data_format)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs
    def compute_output_shape(self, input_shapes):
        input_shape = input_shapes[0]
        output_shape = list(input_shape)
        if self.data_format == 'channels_first':
            c_axis, h_axis, w_axis = 1, 2, 3
        else:
            c_axis, h_axis, w_axis = 3, 1, 2

        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        output_shape[c_axis] = self.filters
        output_shape[h_axis] = conv_utils.deconv_length(
            output_shape[h_axis], stride_h, kernel_h, self.padding)
        output_shape[w_axis] = conv_utils.deconv_length(
            output_shape[w_axis], stride_w, kernel_w, self.padding)
        return tuple(output_shape)

        




In [36]:
# branch 1: ref featex
ref = Input(shape=(256,256,3))
fr = Conv2D(16,(3,3), activation='relu', padding='same')(ref)
fr = MaxPool2D((2,2))(fr)
fr = Conv2D(32,(3,3), activation='relu', padding='same')(fr)
fr = MaxPool2D((2,2))(fr)
fr = Conv2D(64,(3,3), activation='relu', padding='same')(fr)
fr = MaxPool2D((2,2))(fr)
fr = Conv2D(128,(3,3), activation='relu', padding='same')(fr)
fr = MaxPool2D((2,2))(fr)
nb_filters_in = fr._shape_as_list()[-1]
# branch 2: query featex
kernel_size=(3,3)
nb_filters=256
qry = Input(shape=(16,16,3))
fq = Conv2D(8,(3,3),activation='relu',padding='same')(qry)
fq = Conv2D(16,(3,3),activation='relu',padding='same')(fq)
fq = Flatten()(fq)
fq = Dense(256, activation='relu')(fq)
# note: make sure kq and bq uses the 'linear' activation
kq = Dense(np.product(kernel_size)*nb_filters*nb_filters_in, name='kernel')(fq)
bq = Dense(nb_filters, name='bias')(fq)
# branch 3: query-kernel conv ref-feat
res = MyConv2D( nb_filters, kernel_size, padding='same', activation='relu')([fr, kq, bq])
res = Conv2D(1,(3,3), activation='sigmoid', padding='same', name='pred')(res)

model = Model( inputs=[ref, qry], outputs=res, name='test')
print model.summary()




__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_37 (InputLayer)           (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 256, 256, 16) 448         input_37[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_57 (MaxPooling2D) (None, 128, 128, 16) 0           conv2d_85[0][0]                  
__________________________________________________________________________________________________
conv2d_86 (Conv2D)              (None, 128, 128, 32) 4640        max_pooling2d_57[0][0]           
__________________________________________________________________________________________________
input_38 (

In [37]:
a = np.random.randn(1,256,256,3)
b = np.random.randn(1,16,16,3)
c = model.predict([a,b])
print c.shape

(1, 16, 16, 1)


In [38]:
model.compile( loss='binary_crossentropy', optimizer='sgd')
model.fit( [a,b], c)

Epoch 1/1


<keras.callbacks.History at 0x7f8509713490>