diff --git a/deepctr/layers/sequence.py b/deepctr/layers/sequence.py index 6ba8448b..ded9ba97 100644 --- a/deepctr/layers/sequence.py +++ b/deepctr/layers/sequence.py @@ -317,7 +317,9 @@ class Transformer(Layer): """Transformer proposed in 《Attention is all you need》 Input shape - - 3D tensor with shape ``(batch_size, timesteps, input_dim)``. + - a list of two 3D tensor with shape ``(batch_size, timesteps, input_dim)`` if supports_masking=True. + - a list of two 4 tensors, first two tensors with shape ``(batch_size, timesteps, input_dim)``,last two tensors with shape ``(batch_size, 1)`` if supports_masking=False. + Output shape - 3D tensor with shape: ``(batch_size, 1, input_dim)``. @@ -357,9 +359,6 @@ def __init__(self, att_embedding_size=1, head_num=8, dropout_rate=0.0, use_posit self.supports_masking = supports_masking def build(self, input_shape): - if len(input_shape) != 3: - raise ValueError( - "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(input_shape))) embedding_size = input_shape[0][-1].value self.seq_len_max = input_shape[0][-2].value @@ -392,9 +391,6 @@ def build(self, input_shape): super(Transformer, self).build(input_shape) def call(self, inputs, mask=None, **kwargs): - if K.ndim(inputs) != 3: - raise ValueError( - "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs))) if self.supports_masking: queries, keys = inputs