In [None]:
"""
Transformer Generative Model for Antimicrobial Peptides
=======================================================

This script demonstrates a simplified Transformer for generating short protein sequences.
"""

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model

# 1. Example AMP dataset
'''
amp_sequences: list of the AMP protein sequences the model will learn from.
'''
amp_sequences = [
            'FLPLLAGLAANFLPTIICKISYKC',
            'FLPFIARLAAKVFPSIICSVTKKC',
            'GVLSNVIGYLKKLGTGALNAVLKQ',
            'GLFSVLGAVAKHVLPHVVPVIAEK',
            'GLFKVLGSVAKHLLPHVAPVIAEK',
            'GLFKVLGSVAKHLLPHVVPVIAEK',
            'GLFGVLGSIAKHVLPHVVPVIAEK',
            'MFFSSKKCKTVSKTFRGPCVRNAN',
            'LLKELWTKMKGAGKAVLGKIKGLL',
            'LLKELWTKIKGAGKAVLGKIKGLL',
            'FWGALIKGAAKLIPSVVGLFKKKQ',
            'FLPVVAGLAAKVLPSIICAVTKKC',
            'FLPAIVGAAGQFLPKIFCAISKKC',
            'FLPAIVGAAGKFLPKIFCAISKKC',
            'FFPIVAGVAGQVLKKIYCTISKKC',
            'FLPIIAGIAAKVFPKIFCAISKKC',
            'FLPMLAGLAASMVPKLVCLITKKC',
            'FLPMLAGLAASMVPKFVCLITKKC',
            'FLPFIAGMAAKFLPKIFCAISKKC',
            'FLPAIAGMAAKFLPKIFCAISKKC',
            'FLPFIAGVAAKFLPKIFCAISKKC',
            'FLPAIAGVAAKFLPKIFCAISKKC',
            'FLPAIVGAAAKFLPKIFCVISKKC',
            'FLPFIAGMAANFLPKIFCAISKKC',
            'FLPIIAGVAAKVFPKIFCAISKKC',
            'FLPIIASVAAKVFSKIFCAISKKC',
            'FLPIIASVAANVFSKIFCAISKKC',
            'GLNTLKKVFQGLHEAIKLINNHVQ',
            'GLNALKKVFQGIHEAIKLINNHVQ',
            'DSHAKRHHGYKRKFHEKHHSHRGY',
            'FLPLLAGLAANFLPKIFCKITKKC',
            'FLPILAGLAAKIVPKLFCLATKKC',
            'FLPLIAGLAANFLPKIFCAITKKC',
            'FLPVIAGVAAKFLPKIFCAITKKC',
            'FWGALAKGALKLIPSLFSSFSKKD',
            'ITSVSWCTPGCTSEGGGSGCSHCC',
            'GLLNGLALRLGKRALKKIIKRLCR',
            'ALWKDILKNAGKAALNEINQLVNQ',
            'GLRSKIWLWVLLMIWQESNKFKKM',
            'GKGRWLERIGKAGGIIIGGALDHL',
            'FLGALIKGAIHGGRFIHGMIQNHH',
            'FLGLLFHGVHHVGKWIHGLIHGHH',
            'FLPMLAGLAANFLPKLFCKITKKC',
            'FLPLAVSLAANFLPKLFCKITKKC',
            'FLPLLAGLAANFFPKIFCKITRKC',
            'FLPILASLAAKFGPKLFCLVTKKC',
            'FLPILASLAAKLGPKLFCLVTKKC',
            'FLPILASLAATLGPKLLCLITKKC',
            'GIFSNMYARTPAGYFRGPAGYAAN',
            'GLKDKFKSMGEKLKQYIQTWKAKF',
            'SLKDKVKSMGEKLKQYIQTWKAKF',
            'GFRDVLKGAAKAFVKTVAGHIANI',
            'GIKDWIKGAAKKLIKTVASNIANQ',
            'GFKDWIKGAAKKLIKTVASSIANQ',
            'VIPFVASVAAEMMQHVYCAASKKC',
            'FFGTALKIAANVLPTAICKILKKC',
            'FFGTALKIAANILPTAICKILKKC',
            'ILPFVAGVAAEMMQHVYCAASKKC',
            'FLPAIVGAAAKFLPKIFCAISKKC',
            'FLPIIAGVAAKVLPKIFCAISKKC',
            'FLPIIAGIAAKFLPKIFCTISKKC',
            'FLPVIAGVAANFLPKLFCAISKKC',
            'FLPIIAGAAAKVVQKIFCAISKKC',
            'FLPIIAGAAAKVVEKIFCAISKKC',
            'FLPAVLRVAAKIVPTVFCAISKKC',
            'FLPAVLRVAAQVVPTVFCAISKKC',
            'FMGGLIKAATKIVPAAYCAITKKC',
            'FLPILAGLAAKLVPKVFCSITKKC',
            'FLPILAGLAANILPKVFCSITKKC',
            'FFPIIAGMAAKLIPSLFCKITKKC',
            'FMGSALRIAAKVLPAALCQIFKKC',
            'DSHEKRHHEHRRKFHEKHHSHRGY',
            'WRSLGRTLLRLSHALKPLARRSGW',
            'VTSWSLCTPGCTSPGGGSNCSFCC',
            'VIPFVASVAAEMMHHVYCAASKRC',
            'SPAGCRFCCGCCPNMRGCGVCCRF',
            'GRGREFMSNLKEKLSGVKEKMKNS',
            'FLPVLTGLTPSIVPKLVCLLTKKC',
            'FLPVLAGLTPSIVPKLVCLLTKKC',
            'FFPMLAGVAARVVPKVICLITKKC',
            'DSMGAVKLAKLLIDKMKCEVTKAC',
            'FLPGVLRLVTKVGPAVVCAITRNC',
            'VIVFVASVAAEMMQHVYCAASKKC',
            'FLPAVIRVAANVLPTAFCAISKKC',
            'IDPFVAGVAAEMMQHVYCAASKKC',
            'INPFVAGVAAEMMQHVYCAASKKC',
            'ILPFVAGVAAEMMKHVYCAASKKC',
            'IIPFVAGVAAEMMEHVYCAASKKC',
            'QLPFVAGVACEMCQCVYCAASKKC',
            'ILPFVAGVAAEMMEHVYCAASKKC',
            'ILPFVAGVAAMEMEHVYCAASKKC',
            'FLPAVLLVATHVLPTVFCAITRKC',
            'IPWKLPATFRPVERPFSKPFCRKD',
            'FLPLLAGVVANFLPQIICKIARKC',
            'FLGSLLGLVGKVVPTLFCKISKKC',
            'FIGPVLKIAAGILPTAICKIFKKC',
            'FVGPVLKIAAGILPTAICKIYKKC',
            'FLGPIIKIATGILPTAICKFLKKC',
            'FLPLIASLAANFVPKIFCKITKKC',
            'FLPLIASVAANLVPKIFCKITKKC',
            'FLSTLLKVAFKVVPTLFCPITKKC',
            'KRKCPKTPFDNTPGAWFAHLILGC',
            'FLGLIFHGLVHAGKLIHGLIHRNR',
            'FLPAVIRVAANVLPTVFCAISKKC',
            'FLPAVLRVAAKVVPTVFCLISKKC',
            'FLSTALKVAANVVPTLFCKITKKC',
            'FLPIVAGLAANFLPKIVCKITKKC',
            'FLSTLLNVASNVVPTLICKITKKC',
            'FLSTLLNVASKVVPTLFCKITKKC',
            'FLPMLAGLAANFLPKIVCKITKKC',
            'FIGPVLKMATSILPTAICKGFKKC',
            'FLGPIIKMATGILPTAICKGLKKC',
            'FLPIIAGVAAKVLPKLFCAITKKC',
            'FLPVIAGLAAKVLPKLFCAITKKC',
            'RKGWFKAMKSIAKFIAKEKLKEHL',
            'FLPAVLKVAAHILPTAICAISRRC',
            'FMGTALKIAANVLPAAFCKIFKKC',
            'KLGFENFLVKALKTVMHVPTSPLL',
            'GWLPTFGKILRKAMQLGPKLIQPI',
            'GNGVVLTLTHECNLATWTKKLKCC',
            'ITIPPIVKNTLKKFIKGAVSALMS',
            'FLPGLIKAAVGVGSTILCKITKKC',
            'FLPGLIKAAVGIGSTIFCKISKKC',
            'FLPGLIKVAVGVGSTILCKITKKC',
            'FLPGLIKAAVGIGSTIFCKISRKC',
            'FLPMLAGLAANFLPKIICKITKKC',
            'FLPIVASLAANFLPKIICKITKKC',
            'FWGALAKGALKLIPSLVSSFTKKD',
            'FFPLIAGLAARFLPKIFCSITKRC',
            'VIPFVASVAAEMMQHVYCAASKRC',
            'FFPSIAGLAAKFLPKIFCSITKRC',
            'FLPAVLRVAAKVGPAVFCAITQKC',
            'FLGMLLHGVGHAIHGLIHGKQNVE',
            'NPAGCRFCCGCCPNMIGCGVCCRF',
            'IWSFLIKAATKLLPSLFGGGKKDS',
            'RNGCIVDPRCPYQQCRRPLYCRRR',
            'ILELAGNAARDNKKTRIIPRHLQL',
            'FLPLLAGLAANFLPTIICKIARKC',
            'FLPAIIGMAAKVLPAFLCKITKKC',
            'RRRRRFRRVIRRIRLPKYLTINTE',
            'GNGVLKTISHECNMNTWQFLFTCC',
            'FLPILAGLAANLVPKLICSITKKC',
            'FLGAVLKVAGKLVPAAICKISKKC',
            'FLGALFKVASKLVPAAICSISKKC',
            'FLPVIAGIAANVLPKLFCKLTKRC',
            'FFPIIARLAAKVIPSLVCAVTKKC',
            'KRVNWRKVGRNTALGASYVLSFLG',
            'GHSVDRIPEYFGPPGLPGPVLFYS',
            'FLPLIAGVAAKVLPKIFCAISKKC',
            'SDSVVSDIICTTFCSVTWCQSNCC',
            'FLPLLAGLAANFLPQIICKIARKC',
            'FLGTVLKVAAKVLPAALCQIFKKC',
            'QSHLSMCRYCCCKGNKGCGFCCKF',
            'VFDIIKDAGKQLVAHAMGKIAEKV',
            'VFDIIKDAGRQLVAHAMGKIAEKV',
            'FLPLLAGLAASFLPTIFCKISRKC',
            'FFPIVAGVAAKVLKKIFCTISKKC',
    # ...
]


# 2. Build character-to-index mapping
'''
unique_amino_acids: This extracts all the unique characters (amino acid letters) from the sequences.
char_to_idx: This dictionary maps each amino acid character to a unique numerical index (e.g., 'F' might be 0, 'L' might be 1, etc.).
idx_to_char: This dictionary does the reverse, mapping numerical indices back to amino acid characters.
vocab_size: This stores the total number of unique amino acids in the dataset.
'''
unique_amino_acids = sorted(list(set("".join(amp_sequences))))
char_to_idx = {char: idx for idx, char in enumerate(unique_amino_acids)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(unique_amino_acids)


# 3. Convert to integer arrays
'''
encoded_sequences: This converts the original protein sequences (amp_sequences) into numerical representations using the char_to_idx mapping.
  Each amino acid is replaced with its corresponding index.
seq_length: This sets the maximum length of the sequences the model will handle
  (24 amino acids in this case).
X and y: These are created to train the model. X contains the input sequences
  (all but the last amino acid), and y contains the target sequences (all but the first amino acid). This setup is for next-token prediction, where the model learns to predict the next amino acid in a sequence.
'''
encoded_sequences = []
for seq in amp_sequences:
    encoded_sequences.append([char_to_idx[c] for c in seq])
encoded_sequences = np.array(encoded_sequences)  # shape: (num_sequences, seq_length)

# Prepare training data for next-token prediction
seq_length = 24
X = encoded_sequences[:, :-1]  # shape: (num_sequences, seq_length-1)
y = encoded_sequences[:, 1:]   # shape: (num_sequences, seq_length-1)


# 4. Build a small Transformer model
''' We'll define the input, embedding, transformer block, and final dense layer.
This section defines the architecture of the Transformer model using TensorFlow's Keras API.
embedding_dim, num_heads, ff_dim: These are hyperparameters that control the size and complexity of the model.
The model consists of an input layer, an embedding layer (to represent amino acids as vectors),
  a positional encoding layer (to provide information about the order of amino acids),
  a transformer encoder block (the core of the model for learning relationships between amino acids),
  and a final dense layer (to output predictions for the next amino acid).
model.compile: This configures the model for training, specifying the optimizer (adam),
  loss function (sparse_categorical_crossentropy), and metrics to track (accuracy).
model.summary(): This displays a summary of the model's architecture.
'''

embedding_dim = 16
num_heads = 2
ff_dim = 32  # feed-forward layer size in transformer

# Define Input
inputs = layers.Input(shape=(seq_length-1,))  # each example is length-1 = 23

# Token Embedding + Positional Embedding
token_embedding = layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim)(inputs)

# Basic positional encoding
positions = tf.range(start=0, limit=seq_length-1, delta=1)
positional_encoding = layers.Embedding(input_dim=seq_length, output_dim=embedding_dim)(positions)
positional_encoding = positional_encoding[None, ...]  # shape: (1, seq_length-1, embedding_dim)

# Add token embedding and positional encoding
x = token_embedding + positional_encoding

# Transformer Encoder Block (simplified)
'''
This is where self-attention is handled
The model uses query, key, and attention weighting, although implicitly.
The layers.MultiHeadAttention layer handles these steps internally.
By passing x as both the query and the key/value (using (x, x)), the model is essentially performing self-attention, comparing different parts of the input sequence with itself.
The key_dim argument specifies the dimensionality of the keys and queries, influencing the complexity of the attention calculations.
'''
attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)(x, x)
attention_output = layers.Dropout(0.1)(attention_output)
x = layers.LayerNormalization(epsilon=1e-6)(x + attention_output)

ffn = layers.Dense(ff_dim, activation='relu')(x)
ffn = layers.Dense(embedding_dim)(ffn)
ffn = layers.Dropout(0.1)(ffn)
x = layers.LayerNormalization(epsilon=1e-6)(x + ffn)

# Final Dense Layer over vocab
outputs = layers.Dense(vocab_size, activation='softmax')(x)

model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()


# 5. Train the Transformer
'''
epochs: The number of times the model will go through the entire training data.
batch_size: The number of training examples processed in each iteration.
model.fit: This starts the training process, using the prepared data (X, y) and the specified training parameters.
'''
epochs = 50
batch_size = 16
model.fit(X, y, epochs=epochs, batch_size=batch_size)


# 6. Generation function
'''
generate_transformer_sequence: This function takes the trained model and a starting amino acid (start_token) and generates a new peptide sequence of the specified length.
It works by repeatedly predicting the next amino acid based on the previous ones, using the model's learned knowledge.
The example usage demonstrates how to generate 20 new sequences starting with 'F' and 20 starting with 'G'.
'''


def generate_transformer_sequence(model, start_token, length=24):
    """
    Generate a new peptide sequence from a transformer model.
    :param model: trained Keras model
    :param start_token: integer index of first amino acid
    :param length: desired total length
    :return: generated amino acid sequence (string)
    """
    generated = [start_token]

    for i in range(length-1):
        # We feed the current sequence (minus 1 for next-token prediction)
        input_seq = np.array(generated)[None, ...]  # shape: (1, current_length)

        # Model expects length=23 for training; in generation we can adapt.
        # We'll zero-pad to length=23 for simplicity (or you can dynamically mask).
        pad_len = (seq_length - 1) - len(generated)
        if pad_len < 0:
            # If your sequence is already at length=23, we only use the last 23 tokens
            input_seq = np.array(generated[-(seq_length-1):])[None, ...]
            pad_len = 0

        input_seq = np.pad(input_seq, ((0,0),(0,pad_len)), 'constant', constant_values=0)

        preds = model.predict(input_seq, verbose=0)
        # We want the last position's distribution
        last_pos = len(generated)-1 if len(generated) < (seq_length-1) else (seq_length-2)
        prob_dist = preds[0, last_pos]  # shape: (vocab_size,)

        next_idx = np.random.choice(range(vocab_size), p=prob_dist)
        generated.append(next_idx)

    # Convert to string
    generated_str = "".join(idx_to_char[idx] for idx in generated)
    return generated_str

# Example usage:
for i in range(20):
    start_token = char_to_idx['F']
    new_peptide = generate_transformer_sequence(model, start_token, length=24)
    print("Generated Peptide (Transformer):", new_peptide)

for i in range(20):
    start_token = char_to_idx['G']
    new_peptide = generate_transformer_sequence(model, start_token, length=24)
    print("Generated Peptide (Transformer):", new_peptide)

Epoch 1/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 18ms/step - accuracy: 0.0562 - loss: 3.1822
Epoch 2/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.1537 - loss: 2.7338
Epoch 3/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.2450 - loss: 2.5757
Epoch 4/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.2859 - loss: 2.4463
Epoch 5/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.2835 - loss: 2.4370
Epoch 6/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.2805 - loss: 2.4297
Epoch 7/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.3135 - loss: 2.3554
Epoch 8/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.3292 - loss: 2.3276
Epoch 9/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━