

```

Variational Autoencoder (VAE) for AMP Sequences
==============================================

This script demonstrates a basic character-level VAE for protein sequences.
The encoder compresses sequences into a latent vector. The decoder reconstructs
the original sequence from the latent vector. New sequences can be generated by
sampling from the latent space.

```



In [2]:
"""
Variational Autoencoder (VAE) for AMP Sequences
==============================================

This script demonstrates a basic character-level VAE for protein sequences.
The encoder compresses sequences into a latent vector. The decoder reconstructs
the original sequence from the latent vector. New sequences can be generated by
sampling from the latent space.
"""

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model

# 1. Example data
amp_sequences = [
            'FLPLLAGLAANFLPTIICKISYKC',
            'FLPFIARLAAKVFPSIICSVTKKC',
            'GVLSNVIGYLKKLGTGALNAVLKQ',
            'GLFSVLGAVAKHVLPHVVPVIAEK',
            'GLFKVLGSVAKHLLPHVAPVIAEK',
            'GLFKVLGSVAKHLLPHVVPVIAEK',
            'GLFGVLGSIAKHVLPHVVPVIAEK',
            'MFFSSKKCKTVSKTFRGPCVRNAN',
            'LLKELWTKMKGAGKAVLGKIKGLL',
            'LLKELWTKIKGAGKAVLGKIKGLL',
            'FWGALIKGAAKLIPSVVGLFKKKQ',
            'FLPVVAGLAAKVLPSIICAVTKKC',
            'FLPAIVGAAGQFLPKIFCAISKKC',
            'FLPAIVGAAGKFLPKIFCAISKKC',
            'FFPIVAGVAGQVLKKIYCTISKKC',
            'FLPIIAGIAAKVFPKIFCAISKKC',
            'FLPMLAGLAASMVPKLVCLITKKC',
            'FLPMLAGLAASMVPKFVCLITKKC',
            'FLPFIAGMAAKFLPKIFCAISKKC',
            'FLPAIAGMAAKFLPKIFCAISKKC',
            'FLPFIAGVAAKFLPKIFCAISKKC',
            'FLPAIAGVAAKFLPKIFCAISKKC',
            'FLPAIVGAAAKFLPKIFCVISKKC',
            'FLPFIAGMAANFLPKIFCAISKKC',
            'FLPIIAGVAAKVFPKIFCAISKKC',
            'FLPIIASVAAKVFSKIFCAISKKC',
            'FLPIIASVAANVFSKIFCAISKKC',
            'GLNTLKKVFQGLHEAIKLINNHVQ',
            'GLNALKKVFQGIHEAIKLINNHVQ',
            'DSHAKRHHGYKRKFHEKHHSHRGY',
            'FLPLLAGLAANFLPKIFCKITKKC',
            'FLPILAGLAAKIVPKLFCLATKKC',
            'FLPLIAGLAANFLPKIFCAITKKC',
            'FLPVIAGVAAKFLPKIFCAITKKC',
            'FWGALAKGALKLIPSLFSSFSKKD',
            'ITSVSWCTPGCTSEGGGSGCSHCC',
            'GLLNGLALRLGKRALKKIIKRLCR',
            'ALWKDILKNAGKAALNEINQLVNQ',
            'GLRSKIWLWVLLMIWQESNKFKKM',
            'GKGRWLERIGKAGGIIIGGALDHL',
            'FLGALIKGAIHGGRFIHGMIQNHH',
            'FLGLLFHGVHHVGKWIHGLIHGHH',
            'FLPMLAGLAANFLPKLFCKITKKC',
            'FLPLAVSLAANFLPKLFCKITKKC',
            'FLPLLAGLAANFFPKIFCKITRKC',
            'FLPILASLAAKFGPKLFCLVTKKC',
            'FLPILASLAAKLGPKLFCLVTKKC',
            'FLPILASLAATLGPKLLCLITKKC',
            'GIFSNMYARTPAGYFRGPAGYAAN',
            'GLKDKFKSMGEKLKQYIQTWKAKF',
            'SLKDKVKSMGEKLKQYIQTWKAKF',
            'GFRDVLKGAAKAFVKTVAGHIANI',
            'GIKDWIKGAAKKLIKTVASNIANQ',
            'GFKDWIKGAAKKLIKTVASSIANQ',
            'VIPFVASVAAEMMQHVYCAASKKC',
            'FFGTALKIAANVLPTAICKILKKC',
            'FFGTALKIAANILPTAICKILKKC',
            'ILPFVAGVAAEMMQHVYCAASKKC',
            'FLPAIVGAAAKFLPKIFCAISKKC',
            'FLPIIAGVAAKVLPKIFCAISKKC',
            'FLPIIAGIAAKFLPKIFCTISKKC',
            'FLPVIAGVAANFLPKLFCAISKKC',
            'FLPIIAGAAAKVVQKIFCAISKKC',
            'FLPIIAGAAAKVVEKIFCAISKKC',
            'FLPAVLRVAAKIVPTVFCAISKKC',
            'FLPAVLRVAAQVVPTVFCAISKKC',
            'FMGGLIKAATKIVPAAYCAITKKC',
            'FLPILAGLAAKLVPKVFCSITKKC',
            'FLPILAGLAANILPKVFCSITKKC',
            'FFPIIAGMAAKLIPSLFCKITKKC',
            'FMGSALRIAAKVLPAALCQIFKKC',
            'DSHEKRHHEHRRKFHEKHHSHRGY',
            'WRSLGRTLLRLSHALKPLARRSGW',
            'VTSWSLCTPGCTSPGGGSNCSFCC',
            'VIPFVASVAAEMMHHVYCAASKRC',
            'SPAGCRFCCGCCPNMRGCGVCCRF',
            'GRGREFMSNLKEKLSGVKEKMKNS',
            'FLPVLTGLTPSIVPKLVCLLTKKC',
            'FLPVLAGLTPSIVPKLVCLLTKKC',
            'FFPMLAGVAARVVPKVICLITKKC',
            'DSMGAVKLAKLLIDKMKCEVTKAC',
            'FLPGVLRLVTKVGPAVVCAITRNC',
            'VIVFVASVAAEMMQHVYCAASKKC',
            'FLPAVIRVAANVLPTAFCAISKKC',
            'IDPFVAGVAAEMMQHVYCAASKKC',
            'INPFVAGVAAEMMQHVYCAASKKC',
            'ILPFVAGVAAEMMKHVYCAASKKC',
            'IIPFVAGVAAEMMEHVYCAASKKC',
            'QLPFVAGVACEMCQCVYCAASKKC',
            'ILPFVAGVAAEMMEHVYCAASKKC',
            'ILPFVAGVAAMEMEHVYCAASKKC',
            'FLPAVLLVATHVLPTVFCAITRKC',
            'IPWKLPATFRPVERPFSKPFCRKD',
            'FLPLLAGVVANFLPQIICKIARKC',
            'FLGSLLGLVGKVVPTLFCKISKKC',
            'FIGPVLKIAAGILPTAICKIFKKC',
            'FVGPVLKIAAGILPTAICKIYKKC',
            'FLGPIIKIATGILPTAICKFLKKC',
            'FLPLIASLAANFVPKIFCKITKKC',
            'FLPLIASVAANLVPKIFCKITKKC',
            'FLSTLLKVAFKVVPTLFCPITKKC',
            'KRKCPKTPFDNTPGAWFAHLILGC',
            'FLGLIFHGLVHAGKLIHGLIHRNR',
            'FLPAVIRVAANVLPTVFCAISKKC',
            'FLPAVLRVAAKVVPTVFCLISKKC',
            'FLSTALKVAANVVPTLFCKITKKC',
            'FLPIVAGLAANFLPKIVCKITKKC',
            'FLSTLLNVASNVVPTLICKITKKC',
            'FLSTLLNVASKVVPTLFCKITKKC',
            'FLPMLAGLAANFLPKIVCKITKKC',
            'FIGPVLKMATSILPTAICKGFKKC',
            'FLGPIIKMATGILPTAICKGLKKC',
            'FLPIIAGVAAKVLPKLFCAITKKC',
            'FLPVIAGLAAKVLPKLFCAITKKC',
            'RKGWFKAMKSIAKFIAKEKLKEHL',
            'FLPAVLKVAAHILPTAICAISRRC',
            'FMGTALKIAANVLPAAFCKIFKKC',
            'KLGFENFLVKALKTVMHVPTSPLL',
            'GWLPTFGKILRKAMQLGPKLIQPI',
            'GNGVVLTLTHECNLATWTKKLKCC',
            'ITIPPIVKNTLKKFIKGAVSALMS',
            'FLPGLIKAAVGVGSTILCKITKKC',
            'FLPGLIKAAVGIGSTIFCKISKKC',
            'FLPGLIKVAVGVGSTILCKITKKC',
            'FLPGLIKAAVGIGSTIFCKISRKC',
            'FLPMLAGLAANFLPKIICKITKKC',
            'FLPIVASLAANFLPKIICKITKKC',
            'FWGALAKGALKLIPSLVSSFTKKD',
            'FFPLIAGLAARFLPKIFCSITKRC',
            'VIPFVASVAAEMMQHVYCAASKRC',
            'FFPSIAGLAAKFLPKIFCSITKRC',
            'FLPAVLRVAAKVGPAVFCAITQKC',
            'FLGMLLHGVGHAIHGLIHGKQNVE',
            'NPAGCRFCCGCCPNMIGCGVCCRF',
            'IWSFLIKAATKLLPSLFGGGKKDS',
            'RNGCIVDPRCPYQQCRRPLYCRRR',
            'ILELAGNAARDNKKTRIIPRHLQL',
            'FLPLLAGLAANFLPTIICKIARKC',
            'FLPAIIGMAAKVLPAFLCKITKKC',
            'RRRRRFRRVIRRIRLPKYLTINTE',
            'GNGVLKTISHECNMNTWQFLFTCC',
            'FLPILAGLAANLVPKLICSITKKC',
            'FLGAVLKVAGKLVPAAICKISKKC',
            'FLGALFKVASKLVPAAICSISKKC',
            'FLPVIAGIAANVLPKLFCKLTKRC',
            'FFPIIARLAAKVIPSLVCAVTKKC',
            'KRVNWRKVGRNTALGASYVLSFLG',
            'GHSVDRIPEYFGPPGLPGPVLFYS',
            'FLPLIAGVAAKVLPKIFCAISKKC',
            'SDSVVSDIICTTFCSVTWCQSNCC',
            'FLPLLAGLAANFLPQIICKIARKC',
            'FLGTVLKVAAKVLPAALCQIFKKC',
            'QSHLSMCRYCCCKGNKGCGFCCKF',
            'VFDIIKDAGKQLVAHAMGKIAEKV',
            'VFDIIKDAGRQLVAHAMGKIAEKV',
            'FLPLLAGLAASFLPTIFCKISRKC',
            'FFPIVAGVAAKVLKKIFCTISKKC',
]

unique_amino_acids = sorted(list(set("".join(amp_sequences))))
char_to_idx = {char: idx for idx, char in enumerate(unique_amino_acids)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(unique_amino_acids)

# 2. Encode sequences as integers
encoded_data = []
for seq in amp_sequences:
    encoded_data.append([char_to_idx[c] for c in seq])
encoded_data = np.array(encoded_data)  # shape (num_sequences, 24)

# 3. One-hot encode for VAE
one_hot_data = tf.keras.utils.to_categorical(encoded_data, num_classes=vocab_size)
# shape: (num_sequences, 24, vocab_size)

# 4. Define hyperparameters
seq_length = 24
latent_dim = 16  # dimension of the latent space
hidden_dim = 64  # dimension of LSTM or dense hidden units

# 5. Sampling function for the VAE
def sampling(args):
    z_mean, z_log_var = args
    epsilon = tf.keras.backend.random_normal(shape=(tf.shape(z_mean)[0], latent_dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# 6. Define a custom loss function for the VAE
def vae_loss_fn(y_true, y_pred, z_mean, z_log_var):
    # Reconstruction loss (categorical crossentropy)
    reconstruction_loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    reconstruction_loss = tf.reduce_sum(reconstruction_loss, axis=1)  # Sum over sequence length

    # KL divergence loss
    kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
    kl_loss = -0.5 * tf.reduce_sum(kl_loss, axis=1)

    # Total loss
    total_loss = reconstruction_loss + kl_loss
    return tf.reduce_mean(total_loss)

# 7. Encoder model
encoder_inputs = layers.Input(shape=(seq_length, vocab_size))
x = layers.LSTM(hidden_dim)(encoder_inputs)
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
z = layers.Lambda(sampling)([z_mean, z_log_var])
encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")

# 8. Decoder model
latent_inputs = layers.Input(shape=(latent_dim,))
dec_x = layers.RepeatVector(seq_length)(latent_inputs)
dec_x = layers.LSTM(hidden_dim, return_sequences=True)(dec_x)
decoder_outputs = layers.TimeDistributed(layers.Dense(vocab_size, activation='softmax'))(dec_x)
decoder = Model(latent_inputs, decoder_outputs, name="decoder")

# 9. VAE model that connects encoder and decoder
z_mean_tensor, z_log_var_tensor, z_tensor = encoder(encoder_inputs)
outputs = decoder(z_tensor)
vae = Model(encoder_inputs, outputs, name="vae")

# 10. Create a custom model class to incorporate the custom loss
class VAEModel(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAEModel, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]

        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            loss = vae_loss_fn(data, reconstruction, z_mean, z_log_var)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        return {"loss": loss}

# 11. Create and compile the custom VAE model
custom_vae = VAEModel(encoder, decoder)
custom_vae.compile(optimizer='adam')

# 12. Build the model
custom_vae.build(input_shape=(None, seq_length, vocab_size))
custom_vae.summary()

# 13. Train VAE
epochs = 50
batch_size = 8
custom_vae.fit(one_hot_data, epochs=epochs, batch_size=batch_size)

# 14. Generate new sequences by sampling from latent space
def generate_new_sequence(decoder, sample_z=None):
    """
    Sample from the latent space and decode a new sequence.
    :param decoder: the trained decoder model
    :param sample_z: optional latent vector. If None, randomly sample from N(0,1).
    :return: generated protein sequence as a string
    """
    if sample_z is None:
        sample_z = np.random.randn(1, latent_dim)  # random from normal distribution
    pred = decoder.predict(sample_z)[0]  # shape: (24, vocab_size)

    # Convert one-hot distribution at each position to a chosen amino acid
    seq_indices = [np.argmax(prob) for prob in pred]
    seq_string = "".join(idx_to_char[idx] for idx in seq_indices)
    return seq_string

# Example usage:
new_peptide = generate_new_sequence(decoder)
print("Generated Peptide (VAE):", new_peptide)

Epoch 1/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 28ms/step - loss: 68.4986
Epoch 2/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 35ms/step - loss: 66.0774
Epoch 3/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 23ms/step - loss: 63.4805
Epoch 4/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 24ms/step - loss: 62.5497
Epoch 5/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 22ms/step - loss: 61.8852
Epoch 6/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 23ms/step - loss: 61.5162
Epoch 7/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 24ms/step - loss: 60.7194
Epoch 8/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - loss: 60.0854
Epoch 9/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 24ms/step - loss: 59.5951
Epoch 10/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - los