In [None]:
!pip install Bio --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.6/278.6 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from Bio import SeqIO
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Data loading and parsing

Start by defining some useful utility functions for data parsing

In [None]:
def parse_fasta(file_path) -> list:
    "Parse a fasta file into lists of sequences and their corresponding ids"
    sequences, ids = [], []
    with open(file_path, 'r') as fasta_file:
        for record in SeqIO.parse(fasta_file, "fasta"):
            ids.append(record.id)
            sequences.append(record.seq)
    return ids, sequences

def integer_encode(sequence: str) -> torch.tensor:
    "Encode a protein sequence (string) as column vector of integers"
    sequence = sequence.replace('X', '-')  # X also means missing
    encoding = [amino_acids.index(aa) for aa in sequence]
    return torch.tensor(encoding).reshape(-1, 1)


def integer_decode(int_seq: torch.tensor) -> str:
    "Decode an integer encoded protein sequence back to string"
    # Convert the column vector to a list
    encoded_list = int_seq.flatten().tolist()
    # Decode each integer back to the corresponding amino acid
    decoded_sequence = ''.join([amino_acids[i] for i in encoded_list])
    return decoded_sequence


def onehot_encode(sequence: str) -> torch.tensor:
    "Encode a protein sequence (string) as matrix with one-hot row vectors"
    sequence = sequence.replace('X', '-')  # X also means missing
    # Convert to column vector with character entries
    char_vec = np.array(list(sequence)).reshape(-1, 1)
    # Use global instantiated OneHotEncoder sklearn object to transform
    one_hot_encoded_array = onehot_encoder.transform(char_vec)
    return torch.tensor(one_hot_encoded_array) # return as torch tensor

def onehot_decode(onehot_ary: torch.tensor) -> str:
    "Decode a one-hot encoded sequence back to a sequence of amino acids"
    char_vec = onehot_encoder.inverse_transform(onehot_ary) # column vector with character entries
    cs = [c[0] for c in char_vec]
    return ''.join(cs)

Our data consists of multiple sequence alignments (MSAs), which are saved as fasta files. In particular, we have MSAs for five different "domains" or "families" of proteins. They are the result of applying the MUSCLE alignment algorithm to sequences within each domain curated from previous studies (DSRM: Dias et al. 2017, RIG-like receptor: Mukherjee et al. 2014; Pugh et al. 2016, and CARD: Korithoski et al. 2015.) We assume knowledge of the phylogeny relating the sequences within each of the domains.

In [None]:
project_folder = 'drive/MyDrive/ASR-AE'
train_file1 = f'{project_folder}/Training Data/card1_muscle.fasta'
train_file2 = f'{project_folder}/Training Data/drsm1_muscle.fasta'
train_file3 = f'{project_folder}/Training Data/rd1_muscle.fasta'
train_file4 = f'{project_folder}/drsm2_mucsle.fasta' #originally, we were going to use for validation
train_file5 = f'{project_folder}/drsm3_mucsle.fasta' #originally, we were going to use for testing

# Get data for only the first MSA
ids, seqs = parse_fasta(train_file1)
seqs[:10]

[Seq('---------------KGDPVDKIKVITK-N---FVDGFVENLLDRDV-INRRNL...---'),
 Seq('---------------KNDPWDVLKNSAM-K---VLKDFCDDLIEQDV-FNQNEI...---'),
 Seq('-----------------DPLKSIETKAT-K---MIKNIFDDLIEQDV-INSNQI...---'),
 Seq('-------------PLQKDSVDTLKSMAK-N---LVGGILSDFKEKNV-IDENYL...---'),
 Seq('------------------------MAGN-----RVNDFIEDLKGKKV-LTKQEL...---'),
 Seq('------------------------MAGN-----RVNDFIEDLKGKKV-LTKQEL...---'),
 Seq('------------------------MAGN-----RVNDFIEDLKGKKV-LTKQEL...---'),
 Seq('------------------------MAGN-----RVNDFIEDLKGKKV-LTKQEL...---'),
 Seq('------------------------MAGN-----RVNDFIEDLKGKKA-LTKQEL...---'),
 Seq('------------------------MAGN-----RVNDFIEDLKGKKA-LTKQEL...---')]

We will need to encode our data. Our neural network will eventually require us to work with one-hot encoded data so that each sequence becomes a $n_l \times n_c$ matrix, where $n_l$ is the sequence length in the MSA and $n_c = 21$ is the number of distinct character (number of amino acids plus one for a gap).

In [None]:
# amino acid characters
amino_acids_str = '-ACDEFGHIKLMNPQRSTVWY' # "-" = missing
amino_acids = list(amino_acids_str)

# instantiage and fit OneHotEncoder transormation object
onehot_encoder = OneHotEncoder(categories=[amino_acids], sparse_output=False)
onehot_encoder.fit(np.array(list(amino_acids_str)).reshape(-1, 1))

# encode the first sequence in two ways
int_code = integer_encode(seqs[0])
oh_code = onehot_encode(seqs[0])
print(int_code.shape)
print(oh_code.shape)

# check that they decode back to the same strings
print(integer_decode(int_code))
print(onehot_decode(oh_code))

torch.Size([130, 1])
torch.Size([130, 21])
---------------KGDPVDKIKVITK-N---FVDGFVENLLDRDV-INRRNLQKLGNT----IGDIVKG--TQNLFEEFKEQ-SEK-GN---IVMV-----IG-------NPK--KQLSLKL------
---------------KGDPVDKIKVITK-N---FVDGFVENLLDRDV-INRRNLQKLGNT----IGDIVKG--TQNLFEEFKEQ-SEK-GN---IVMV-----IG-------NPK--KQLSLKL------


Now we one-hot encode each sequence in the MSA and stack the $n$ matrices into a tensor.

In [None]:
onehot_encoded_sequences = [onehot_encode(seq) for seq in seqs]
data_tensor = torch.stack(onehot_encoded_sequences)
data_tensor.shape

torch.Size([1273, 130, 21])

## Build Dataset class

In [None]:
class MSA_Dataset(Dataset):
    '''
    Dataset class for multiple sequence alignment.
    '''

    def __init__(self, seq_msa_binary, seq_weight, seq_keys):
        '''
        seq_msa_binary: a three dimensional tensor.
                        size: [num_of_sequences, length_of_msa, num_amino_acid_types]
        seq_weight: one dimensional tensor.
                    size: [num_sequences].
                    Weights for sequences in a MSA.
                    The sum of seq_weight has to be equal to 1 when training latent space models using VAE
        seq_keys: name of sequences in MSA
        '''
        super(MSA_Dataset).__init__()
        self.seq_msa_binary = seq_msa_binary.to(torch.float32) # for training
        self.seq_weight = seq_weight
        self.seq_keys = seq_keys

    def __len__(self):
        assert(self.seq_msa_binary.shape[0] == len(self.seq_weight))
        assert(self.seq_msa_binary.shape[0] == len(self.seq_keys))
        return self.seq_msa_binary.shape[0]

    def __getitem__(self, idx):
        return self.seq_msa_binary[idx,:,:], self.seq_weight[idx], self.seq_keys[idx]

n_seq = len(ids)
wts = torch.ones(n_seq) / n_seq
data = MSA_Dataset(data_tensor, wts, ids)
# Get a random sample
data[100]

(tensor([[1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]]),
 tensor(0.0008),
 'XP_008018881.1__Chlorocebus_sabaeus')

We see that a data item is a tuple containing the one-hot representation of the sequence, the weight assigned to the sequence (explained later), and the name of the sequence

## Model

This code is adopted with a few changes from the PEVAE paper

In [None]:
class VAE(nn.Module):
    def __init__(self, nl, nc=21, dim_latent_vars=10, num_hidden_units=[256, 256]):
        """
        For now, we keep our model simple with both encoder and decoder having
        only fully connected layers (and the same number of them)

        Our default is that the latent embeddings are dimension 10 and there are
        256 neurons in each of the hidden layers of the encoder and decoder
        """
        super(VAE, self).__init__()

        ## num of amino acid types
        self.nc = nc

        ## length of sequences in the MSA
        self.nl = nl

        ## dimension of input
        self.dim_input = nc * nl

        ## dimension of latent space
        self.dim_latent_vars = dim_latent_vars

        ## num of hidden neurons in encoder and decoder networks
        self.num_hidden_units = num_hidden_units

        ## encoder
        self.encoder_linears = nn.ModuleList()
        self.encoder_linears.append(nn.Linear(self.dim_input, num_hidden_units[0]))
        for i in range(1, len(num_hidden_units)):
            self.encoder_linears.append(nn.Linear(num_hidden_units[i-1], num_hidden_units[i]))
        self.encoder_mu = nn.Linear(num_hidden_units[-1], dim_latent_vars)
        self.encoder_logsigma = nn.Linear(num_hidden_units[-1], dim_latent_vars)

        ## decoder
        self.decoder_linears = nn.ModuleList()
        self.decoder_linears.append(nn.Linear(dim_latent_vars, num_hidden_units[0]))
        for i in range(1, len(num_hidden_units)):
            self.decoder_linears.append(nn.Linear(num_hidden_units[i-1], num_hidden_units[i]))
        self.decoder_linears.append(nn.Linear(num_hidden_units[-1], self.dim_input))

    def encoder(self, x):
        '''
        encoder transforms x into latent space z
        '''
        # convert from matrix to vector by concatenating rows (which are one-hot vectors)
        h = torch.flatten(x, start_dim=1) # start_dim=1 to maintain batch dimension
        for T in self.encoder_linears:
            h = T(h)
            h = F.relu(h)
        mu = self.encoder_mu(h)
        sigma = torch.exp(self.encoder_logsigma(h))
        return mu, sigma

    def decoder(self, z):
        '''
        decoder transforms latent space z into p, which is the probability  of x being 1.
        '''
        h = z
        for i in range(len(self.decoder_linears)-1):
            h = self.decoder_linears[i](h)
            h = F.relu(h)
        h = self.decoder_linears[-1](h) #Should now have dimension nc*nl

        fixed_shape = tuple(h.shape[0:-1])
        h = torch.unsqueeze(h, -1)
        h = torch.reshape(h, fixed_shape + (-1, self.nc))
        log_p = F.log_softmax(h, dim = -1)
        #log_p = torch.reshape(log_p, fixed_shape + (-1,))

        return log_p

    def compute_weighted_elbo(self, x, weight):
        ## sample z from q(z|x)
        mu, sigma = self.encoder(x)
        eps = torch.randn_like(sigma)
        z = mu + sigma*eps

        ## compute log p(x|z)
        log_p = self.decoder(z)
        log_PxGz = torch.sum(x*log_p, -1)

        ## compute elbo
        elbo = log_PxGz - torch.sum(0.5*(sigma**2 + mu**2 - 2*torch.log(sigma) - 1), -1)
        weight = weight / torch.sum(weight)
        elbo = torch.sum(elbo*weight)

        return elbo

    def compute_elbo_with_multiple_samples(self, x, num_samples):
        '''
        Evidence lower bound is an lower bound of log P(x). Although it is a lower
        bound, we can use elbo to approximate log P(x).
        Using multiple samples to calculate the elbo makes it be a better approximation
        of log P(x).
        '''

        with torch.no_grad():
            x = x.expand(num_samples, x.shape[0], x.shape[1])
            mu, sigma = self.encoder(x)
            eps = torch.randn_like(mu)
            z = mu + sigma * eps
            log_Pz = torch.sum(-0.5*z**2 - 0.5*torch.log(2*z.new_tensor(np.pi)), -1)
            log_p = self.decoder(z)
            log_PxGz = torch.sum(x*log_p, -1)
            log_Pxz = log_Pz + log_PxGz

            log_QzGx = torch.sum(-0.5*(eps)**2 -
                                 0.5*torch.log(2*z.new_tensor(np.pi))
                                 - torch.log(sigma), -1)
            log_weight = (log_Pxz - log_QzGx).detach().data
            log_weight = log_weight.double()
            log_weight_max = torch.max(log_weight, 0)[0]
            log_weight = log_weight - log_weight_max
            weight = torch.exp(log_weight)
            elbo = torch.log(torch.mean(weight, 0)) + log_weight_max
            return elbo

In [None]:
# Dimensions of one-hot encoding
nl = data_tensor.shape[1]
nc = data_tensor.shape[2]
# For architecture hyper-parameters, we rely on the defaults in the class definition
model = VAE(nl = nl, nc = nc)

Let's check that our model processes data the way we want it to

In [None]:
#Encoding
one_hot_ary = data[100][0]
batch_one_hot_ary = torch.unsqueeze(one_hot_ary, 0)
latent_parameters = model.encoder(batch_one_hot_ary)
print(f"Mean and variance of latent vector: {latent_parameters}")
#Decoding
mn_z = latent_parameters[0]
recon_log_probs = model.decoder(mn_z)
print(f"Decoded output has shape {recon_log_probs.shape} and is given by")
print(recon_log_probs)
probs = torch.exp(recon_log_probs.squeeze())
print("The probability for each amino acid in each position is")
print(probs)
print("Rows should sum to 1: ")
print(torch.sum(probs, dim = 1))

Mean and variance of latent vector: (tensor([[-0.0189,  0.0212, -0.0329,  0.0069, -0.0412,  0.0215,  0.0807,  0.0102,
         -0.0417, -0.0510]], grad_fn=<AddmmBackward0>), tensor([[0.9599, 1.0618, 1.0214, 0.9816, 0.9903, 1.0196, 0.9594, 1.0351, 1.0125,
         0.9293]], grad_fn=<ExpBackward0>))
Decoded output has shape torch.Size([1, 130, 21]) and is given by
tensor([[[-3.1071, -2.9985, -2.9959,  ..., -3.0843, -3.0167, -3.0363],
         [-3.0460, -3.0296, -3.0044,  ..., -3.0338, -3.1036, -3.0152],
         [-3.0653, -2.9674, -3.0949,  ..., -3.0450, -3.0645, -3.1256],
         ...,
         [-3.1468, -3.0576, -3.0766,  ..., -3.0858, -3.0362, -3.0575],
         [-3.0835, -3.0468, -2.9983,  ..., -3.0853, -3.0078, -2.9356],
         [-3.1057, -3.0256, -3.1402,  ..., -3.0402, -3.1947, -3.0470]]],
       grad_fn=<LogSoftmaxBackward0>)
The probability for each amino acid in each position is
tensor([[0.0447, 0.0499, 0.0500,  ..., 0.0458, 0.0490, 0.0480],
        [0.0475, 0.0483, 0.0496,  .

The decoder outputs log probabilities for amino acids with the same dimension as input of the encoder

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Training hyperparameters
num_epochs = 30
learning_rate = 1e-3
batch_size = 128

data_loader = DataLoader(data, batch_size = batch_size, shuffle = True)