In [None]:
"""
LSTM Generative Model for Antimicrobial Peptides
================================================

This script shows how to:
1) Preprocess AMP sequences (tokenize amino acids).
2) Train an LSTM-based model to predict the next amino acid.
3) Generate new sequences by sampling from the trained model.

Note: With only ~150 AMP sequences (each length 24), overfitting is likely.
      Consider data augmentation, dropout, or pretraining on larger protein sets.
"""

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.optimizers import Adam

# 1. Example dataset of AMP sequences
amp_sequences = [
            'FLPLLAGLAANFLPTIICKISYKC',
            'FLPFIARLAAKVFPSIICSVTKKC',
            'GVLSNVIGYLKKLGTGALNAVLKQ',
            'GLFSVLGAVAKHVLPHVVPVIAEK',
            'GLFKVLGSVAKHLLPHVAPVIAEK',
            'GLFKVLGSVAKHLLPHVVPVIAEK',
            'GLFGVLGSIAKHVLPHVVPVIAEK',
            'MFFSSKKCKTVSKTFRGPCVRNAN',
            'LLKELWTKMKGAGKAVLGKIKGLL',
            'LLKELWTKIKGAGKAVLGKIKGLL',
            'FWGALIKGAAKLIPSVVGLFKKKQ',
            'FLPVVAGLAAKVLPSIICAVTKKC',
            'FLPAIVGAAGQFLPKIFCAISKKC',
            'FLPAIVGAAGKFLPKIFCAISKKC',
            'FFPIVAGVAGQVLKKIYCTISKKC',
            'FLPIIAGIAAKVFPKIFCAISKKC',
            'FLPMLAGLAASMVPKLVCLITKKC',
            'FLPMLAGLAASMVPKFVCLITKKC',
            'FLPFIAGMAAKFLPKIFCAISKKC',
            'FLPAIAGMAAKFLPKIFCAISKKC',
            'FLPFIAGVAAKFLPKIFCAISKKC',
            'FLPAIAGVAAKFLPKIFCAISKKC',
            'FLPAIVGAAAKFLPKIFCVISKKC',
            'FLPFIAGMAANFLPKIFCAISKKC',
            'FLPIIAGVAAKVFPKIFCAISKKC',
            'FLPIIASVAAKVFSKIFCAISKKC',
            'FLPIIASVAANVFSKIFCAISKKC',
            'GLNTLKKVFQGLHEAIKLINNHVQ',
            'GLNALKKVFQGIHEAIKLINNHVQ',
            'DSHAKRHHGYKRKFHEKHHSHRGY',
            'FLPLLAGLAANFLPKIFCKITKKC',
            'FLPILAGLAAKIVPKLFCLATKKC',
            'FLPLIAGLAANFLPKIFCAITKKC',
            'FLPVIAGVAAKFLPKIFCAITKKC',
            'FWGALAKGALKLIPSLFSSFSKKD',
            'ITSVSWCTPGCTSEGGGSGCSHCC',
            'GLLNGLALRLGKRALKKIIKRLCR',
            'ALWKDILKNAGKAALNEINQLVNQ',
            'GLRSKIWLWVLLMIWQESNKFKKM',
            'GKGRWLERIGKAGGIIIGGALDHL',
            'FLGALIKGAIHGGRFIHGMIQNHH',
            'FLGLLFHGVHHVGKWIHGLIHGHH',
            'FLPMLAGLAANFLPKLFCKITKKC',
            'FLPLAVSLAANFLPKLFCKITKKC',
            'FLPLLAGLAANFFPKIFCKITRKC',
            'FLPILASLAAKFGPKLFCLVTKKC',
            'FLPILASLAAKLGPKLFCLVTKKC',
            'FLPILASLAATLGPKLLCLITKKC',
            'GIFSNMYARTPAGYFRGPAGYAAN',
            'GLKDKFKSMGEKLKQYIQTWKAKF',
            'SLKDKVKSMGEKLKQYIQTWKAKF',
            'GFRDVLKGAAKAFVKTVAGHIANI',
            'GIKDWIKGAAKKLIKTVASNIANQ',
            'GFKDWIKGAAKKLIKTVASSIANQ',
            'VIPFVASVAAEMMQHVYCAASKKC',
            'FFGTALKIAANVLPTAICKILKKC',
            'FFGTALKIAANILPTAICKILKKC',
            'ILPFVAGVAAEMMQHVYCAASKKC',
            'FLPAIVGAAAKFLPKIFCAISKKC',
            'FLPIIAGVAAKVLPKIFCAISKKC',
            'FLPIIAGIAAKFLPKIFCTISKKC',
            'FLPVIAGVAANFLPKLFCAISKKC',
            'FLPIIAGAAAKVVQKIFCAISKKC',
            'FLPIIAGAAAKVVEKIFCAISKKC',
            'FLPAVLRVAAKIVPTVFCAISKKC',
            'FLPAVLRVAAQVVPTVFCAISKKC',
            'FMGGLIKAATKIVPAAYCAITKKC',
            'FLPILAGLAAKLVPKVFCSITKKC',
            'FLPILAGLAANILPKVFCSITKKC',
            'FFPIIAGMAAKLIPSLFCKITKKC',
            'FMGSALRIAAKVLPAALCQIFKKC',
            'DSHEKRHHEHRRKFHEKHHSHRGY',
            'WRSLGRTLLRLSHALKPLARRSGW',
            'VTSWSLCTPGCTSPGGGSNCSFCC',
            'VIPFVASVAAEMMHHVYCAASKRC',
            'SPAGCRFCCGCCPNMRGCGVCCRF',
            'GRGREFMSNLKEKLSGVKEKMKNS',
            'FLPVLTGLTPSIVPKLVCLLTKKC',
            'FLPVLAGLTPSIVPKLVCLLTKKC',
            'FFPMLAGVAARVVPKVICLITKKC',
            'DSMGAVKLAKLLIDKMKCEVTKAC',
            'FLPGVLRLVTKVGPAVVCAITRNC',
            'VIVFVASVAAEMMQHVYCAASKKC',
            'FLPAVIRVAANVLPTAFCAISKKC',
            'IDPFVAGVAAEMMQHVYCAASKKC',
            'INPFVAGVAAEMMQHVYCAASKKC',
            'ILPFVAGVAAEMMKHVYCAASKKC',
            'IIPFVAGVAAEMMEHVYCAASKKC',
            'QLPFVAGVACEMCQCVYCAASKKC',
            'ILPFVAGVAAEMMEHVYCAASKKC',
            'ILPFVAGVAAMEMEHVYCAASKKC',
            'FLPAVLLVATHVLPTVFCAITRKC',
            'IPWKLPATFRPVERPFSKPFCRKD',
            'FLPLLAGVVANFLPQIICKIARKC',
            'FLGSLLGLVGKVVPTLFCKISKKC',
            'FIGPVLKIAAGILPTAICKIFKKC',
            'FVGPVLKIAAGILPTAICKIYKKC',
            'FLGPIIKIATGILPTAICKFLKKC',
            'FLPLIASLAANFVPKIFCKITKKC',
            'FLPLIASVAANLVPKIFCKITKKC',
            'FLSTLLKVAFKVVPTLFCPITKKC',
            'KRKCPKTPFDNTPGAWFAHLILGC',
            'FLGLIFHGLVHAGKLIHGLIHRNR',
            'FLPAVIRVAANVLPTVFCAISKKC',
            'FLPAVLRVAAKVVPTVFCLISKKC',
            'FLSTALKVAANVVPTLFCKITKKC',
            'FLPIVAGLAANFLPKIVCKITKKC',
            'FLSTLLNVASNVVPTLICKITKKC',
            'FLSTLLNVASKVVPTLFCKITKKC',
            'FLPMLAGLAANFLPKIVCKITKKC',
            'FIGPVLKMATSILPTAICKGFKKC',
            'FLGPIIKMATGILPTAICKGLKKC',
            'FLPIIAGVAAKVLPKLFCAITKKC',
            'FLPVIAGLAAKVLPKLFCAITKKC',
            'RKGWFKAMKSIAKFIAKEKLKEHL',
            'FLPAVLKVAAHILPTAICAISRRC',
            'FMGTALKIAANVLPAAFCKIFKKC',
            'KLGFENFLVKALKTVMHVPTSPLL',
            'GWLPTFGKILRKAMQLGPKLIQPI',
            'GNGVVLTLTHECNLATWTKKLKCC',
            'ITIPPIVKNTLKKFIKGAVSALMS',
            'FLPGLIKAAVGVGSTILCKITKKC',
            'FLPGLIKAAVGIGSTIFCKISKKC',
            'FLPGLIKVAVGVGSTILCKITKKC',
            'FLPGLIKAAVGIGSTIFCKISRKC',
            'FLPMLAGLAANFLPKIICKITKKC',
            'FLPIVASLAANFLPKIICKITKKC',
            'FWGALAKGALKLIPSLVSSFTKKD',
            'FFPLIAGLAARFLPKIFCSITKRC',
            'VIPFVASVAAEMMQHVYCAASKRC',
            'FFPSIAGLAAKFLPKIFCSITKRC',
            'FLPAVLRVAAKVGPAVFCAITQKC',
            'FLGMLLHGVGHAIHGLIHGKQNVE',
            'NPAGCRFCCGCCPNMIGCGVCCRF',
            'IWSFLIKAATKLLPSLFGGGKKDS',
            'RNGCIVDPRCPYQQCRRPLYCRRR',
            'ILELAGNAARDNKKTRIIPRHLQL',
            'FLPLLAGLAANFLPTIICKIARKC',
            'FLPAIIGMAAKVLPAFLCKITKKC',
            'RRRRRFRRVIRRIRLPKYLTINTE',
            'GNGVLKTISHECNMNTWQFLFTCC',
            'FLPILAGLAANLVPKLICSITKKC',
            'FLGAVLKVAGKLVPAAICKISKKC',
            'FLGALFKVASKLVPAAICSISKKC',
            'FLPVIAGIAANVLPKLFCKLTKRC',
            'FFPIIARLAAKVIPSLVCAVTKKC',
            'KRVNWRKVGRNTALGASYVLSFLG',
            'GHSVDRIPEYFGPPGLPGPVLFYS',
            'FLPLIAGVAAKVLPKIFCAISKKC',
            'SDSVVSDIICTTFCSVTWCQSNCC',
            'FLPLLAGLAANFLPQIICKIARKC',
            'FLGTVLKVAAKVLPAALCQIFKKC',
            'QSHLSMCRYCCCKGNKGCGFCCKF',
            'VFDIIKDAGKQLVAHAMGKIAEKV',
            'VFDIIKDAGRQLVAHAMGKIAEKV',
            'FLPLLAGLAASFLPTIFCKISRKC',
            'FFPIVAGVAAKVLKKIFCTISKKC',
    # AMP sequences, each of length 24
]

# 2. Build a character-to-index mapping
#    In real data, you might have 20 canonical amino acids + special tokens if needed.
unique_amino_acids = sorted(list(set("".join(amp_sequences))))
# e.g., unique_amino_acids might look like: ["A", "C", "D", "E", ..., "Y"]

char_to_idx = {char: idx for idx, char in enumerate(unique_amino_acids)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

vocab_size = len(unique_amino_acids)  # e.g., could be 20 if strictly canonical

# 3. Convert sequences to integer arrays
encoded_sequences = []
for seq in amp_sequences:
    encoded_sequences.append([char_to_idx[c] for c in seq])

encoded_sequences = np.array(encoded_sequences)  # shape: (num_sequences, seq_length)

# 4. Prepare training data
#    We can train a "next-character prediction" model. Treat it like a tiime series.
#    For each position t in a sequence, predict the amino acid at position t+1.
#    We'll "shift" the sequence by 1 for targets.
#
#    Input: [X_0, X_1, ..., X_{22}],
#    Target: [X_1, X_2, ..., X_{23}].
#    We do this for all sequences.

X = encoded_sequences[:, :-1]  # all but last character
y = encoded_sequences[:, 1:]   # all but first character

# 5. Define LSTM model
'''
Sequential: This creates a linear stack of layers to build the LSTM model.
Embedding: This layer converts each amino acid index into a dense vector representation
  (embedding) of size embedding_dim. This allows the model to capture relationships between amino acids.
LSTM: This is the core layer, learning long-term dependencies in the sequence data. lstm_units sets the dimensionality of the LSTM's hidden state.
return_sequences=True makes the LSTM output a sequence for each input sequence,
  necessary for predicting the next amino acid at each position.
Dense: This is the output layer, with vocab_size neurons. It uses the 'softmax'
  activation to produce a probability distribution over all possible amino acids,
  representing the model's prediction for the next amino acid in the sequence.
Adam: An optimization algorithm that helps the model learn more effectively.
compile: Configures the model for training, specifying the loss function, optimizer, and evaluation metrics.
model.summary(): Prints a summary of the model's architecture.
'''

model = Sequential()
# Embedding layer: (vocab_size) distinct amino acid characters -> embedding_dim vectors
embedding_dim = 8
model.add(Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=23))

# LSTM layer
lstm_units = 64
model.add(LSTM(lstm_units, return_sequences=True))

# Final Dense layer for classification over the vocabulary
model.add(Dense(vocab_size, activation='softmax'))

optimizer = Adam(learning_rate=0.01)
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

model.summary()

# 6. Train the model
#    Note: Because the dataset is small, this is primarily an illustrative example.
'''
X and y: Represent the input and target data for training.
  X contains the encoded AMP sequences shifted by one position,
  and y contains the original encoded sequences shifted by one position to the right,
  so the model learns to predict the next amino acid in the sequence.
epochs: The number of times the model sees the entire training dataset.
batch_size: The number of samples processed before the model's internal parameters are updated.
model.fit: Starts the training process.
'''

epochs = 50
batch_size = 16
model.fit(X, y, epochs=epochs, batch_size=batch_size)

# 7. Generating new sequences
'''
generate_sequence: This function takes the trained model, a starting sequence (seed_seq),
  and a desired sequence length as input. It uses the model to predict the next amino acid step-by-step, generating a new sequence.
seed: The starting point for sequence generation, in this case, the amino acid 'F'.
The loop runs 20 times, generating and printing 20 new AMP sequences.
'''
def generate_sequence(model, seed_seq, length=24):
    """
    Generate a new sequence of desired length using the trained model.
    :param model: trained LSTM model
    :param seed_seq: list of integer-encoded amino acids (starting sequence)
    :param length: desired total length of generated sequence
    :return: string of amino acids
    """
    generated = seed_seq[:]  # copy

    for _ in range(length - len(seed_seq)):
        # Predict next amino acid distribution
        input_seq = np.array(generated[-1:])  # last amino acid as input
        input_seq = input_seq.reshape(1, -1)  # shape: (1, 1)

        # Model expects a fixed input length of 23 for each training example,
        # so for generation, we can adapt in different ways.
        # Simplest approach: pad/truncate to length=23 and only use last token for the next prediction
        # We'll do a simple approach:
        padded_seq = np.zeros((1, 23))
        padded_seq[0, 22] = input_seq[0, 0]

        # Predict
        preds = model.predict(padded_seq, verbose=0)[0, 22, :]

        next_idx = np.random.choice(range(vocab_size), p=preds)
        generated.append(next_idx)

    # Convert generated integer tokens to string
    generated_str = "".join(idx_to_char[idx] for idx in generated)
    return generated_str

# Example usage:
# Start generation from a single amino acid: 'F'
seed = [char_to_idx['F']]  # or choose any valid token from your vocab
for i in range(20):
    new_peptide = generate_sequence(model, seed, length=24)
    print("Generated Peptide:", new_peptide)




Epoch 1/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 21ms/step - accuracy: 0.1121 - loss: 2.8813
Epoch 2/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.1955 - loss: 2.5743
Epoch 3/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.2207 - loss: 2.4862
Epoch 4/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 23ms/step - accuracy: 0.2384 - loss: 2.4447
Epoch 5/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.3113 - loss: 2.2447
Epoch 6/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.3391 - loss: 2.1580
Epoch 7/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.4090 - loss: 2.0077
Epoch 8/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 31ms/step - accuracy: 0.4239 - loss: 1.9185
Epoch 9/50
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━