In [None]:
# Lakes' model on a single dataset with linear decoder

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install Bio --quiet
!pip install keras==3.0.0 --upgrade --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.6/278.6 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m997.1/997.1 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.0 which is incompatible.[0m[31m
[0m

In [None]:
from Bio import SeqIO

import os
os.environ["KERAS_BACKEND"] = "torch"

import torch
from torch.utils.data import Dataset, DataLoader

import keras
from keras import backend as K, layers, activations

import numpy as np
from sklearn.preprocessing import OneHotEncoder

In [None]:
print(keras.__version__)

3.0.0


In [None]:
print(K.backend())

torch


# Data processing

## Pre-processing

In [None]:
# Global variables
folder_path = 'drive/MyDrive/ae_training'
file_name1 = f'{folder_path}/card1_1273x130.fasta'
file_name2 = f'{folder_path}/drsm1_1376x177.fasta'
file_name3 = f'{folder_path}/rd1_935x221.fasta'
file_name4 = f'{folder_path}/drsm3_718x103_testing.fasta'

amino_acids_str = ' ACDEFGHIKLMNPQRSTVWY-'
amino_acids = [' ', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
               'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '-']

onehot_encoder = OneHotEncoder(categories=[amino_acids])
onehot_encoder.fit(np.array(list(amino_acids_str)).reshape(-1, 1))

# Hyperparameters
max_len = 221
num_epochs = 30
batch_size = 128
learning_rate = 1e-3

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
# Utility functions

def parse_fasta(file_path) -> list:
    "Parse a fasta file into an array of Seq"
    sequences = []
    with open(file_path, 'r') as fasta_file:
        for record in SeqIO.parse(fasta_file, "fasta"):
            sequences.append(record.seq)
    return sequences


def integer_encode(sequence, max_length) -> torch.tensor:
    "Encode a protein sequence into a sequence of integers"
    sequence = sequence.replace('X', '-')  # X also means missing
    encoding = [amino_acids.index(aa) for aa in sequence]
    # Pad the sequence to the specified maximum length
    if len(encoding) < max_length:
        encoding += [0] * (max_length - len(encoding))
    return torch.tensor(encoding).reshape(-1, 1)


def integer_decode(int_seq) -> str:
    "Decode an integer encoded sequence back to a sequence of amino acids"
    # Convert the torch tensor to a list of integers
    encoded_list = int_seq.flatten().tolist()
    # Decode each integer back to the corresponding amino acid
    decoded_sequence = ''.join([amino_acids[i] for i in encoded_list])
    return decoded_sequence


def onehot_encode(sequence, max_length) -> torch.tensor:
    "Encode a protein sequence into a sequence of one-hot vectors"
    sequence = sequence.replace('X', '-')  # X also means missing
    # Pad the sequence with whitespaces
    padding = ' ' * (max_length - len(sequence))
    sequence += padding
    protein_sequence_array = np.array(list(sequence)).reshape(-1, 1)
    one_hot_encoded_sequence = onehot_encoder.transform(protein_sequence_array)
    one_hot_encoded_array = one_hot_encoded_sequence.toarray()
    return torch.tensor(one_hot_encoded_array)  # Whitespace is [1,0,...,0] for now


def onehot_decode(onehot_seq: torch.tensor) -> str:
    "Decode a one-hot encoded sequence back to a sequence of amino acids"
    original_seq = onehot_encoder.inverse_transform(onehot_seq)
    s = [''.join(c) for c in original_seq]
    return ''.join(s)

In [None]:
sequences = parse_fasta(file_name1)
onehot_encoded_sequences = [onehot_encode(seq, max_len) for seq in sequences]
data = torch.stack(onehot_encoded_sequences)
data.shape

torch.Size([1273, 221, 22])

In [None]:
data[1]  # onehot encoding

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)

In [None]:
onehot_decode(data[1])

'---------------KNDPWDVLKNSAM-K---VLKDFCDDLIEQDV-FNQNEIKNMGKQ----LSTVKDK--SEDLVKIVTHK-GSQ-IG---DIFVKRV--LM-------AAK--QLHS---------                                                                                           '

In [None]:
sequences = parse_fasta(file_name1)
int_encoded_sequences = [integer_encode(seq, max_len) for seq in sequences]
data = torch.stack(int_encoded_sequences)
data.shape

torch.Size([1273, 221, 1])

In [None]:
integer_decode(data[1])

'---------------KNDPWDVLKNSAM-K---VLKDFCDDLIEQDV-FNQNEIKNMGKQ----LSTVKDK--SEDLVKIVTHK-GSQ-IG---DIFVKRV--LM-------AAK--QLHS---------                                                                                           '

In [None]:
# sequences1 = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name1)])[:,:,1:]
# sequences2 = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name2)])[:,:,1:]
sequences3 = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name3)])[:,:,1:]

# data = torch.cat((sequences1, sequences2, sequences3))
data = torch.cat((sequences3,))
data.shape

torch.Size([935, 221, 21])

## Build the dataloaders

Build the training dataset as in PyTorch `DataLoader`.

Ideally, the model (as `keras.Model`) should be instantiated as a PyTorch `Module` in PyTorch backend.

In [None]:
class TrainDataset(Dataset):
    def __init__(self):
        self.encoding = onehot_encode

        # sequences1 = torch.stack([self.encoding(seq, max_len) for seq in parse_fasta(file_name1)])[:,:,1:]
        # sequences2 = torch.stack([self.encoding(seq, max_len) for seq in parse_fasta(file_name2)])[:,:,1:]
        sequences3 = torch.stack([self.encoding(seq, max_len) for seq in parse_fasta(file_name3)])[:,:,1:]
        # self.data = torch.cat((sequences1, sequences2, sequences3))
        self.data = torch.cat((sequences3,))

    def __getitem__(self,idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


class TestDataset(Dataset):
    def __init__(self):
        # self.data = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name4)])[:,:,1:]
        self.data = torch.stack([onehot_encode(seq, max_len) for seq in parse_fasta(file_name3)])[:,:,1:]

    def __getitem__(self,idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

In [None]:
# Create torch Datasets
train_dataset = TrainDataset()
val_dataset = TestDataset()

# Create DataLoaders for the Datasets
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Build the LSTM Variational Autoencoder (VAE)

Rewrite the code to build the model: **Protein sequence LSTM Variational AutoEncoder (VAE)**

Write it in Keras 3 as subclass of `keras.Model` to handle variable length sequences (and missing characters).

Sources:
- VAE in Keras 3: https://keras.io/examples/generative/vae/
- LSTM Autoencoder: https://machinelearningmastery.com/lstm-autoencoders/
- Variable length: https://machinelearningmastery.com/handle-missing-timesteps-sequence-prediction-problems-python/


## Sampling layer

In [None]:
class Sampling(layers.Layer):
    "Uses (z_mean, z_log_var) to sample z, the vector encoding a sequence."
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = keras.ops.shape(z_mean)[0]
        dim = keras.ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim))
        return z_mean + keras.ops.exp(0.5 * z_log_var) * epsilon

## Encoder

In [None]:
latent_dim = 64
# input_shape = (max_len, 1)  # (max_len, 1) for integer encoding
input_shape = (max_len, 21)  # (max_len, 21) for one-hot encoding

encoder_inputs = keras.Input(shape=input_shape)
x = layers.Masking(mask_value=0.0)(encoder_inputs)
z_mean = layers.LSTM(latent_dim, activation='relu',
                     input_shape=input_shape, name="z_mean")(x)
z_log_var = layers.LSTM(latent_dim, activation='relu',
                        input_shape=input_shape, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

  super().__init__(**kwargs)


## Decoder

In [None]:
# latent_inputs = keras.Input(shape=(latent_dim,))
# x = layers.RepeatVector(max_len)(latent_inputs)
# x = layers.LSTM(latent_dim, activation='relu', return_sequences=True)(x)
# x = layers.TimeDistributed(layers.Dense(21))(x)
# decoder_outputs = layers.Softmax()(x)
# decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
# decoder.summary()

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(128)(latent_inputs)
x = layers.Dense(221)(x)
x = layers.Dense(221 * 21)(x)
x = layers.Reshape((221, 21))(x)
decoder_outputs = layers.Softmax()(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

## VAE Model

In [None]:
@keras.saving.register_keras_serializable()
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

# Training

In [None]:
model = VAE(encoder, decoder).to(device)
model.load_weights(f'{folder_path}/vae-simple.weights.h5', skip_mismatch=False)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Start training
for epoch in range(num_epochs):
    for i, x in enumerate(train_dataloader):
        # Forward pass
        x = x.to(device)
        z_mean, z_log_var, z = model.encoder(x)
        reconstruction = model.decoder(z)

        # Compute reconstruction loss and kl divergence
        # x = activations.sigmoid(x / 4)
        # reconstruction = activations.sigmoid(reconstruction / 4)
        # print(reconstruction)

        reconstruction_loss = keras.ops.mean(
            keras.ops.sum(
                keras.losses.binary_crossentropy(x, reconstruction),
                axis=1,
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - keras.ops.square(z_mean) - keras.ops.exp(z_log_var))
        kl_loss = keras.ops.mean(keras.ops.sum(kl_loss, axis=1))

        # Backprop and optimize
        loss = reconstruction_loss + kl_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, len(train_dataloader), reconstruction_loss.item(), kl_loss.item()))

model.save_weights(f'{folder_path}/vae-simple.weights.h5', overwrite=True)

# Testing

In [None]:
model = VAE(encoder, decoder).to(device)
model.load_weights(f'{folder_path}/vae-simple.weights.h5', skip_mismatch=False)

In [None]:
with torch.no_grad():
    for i, x in enumerate(val_dataloader):
        x = x.to(device)
        z_mean, z_log_var, z = model.encoder(x)
        reconstruction = model.decoder(z)

        reconstruction_loss = keras.ops.mean(
            keras.ops.sum(
                keras.losses.binary_crossentropy(x, reconstruction),
                axis=1,
            )
        )

        print ("Step [{}/{}], Reconst Loss: {:.4f}"
                .format(i+1, len(val_dataloader), reconstruction_loss.item()))

Step [1/8], Reconst Loss: 19.7697
Step [2/8], Reconst Loss: 12.9240
Step [3/8], Reconst Loss: 14.2191
Step [4/8], Reconst Loss: 11.9006
Step [5/8], Reconst Loss: 14.1742
Step [6/8], Reconst Loss: 11.6779
Step [7/8], Reconst Loss: 10.4778
Step [8/8], Reconst Loss: 11.4295


In [None]:
sequences = parse_fasta(file_name3)
onehot_encoded_sequences = [onehot_encode(seq, max_len) for seq in sequences]
data = torch.stack(onehot_encoded_sequences)
data.shape

torch.Size([935, 221, 22])

In [None]:
data[0]

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], dtype=torch.float64)

In [None]:
origin = onehot_decode(data[0])
origin

'-FANSLYKLNCVG-CSTTFCMSSD-I-K-KVYS---NYMAFDPAAW-------------------------------------QFFTV--ESK--KKKPNSYLSEDT--Q--PLSIL--KC-----A-K--C--VTTVIGKA---YKMRGVY------LPQIDVKSVFFVEE----NS-SE----------------SKT--AKKWSSVEQELFYV-GEA-'

In [None]:
x = data[0][:,1:].reshape(1, 221, 21)
z_mean, z_log_var, z = model.encoder(x)
reconst = model.decoder(z_mean)
reconst.reshape(221, 21)

tensor([[1.3024e-06, 1.2420e-06, 1.1734e-06,  ..., 1.0581e-06, 9.6972e-07,
         9.9934e-01],
        [6.3325e-03, 8.3627e-07, 6.2426e-02,  ..., 8.7345e-04, 1.5985e-02,
         2.9862e-01],
        [1.3382e-02, 8.1546e-04, 1.0301e-01,  ..., 8.7649e-07, 1.7583e-03,
         2.8733e-01],
        ...,
        [5.6290e-02, 8.3085e-07, 3.0076e-01,  ..., 8.6538e-07, 9.2334e-07,
         4.5644e-02],
        [5.0494e-03, 6.4836e-07, 6.7131e-07,  ..., 6.6579e-07, 2.5962e-03,
         4.4311e-02],
        [7.2389e-07, 6.5524e-07, 6.7681e-07,  ..., 5.7198e-07, 6.6830e-07,
         9.9426e-01]], device='cuda:0', grad_fn=<ViewBackward0>)

In [None]:
reconst.cpu().detach().numpy()
reconst = keras.utils.to_categorical(np.argmax(reconst.cpu().detach().numpy(), axis=2), 21)
reconst = reconst.reshape(221, 21)
reconst = np.hstack((np.zeros((reconst.shape[0], 1)), reconst))
reconst = onehot_decode(reconst)

In [None]:
reconst

'--NPSLVKLLCKN-CKVLVCSGSD-I-R-VIEGM--HHVNVNPAFK-------------------------------------ELYIV--REN--KPLQKKF--ADY--E--PNGEI--IC-----K-N--C-------GQD---WGIMMVY-KGLD-LPCLKIKN-FVVET----PT--G---------------KKQY---KKWKEVP---FTF-PDF-'