In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install Bio --quiet
!pip install keras==3.0.0 --upgrade --quiet

In [3]:
from Bio import SeqIO

import os
os.environ["KERAS_BACKEND"] = "torch"

import torch
from torch.utils.data import Dataset, DataLoader

import keras
from keras import backend as K, layers, activations

import numpy as np
from sklearn.preprocessing import OneHotEncoder

In [4]:
print(keras.__version__)

3.0.0


In [5]:
print(K.backend())

torch


# Data processing

## Pre-processing

In [6]:
# Global variables
folder_path = 'drive/MyDrive/ae_training'
file_name1 = f'{folder_path}/card1_1273x130.fasta'
file_name2 = f'{folder_path}/drsm1_1376x177.fasta'
file_name3 = f'{folder_path}/rd1_935x221.fasta'
file_name4 = f'{folder_path}/drsm3_718x103_testing.fasta'

amino_acids_str = ' ACDEFGHIKLMNPQRSTVWY-'
amino_acids = [' ', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
               'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '-']

onehot_encoder = OneHotEncoder(categories=[amino_acids])
onehot_encoder.fit(np.array(list(amino_acids_str)).reshape(-1, 1))

# Hyperparameters
max_len = 221
num_epochs = 30
batch_size = 128
learning_rate = 1e-3

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [7]:
# Utility functions

def parse_fasta(file_path) -> list:
    "Parse a fasta file into an array of Seq"
    sequences = []
    with open(file_path, 'r') as fasta_file:
        for record in SeqIO.parse(fasta_file, "fasta"):
            sequences.append(record.seq)
    return sequences


def integer_encode(sequence, max_length) -> torch.tensor:
    "Encode a protein sequence into a sequence of integers"
    sequence = sequence.replace('X', '-')  # X also means missing
    encoding = [amino_acids.index(aa) for aa in sequence]
    # Pad the sequence to the specified maximum length
    if len(encoding) < max_length:
        encoding += [0] * (max_length - len(encoding))
    return torch.tensor(encoding).reshape(-1, 1)


def integer_decode(int_seq) -> str:
    "Decode an integer encoded sequence back to a sequence of amino acids"
    # Convert the torch tensor to a list of integers
    encoded_list = int_seq.flatten().tolist()
    # Decode each integer back to the corresponding amino acid
    decoded_sequence = ''.join([amino_acids[i] for i in encoded_list])
    return decoded_sequence


def onehot_encode(sequence, max_length) -> torch.tensor:
    "Encode a protein sequence into a sequence of one-hot vectors"
    sequence = sequence.replace('X', '-')  # X also means missing
    # Pad the sequence with whitespaces
    padding = ' ' * (max_length - len(sequence))
    sequence += padding
    protein_sequence_array = np.array(list(sequence)).reshape(-1, 1)
    one_hot_encoded_sequence = onehot_encoder.transform(protein_sequence_array)
    one_hot_encoded_array = one_hot_encoded_sequence.toarray()
    return torch.tensor(one_hot_encoded_array)  # Whitespace is [1,0,...,0] for now


def onehot_decode(onehot_seq: torch.tensor) -> str:
    "Decode a one-hot encoded sequence back to a sequence of amino acids"
    original_seq = onehot_encoder.inverse_transform(onehot_seq)
    s = [''.join(c) for c in original_seq]
    return ''.join(s)

In [8]:
sequences = parse_fasta(file_name1)
onehot_encoded_sequences = [onehot_encode(seq, max_len) for seq in sequences]
data = torch.stack(onehot_encoded_sequences)
data.shape

torch.Size([1273, 221, 22])

In [9]:
data[1]  # onehot encoding

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)

In [10]:
onehot_decode(data[1])

'---------------KNDPWDVLKNSAM-K---VLKDFCDDLIEQDV-FNQNEIKNMGKQ----LSTVKDK--SEDLVKIVTHK-GSQ-IG---DIFVKRV--LM-------AAK--QLHS---------                                                                                           '

In [11]:
sequences = parse_fasta(file_name1)
int_encoded_sequences = [integer_encode(seq, max_len) for seq in sequences]
data = torch.stack(int_encoded_sequences)
data.shape

torch.Size([1273, 221, 1])

In [12]:
integer_decode(data[1])

'---------------KNDPWDVLKNSAM-K---VLKDFCDDLIEQDV-FNQNEIKNMGKQ----LSTVKDK--SEDLVKIVTHK-GSQ-IG---DIFVKRV--LM-------AAK--QLHS---------                                                                                           '

In [13]:
sequences1 = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name1)])[:,:,1:]
sequences2 = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name2)])[:,:,1:]
sequences3 = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name3)])[:,:,1:]

data = torch.cat((sequences1, sequences2, sequences3))
# data.shape
data[1]

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)

## Build the dataloaders

Build the training dataset as in PyTorch `DataLoader`.

Ideally, the model (as `keras.Model`) should be instantiated as a PyTorch `Module` in PyTorch backend.

In [14]:
class TrainDataset(Dataset):
    def __init__(self):
        self.encoding = onehot_encode

        sequences1 = torch.stack([self.encoding(seq, max_len) for seq in parse_fasta(file_name1)])[:,:,1:]
        sequences2 = torch.stack([self.encoding(seq, max_len) for seq in parse_fasta(file_name2)])[:,:,1:]
        sequences3 = torch.stack([self.encoding(seq, max_len) for seq in parse_fasta(file_name3)])[:,:,1:]
        self.data = torch.cat((sequences1, sequences2, sequences3))

    def __getitem__(self,idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


class TestDataset(Dataset):
    def __init__(self):
        self.data = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name4)])[:,:,1:]

    def __getitem__(self,idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

In [15]:
# Create torch Datasets
train_dataset = TrainDataset()
val_dataset = TestDataset()

# Create DataLoaders for the Datasets
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Build the LSTM Variational Autoencoder (VAE)

Rewrite the code to build the model: **Protein sequence LSTM Variational AutoEncoder (VAE)**

Write it in Keras 3 as subclass of `keras.Model` to handle variable length sequences (and missing characters).

Sources:
- VAE in Keras 3: https://keras.io/examples/generative/vae/
- LSTM Autoencoder: https://machinelearningmastery.com/lstm-autoencoders/
- Variable length: https://machinelearningmastery.com/handle-missing-timesteps-sequence-prediction-problems-python/


## Sampling layer

In [16]:
class Sampling(layers.Layer):
    "Uses (z_mean, z_log_var) to sample z, the vector encoding a sequence."
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = keras.ops.shape(z_mean)[0]
        dim = keras.ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim))
        return z_mean + keras.ops.exp(0.5 * z_log_var) * epsilon

## Encoder

In [17]:
latent_dim = 64
# input_shape = (max_len, 1)  # (max_len, 1) for integer encoding
input_shape = (max_len, 21)  # (max_len, 21) for one-hot encoding

encoder_inputs = keras.Input(shape=input_shape)
x = layers.Masking(mask_value=0.0)(encoder_inputs)
z_mean = layers.LSTM(latent_dim, activation='relu',
                     input_shape=input_shape, name="z_mean")(x)
z_log_var = layers.LSTM(latent_dim, activation='relu',
                        input_shape=input_shape, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

  super().__init__(**kwargs)


## Decoder

In [18]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.RepeatVector(max_len)(latent_inputs)
x = layers.LSTM(latent_dim, activation='relu', return_sequences=True)(x)
decoder_outputs = layers.TimeDistributed(layers.Dense(21))(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

## VAE Model

In [19]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

# Training

In [20]:
model = VAE(encoder, decoder).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Start training
for epoch in range(num_epochs):
    for i, x in enumerate(train_dataloader):
        # Forward pass
        x = x.to(device)
        z_mean, z_log_var, z = model.encoder(x)
        reconstruction = model.decoder(z)

        # Compute reconstruction loss and kl divergence
        # x = activations.sigmoid(x / 4)
        # reconstruction = activations.sigmoid(reconstruction / 4)
        # print(reconstruction)

        reconstruction_loss = keras.ops.mean(
            keras.ops.sum(
                keras.losses.binary_crossentropy(x, reconstruction),
                axis=1,
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - keras.ops.square(z_mean) - keras.ops.exp(z_log_var))
        kl_loss = keras.ops.mean(keras.ops.sum(kl_loss, axis=1))

        # Backprop and optimize
        loss = reconstruction_loss + kl_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, len(train_dataloader), reconstruction_loss.item(), kl_loss.item()))

Epoch[1/30], Step [10/28], Reconst Loss: 796.1511, KL Div: 0.1215
Epoch[1/30], Step [20/28], Reconst Loss: 645.0278, KL Div: 0.1289
Epoch[2/30], Step [10/28], Reconst Loss: 536.7051, KL Div: 0.1370
Epoch[2/30], Step [20/28], Reconst Loss: 489.1587, KL Div: 0.1347
Epoch[3/30], Step [10/28], Reconst Loss: 483.3050, KL Div: 0.1339
Epoch[3/30], Step [20/28], Reconst Loss: 415.3433, KL Div: 0.1350
Epoch[4/30], Step [10/28], Reconst Loss: 341.9256, KL Div: 0.1360
Epoch[4/30], Step [20/28], Reconst Loss: 348.0709, KL Div: 0.1344
Epoch[5/30], Step [10/28], Reconst Loss: 277.3434, KL Div: 0.1289
Epoch[5/30], Step [20/28], Reconst Loss: 309.0907, KL Div: 0.1323
Epoch[6/30], Step [10/28], Reconst Loss: 303.1607, KL Div: 0.1304
Epoch[6/30], Step [20/28], Reconst Loss: 259.0059, KL Div: 0.1386
Epoch[7/30], Step [10/28], Reconst Loss: 269.4931, KL Div: 0.1342
Epoch[7/30], Step [20/28], Reconst Loss: 254.0221, KL Div: 0.1338
Epoch[8/30], Step [10/28], Reconst Loss: 210.8120, KL Div: 0.1301
Epoch[8/30

In [21]:
# Save and load only the model parameters (recommended).
torch.save(model.state_dict(), f'{folder_path}/params.ckpt')

In [None]:
# model = VAE(encoder, decoder).to(device)
# model.load_state_dict(torch.load(f'{folder_path}/params.ckpt'))