In [2]:
import seq2seq

from seq2seq.models import AttentionSeq2Seq

In [10]:
def vae_loss(x, x_decoded_mean):
        xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
        kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return xent_loss + kl_loss
    
def sampling(args):
    z_mean, z_log_var = args         
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var / 2) * epsilon

def AttentionSeq2Seq(output_dim, output_length, hidden_dim=None, depth=1, bidirectional=True, dropout=0., **kwargs):
    '''
    This is an attention Seq2seq model with variational bound

    '''
    if type(depth) == int:
        depth = [depth, depth]
    if 'batch_input_shape' in kwargs:
        shape = kwargs['batch_input_shape']
        del kwargs['batch_input_shape']
    elif 'input_shape' in kwargs:
        shape = (None,) + tuple(kwargs['input_shape'])
        del kwargs['input_shape']
    elif 'input_dim' in kwargs:
        if 'input_length' in kwargs:
            shape = (None, kwargs['input_length'], kwargs['input_dim'])
            del kwargs['input_length']
        else:
            shape = (None, None, kwargs['input_dim'])
        del kwargs['input_dim']
    if 'unroll' in kwargs:
        unroll = kwargs['unroll']
        del kwargs['unroll']
    else:
        unroll = False
    if 'stateful' in kwargs:
        stateful = kwargs['stateful']
        del kwargs['stateful']
    else:
        stateful = False
    if not hidden_dim:
        hidden_dim = output_dim
        
    encoder = RecurrentContainer(unroll=unroll, stateful=stateful, return_sequences=True, input_length=shape[1])
    encoder.add(LSTMCell(hidden_dim, batch_input_shape=(shape[0], shape[2]), **kwargs))
    for _ in range(1, depth[0]):
        encoder.add(Dropout(dropout))
        encoder.add(LSTMCell(hidden_dim, **kwargs))
    input = Input(batch_shape=shape)
    input._keras_history[0].supports_masking = True
    if bidirectional:
        encoder = Bidirectional(encoder, merge_mode='sum')
        
    h_encoded = encoder(input)
    
    # Variational part
    z_mean = Dense(latent_dim)(h_encoded)
    z_log_var = Dense(latent_dim)(h_encoded)
    
    z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
    
    h_0 = Dense(hidden_dim, activation='relu')(z)
    
    decoder = RecurrentContainer(decode=True, output_length=output_length, unroll=unroll, stateful=stateful, input_length=shape[1])
    decoder.add(Dropout(dropout, batch_input_shape=(shape[0], shape[1], hidden_dim)))
    if depth[1] == 1:
        decoder.add(AttentionDecoderCell(output_dim=output_dim, hidden_dim=hidden_dim))
    else:
        decoder.add(AttentionDecoderCell(output_dim=hidden_dim, hidden_dim=hidden_dim))
        for _ in range(depth[1] - 2):
            decoder.add(Dropout(dropout))
            decoder.add(LSTMDecoderCell(output_dim=hidden_dim, hidden_dim=hidden_dim))
        decoder.add(Dropout(dropout))
        decoder.add(LSTMDecoderCell(output_dim=output_dim, hidden_dim=hidden_dim))
    inputs = [input]
    '''
    if teacher_force:
        truth_tensor = Input(batch_shape=(shape[0], output_length, output_dim))
        inputs += [truth_tensor]
        decoder.set_truth_tensor(truth_tensor)
    '''
    x_decoded_mean = decoder(h_0)
    
    model = Model(inputs, decoded)
    
    model.compile(optimizer='adam', loss=vae_loss)
    
    return model