In [2]:
from keras.layers import *
from keras.models import Model
from keras.losses import binary_crossentropy
from keras.initializers import Ones
import numpy as np
import warnings
warnings.filterwarnings("ignore")


In [3]:
DIM_EMBEDDING = 100
LEN_VOC = 1000
LSTM_UNIT1 = 32
LSTM_UNIT2 = 16
LEN_S1 = 50
LEN_S2 = 40
POOL_SIZE = 2
BATCH_SIZE = 8

## 测试自己写的ESIM是否正确

### 自己编写ESIM

In [6]:
WordEncoder = Embedding(input_dim=LEN_VOC, output_dim=DIM_EMBEDDING, name='word_encoder',
                        embeddings_initializer=Ones())
ContextEncoder = Bidirectional(LSTM(LSTM_UNIT1, return_sequences=True, name='context_encoder',
                     recurrent_initializer=Ones(), kernel_initializer=Ones(), bias_initializer=Ones()))
ContextEncoder2 = Bidirectional(LSTM(LSTM_UNIT2, return_sequences=True, name='context_encoder2',
                     recurrent_initializer=Ones(), kernel_initializer=Ones(), bias_initializer=Ones()))
SharpenContext = Subtract()
FilterContext = Multiply()
ContextConcatenate = Concatenate(axis=-1, name='context_concatenate')
RelationConcatenate = Concatenate(axis=-1, name='relation_concatenate')


print('build model...')
# input encoding
s1 = Input(batch_shape=[None, LEN_S1], name='sen1')
s2 = Input(batch_shape=[None, LEN_S2], name='sen2')
s1_embeded = WordEncoder(s1)
s2_embeded = WordEncoder(s2)
s1_encoded = ContextEncoder(s1_embeded)
s2_encoded = ContextEncoder(s2_embeded)
# local inference
dot = Dot(axes=-1)([s1_encoded, s2_encoded])
s1_attention_weight = Activation(activation='softmax')(dot)                        # shape = [BATCH_SIZE, LEN_S1, LEN_S2]
s2_attention_weight = Activation(activation='softmax')(Permute(dims=[2, 1])(dot))  # shape = [BATCH_SIZE, LEN_S2, LEN_S1]
s1_ = Dot(axes=-1)([s1_attention_weight, Permute(dims=[2, 1])(s2_encoded)])
s2_ = Dot(axes=-1)([s2_attention_weight, Permute(dims=[2, 1])(s1_encoded)])
ctx1 = ContextConcatenate([s1_encoded, s1_, SharpenContext([s1_encoded, s1_]), FilterContext([s1_encoded, s1_])])
ctx2 = ContextConcatenate([s2_encoded, s2_, SharpenContext([s2_encoded, s2_]), FilterContext([s2_encoded, s2_])])
# inference composition
ctx1 = ContextEncoder2(ctx1)
ctx2 = ContextEncoder2(ctx2)
c1_avg = GlobalAveragePooling1D()(ctx1)
c1_max = GlobalMaxPooling1D()(ctx1)
c2_avg = GlobalAveragePooling1D()(ctx2)
c2_max = GlobalMaxPooling1D()(ctx2)
v = ContextConcatenate([c1_avg, c1_max, c2_avg, c2_max])
result = Dense(1, activation='sigmoid', kernel_initializer=Ones(), bias_initializer=Ones())(v)

# compile model
mine = Model(inputs=[s1, s2], outputs=[result])
mine.compile(optimizer='sgd', loss=binary_crossentropy)
mine.summary()

build model...
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
sen1 (InputLayer)               (None, 50)           0                                            
__________________________________________________________________________________________________
sen2 (InputLayer)               (None, 40)           0                                            
__________________________________________________________________________________________________
word_encoder (Embedding)        multiple             100000      sen1[0][0]                       
                                                                 sen2[0][0]                       
__________________________________________________________________________________________________
bidirectional_1 (Bidirectional) multiple             34048       word_encoder[0][0]           

### 别人编写的ESIM

In [7]:
# Based on arXiv:1609.06038
q1 = Input(batch_shape=[None, None])  # shape = [batch_size, seq_len]
q2 = Input(batch_shape=[None, None])  # shape = [batch_size, seq_len]

# Embedding
embedding = Embedding(input_dim=LEN_VOC, output_dim=DIM_EMBEDDING, embeddings_initializer=Ones())

q1_embed = embedding(q1)
q2_embed = embedding(q2)

# Encode
bilstm = Bidirectional(LSTM(LSTM_UNIT1, return_sequences=True, recurrent_initializer=Ones(), kernel_initializer=Ones(), bias_initializer=Ones()))
x1 = bilstm(q1_embed)
x2 = bilstm(q2_embed)

e = Dot(axes=2)([x1, x2])
e1 = Softmax(axis=2)(e)
e2 = Softmax(axis=1)(e)
e1 = Lambda(K.expand_dims, arguments={'axis': 3})(e1)
e2 = Lambda(K.expand_dims, arguments={'axis': 3})(e2)

_x1 = Lambda(K.expand_dims, arguments={'axis': 1})(x2)
_x1 = Multiply()([e1, _x1])
_x1 = Lambda(K.sum, arguments={'axis': 2})(_x1)
_x2 = Lambda(K.expand_dims, arguments={'axis': 2})(x1)
_x2 = Multiply()([e2, _x2])
_x2 = Lambda(K.sum, arguments={'axis': 1})(_x2)

m1 = Concatenate()([x1, _x1, Subtract()([x1, _x1]), Multiply()([x1, _x1])])
m2 = Concatenate()([x2, _x2, Subtract()([x2, _x2]), Multiply()([x2, _x2])])

bilstm2 = Bidirectional(LSTM(LSTM_UNIT2, return_sequences=True, recurrent_initializer=Ones(), kernel_initializer=Ones(), bias_initializer=Ones()))
y1 = bilstm2(m1)
y2 = bilstm2(m2)

mx1 = Lambda(K.max, arguments={'axis': 1})(y1)
av1 = Lambda(K.mean, arguments={'axis': 1})(y1)
mx2 = Lambda(K.max, arguments={'axis': 1})(y2)
av2 = Lambda(K.mean, arguments={'axis': 1})(y2)

y = Concatenate()([av1, mx1, av2, mx2])
output = Dense(1, activation='sigmoid', kernel_initializer=Ones(), bias_initializer=Ones())(y)
esim = Model(inputs=[q1, q2], outputs=[output])
esim.compile(optimizer='sgd', loss=binary_crossentropy, metrics=['acc'])
esim.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 100)    100000      input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
bidirectional_3 (Bidirectional) (None, None, 64)     34048       embedding_1[0][0]                
          

### 一些层的输出

In [11]:
x1 = np.random.rand(BATCH_SIZE, LEN_S1)
x2 = np.random.rand(BATCH_SIZE, LEN_S2)
print('x1[:1]\n\n', x1[:1])
print('x2[:1]\n\n', x2[:1])
y_mine = mine.predict(x=[x1, x2])
y_esim = esim.predict(x=[x1, x2])
print('y_mine\n\n', y_mine)
print('y_esim\n\n', y_esim)

x1[:1]

 [[0.20315027 0.19419504 0.0391941  0.82134968 0.1305034  0.94363338
  0.13618826 0.97851304 0.24774574 0.37293895 0.17771148 0.75703405
  0.50921326 0.42383154 0.81686769 0.58488338 0.59101397 0.15340413
  0.51231427 0.08793192 0.21079994 0.63293082 0.54531202 0.14813556
  0.16403294 0.76736216 0.5181336  0.83074243 0.20470304 0.01973633
  0.14425149 0.98327597 0.82249118 0.69700737 0.35186999 0.9811738
  0.82977792 0.11704572 0.47706707 0.0146338  0.7075972  0.27989424
  0.55196786 0.5637819  0.97258895 0.61683193 0.91119054 0.06611965
  0.44703284 0.56086396]]
x2[:1]

 [[0.78120823 0.72945504 0.7667443  0.58191879 0.39841785 0.43449338
  0.63027284 0.50604932 0.15369607 0.88741095 0.78567249 0.41675698
  0.0122022  0.58232796 0.13115551 0.2263666  0.2947192  0.71964085
  0.24682478 0.96040925 0.09982975 0.14335222 0.76129882 0.32549239
  0.48723311 0.85125139 0.20472467 0.5770125  0.41947864 0.85608947
  0.94057238 0.03546743 0.2575981  0.50893252 0.39837772 0.05501654
  0.9

In [13]:
mine_middle_layer = Model(inputs=mine.inputs, outputs=[ctx1])
esim_middle_layer = Model(inputs=esim.inputs, outputs=[av1])
print('mine_middle_layer', mine_middle_layer.predict(x=[x1, x2]))
print('esim_middle_layer', esim_middle_layer.predict(x=[x1, x2]))

mine_middle_layer [[[0.7615942 0.7615942 0.7615942 ... 1.        1.        1.       ]
  [0.9640276 0.9640276 0.9640276 ... 1.        1.        1.       ]
  [0.9950547 0.9950547 0.9950547 ... 1.        1.        1.       ]
  ...
  [1.        1.        1.        ... 0.9950547 0.9950547 0.9950547]
  [1.        1.        1.        ... 0.9640276 0.9640276 0.9640276]
  [1.        1.        1.        ... 0.7615942 0.7615942 0.7615942]]

 [[0.7615942 0.7615942 0.7615942 ... 1.        1.        1.       ]
  [0.9640276 0.9640276 0.9640276 ... 1.        1.        1.       ]
  [0.9950547 0.9950547 0.9950547 ... 1.        1.        1.       ]
  ...
  [1.        1.        1.        ... 0.9950547 0.9950547 0.9950547]
  [1.        1.        1.        ... 0.9640276 0.9640276 0.9640276]
  [1.        1.        1.        ... 0.7615942 0.7615942 0.7615942]]

 [[0.7615942 0.7615942 0.7615942 ... 1.        1.        1.       ]
  [0.9640276 0.9640276 0.9640276 ... 1.        1.        1.       ]
  [0.9950547 0

### 结论

- 从网络结构看，没错
- 从参数数量上看是没错的
- 从简单计算结果上看，也是没错的