Code is adapted from: https://github.com/OATML-Markslab/EVE

In [None]:
import os
import re
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
import torch.nn as nn
import time
import tqdm
from scipy.special import erfinv
from sklearn.model_selection import train_test_split
import torch.optim as optim
import torch.backends.cudnn as cudnn
import statistics
from tqdm import notebook


In [None]:
model_params = {   
    "encoder_parameters": {
        "hidden_layers_sizes"         :   [2000,1000,300],
        "z_dim"                               :   2,
        "convolve_input"                      :   False,
        "convolution_input_depth"             :   40,
        "nonlinear_activation"                :   "relu",
        "dropout_proba"                       :   0.0
    },
    "decoder_parameters": {
        "hidden_layers_sizes"         :   [300,1000,2000],
        "z_dim"                               :   2,
        "bayesian_decoder"                    :   False,
        "first_hidden_nonlinearity"           :   "relu", 
        "last_hidden_nonlinearity"            :   "relu", 
        "dropout_proba"                       :   0.1,
        "convolve_output"                     :   True,
        "convolution_output_depth"            :   40, 
        "include_temperature_scaler"          :   True, 
        "include_sparsity"                    :   False, 
        "num_tiles_sparsity"                  :   0,
        "logit_sparsity_p"                    :   0
    },
    "training_parameters": {
        "num_training_steps"                :   150000,
        "learning_rate"                     :   1e-4,
        "batch_size"                        :   256,
        "annealing_warm_up"                 :   0,
        "kl_latent_scale"                   :   1.0,
        "kl_global_params_scale"            :   1.0,
        "l2_regularization"                 :   0.0,
        "use_lr_scheduler"                  :   False,
        "use_validation_set"                :   False,
        "validation_set_pct"                :   0.2,
        "validation_freq"                   :   1000,
        "log_training_info"                 :   False,
        "log_training_freq"                 :   1000,
        "save_model_params_freq"            :   500000,
        'model_checkpoint_location'         :   '.',
        'calc_weights'                      :   False
    }
}

In [None]:
# FASTA parser requires Biopython
try:
    from Bio import SeqIO
except:
    !pip install biopython
    from Bio import SeqIO
    
# Retrieve protein alignment file
if not os.path.exists('BLAT_ECOLX_1_b0.5_labeled.fasta'):
    !wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_1_b0.5_labeled.fasta
        
# Retrieve file with experimental measurements
if not os.path.exists('BLAT_ECOLX_Ranganathan2015.csv'):
    !wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_Ranganathan2015.csv
        
# Options
batch_size = 16

In [None]:
# Mapping from amino acids to integers
aa1_to_index = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6,
                'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12,
                'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18,
                'Y': 19, 'X':20, 'Z': 21, '-': 22}
aa1 = "ACDEFGHIKLMNPQRSTVWYXZ-"

phyla = ['Acidobacteria', 'Actinobacteria', 'Bacteroidetes',
         'Chloroflexi', 'Cyanobacteria', 'Deinococcus-Thermus',
         'Firmicutes', 'Fusobacteria', 'Proteobacteria', 'Other']

def convert_to_one_hot_encoding(seq_list):
    one_hot = np.array([aa1_to_index[aa] for aa in str(seq_list).upper().replace('.', '-')])
    one_hot = F.one_hot(torch.from_numpy(one_hot))
    target_one_hot = torch.zeros(one_hot.shape[0], 23)
    target_one_hot[:, :one_hot.shape[1]] = one_hot
    target_one_hot = target_one_hot[None, :, :]
    return target_one_hot

def get_data(data_filename, calc_weights=False, weights_similarity_threshold=0.8):
    '''Create dataset from FASTA filename'''
    ids = []
    labels = []
    seqs = []
    label_re = re.compile(r'\[([^\]]*)\]')
    seqs_letters = []
    for record in SeqIO.parse(data_filename, "fasta"):
        ids.append(record.id)

        seqs_letters.append(np.array([aa for aa in str(record.seq).upper().replace('.', '-')]))
        seqs.append(np.array([aa1_to_index[aa] for aa in str(record.seq).upper().replace('.', '-')]))
        
        label = label_re.search(record.description).group(1)
        # Only use most common classes
        if label not in phyla:
            label = 'Other'
        labels.append(label)            
    seqs = torch.from_numpy(np.vstack(seqs))
    seqs_letters = np.vstack(seqs_letters)
    labels = np.array(labels)
    
    phyla_lookup_table, phyla_idx = np.unique(labels, return_inverse=True)
    # dataset = torch.utils.data.TensorDataset(*[seqs, torch.from_numpy(phyla_idx)])
    
    one_hot1 = F.one_hot(seqs[:len(seqs)//2].long()).bool()
    one_hot2 = F.one_hot(seqs[len(seqs)//2:].long()).bool()
    one_hot = torch.cat([one_hot1, one_hot2])
    assert(len(seqs) == len(one_hot))
    del one_hot1
    del one_hot2
    one_hot[seqs>19] = 0


    # weights = None
    if calc_weights is not False:

        # Experiencing memory issues on colab for this code because pytorch doesn't
        # allow one_hot directly to bool. Splitting in two and then merging.
        # one_hot = F.one_hot(seqs.long()).to('cuda' if torch.cuda.is_available() else 'cpu')

        flat_one_hot = one_hot.flatten(1)

        weights = []
        weight_batch_size = 1000
        flat_one_hot = flat_one_hot.float()
        for i in range(seqs.size(0) // weight_batch_size + 1):
            x = flat_one_hot[i * weight_batch_size : (i + 1) * weight_batch_size]
            similarities = torch.mm(x, flat_one_hot.T)
            lengths = (seqs[i * weight_batch_size : (i + 1) * weight_batch_size] <=19).sum(1).unsqueeze(-1)
            # w = 1.0 / (similarities / lengths).gt(weights_similarity_threshold).sum(1).float().to('cuda' if torch.cuda.is_available() else 'cpu')
            w = 1.0 / (similarities / lengths).gt(weights_similarity_threshold).sum(1).float()
            weights.append(w)
            
        weights = torch.cat(weights).numpy()
        neff = weights.sum()

    else:
        weights = np.ones(seqs.shape[0])
        neff = weights.sum()

    one_hot = np.multiply(one_hot, 1).to('cuda' if torch.cuda.is_available() else 'cpu')
    # seq_len = one_hot.shape[1]
    # alphabet_size = one_hot.shape[2]

    # dataset_one_hot = torch.utils.data.TensorDataset(*[one_hot, torch.from_numpy(phyla_idx)])  
    return seqs, seqs_letters, one_hot, torch.from_numpy(phyla_idx), weights, neff, phyla_lookup_table



dataset, dataset_letters, dataset_one_hot, phyla_idx, weights, neff, phyla_lookup_table = get_data('BLAT_ECOLX_1_b0.5_labeled.fasta', calc_weights=model_params["training_parameters"]['calc_weights'])
print(dataset.shape[0])
dataset_one_hot_tensor = torch.utils.data.TensorDataset(*[dataset_one_hot, phyla_idx])
dataset_tensor = torch.utils.data.TensorDataset(*[dataset, phyla_idx])

dataloader = torch.utils.data.DataLoader(dataset_one_hot_tensor, batch_size=batch_size, shuffle=True)

# print(phyla_lookup_table)
print(weights)

7844
[1. 1. 1. ... 1. 1. 1.]


In [None]:
def read_experimental_data(filename, alignment_data, measurement_col_name = '2500', sequence_offset=0):
    '''Read experimental data from csv file, and check that amino acid match those 
       in the first sequence of the alignment.
       
       measurement_col_name specifies which column in the csv file contains the experimental 
       observation. In our case, this is the one called 2500.
       
       sequence_offset is used in case there is an overall offset between the
       indices in the two files.
       '''
    
    measurement_df = pd.read_csv(filename, delimiter=',', usecols=['mutant', measurement_col_name])
    wt_sequence, wt_label = alignment_data[0]
    
    zero_index = None
    
    experimental_data = {}
    for idx, entry in measurement_df.iterrows():
        mutant_from, position, mutant_to = entry['mutant'][:1],int(entry['mutant'][1:-1]),entry['mutant'][-1:]  
        # Use index of first entry as offset (keep track of this in case 
        # there are index gaps in experimental data)
        if zero_index is None:
            zero_index = position
            
        # Corresponding position in our alignment
        seq_position = position-zero_index+sequence_offset
            
        # Make sure that two two inputs agree on the indices: the 
        # amino acids in the first entry of the alignment should be 
        # identical to those in the experimental file.
        assert mutant_from == aa1[wt_sequence[seq_position]]  
        
        if seq_position not in experimental_data:
            experimental_data[seq_position] = {}
        
        # Check that there is only a single experimental value for mutant
        assert mutant_to not in experimental_data[seq_position]
        
        experimental_data[seq_position]['pos'] = seq_position
        experimental_data[seq_position]['WT'] = mutant_from
        experimental_data[seq_position][mutant_to] = entry[measurement_col_name]
    experimental_data = pd.DataFrame(experimental_data).transpose().set_index(['pos', 'WT'])
    return experimental_data
        
        
experimental_data = read_experimental_data("BLAT_ECOLX_Ranganathan2015.csv", dataset_tensor)
experimental_data

Unnamed: 0_level_0,Unnamed: 1_level_0,A,C,E,D,G,F,I,K,M,L,N,Q,P,S,R,T,W,V,Y,H
pos,WT,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,H,-0.00978356,-0.41826,-0.279024,-0.181607,-0.0602417,-0.818487,-0.359191,0.0144696,-0.224781,-0.480347,-0.0430932,-0.135568,-1.01085,0.0361661,-0.00252626,-0.0671875,-1.34759,-0.026874,-0.885025,
1,P,-1.6426,-0.364138,0.143258,-0.0284025,-0.969268,-0.199804,-0.0735238,0.13559,-0.0283657,-0.211869,0.0206405,-0.0282136,,-0.16013,0.054154,-0.0911999,-0.109139,0.045913,0.00174467,0.0457846
2,E,0.0109131,-0.158233,,-0.0757852,0.0813101,-0.232106,-0.153907,0.0871198,-0.036441,-0.0581804,-0.0064688,0.0496907,-0.387232,-0.0395849,-0.220003,-0.135909,-0.44234,-0.0645674,-0.245436,0.0209168
3,T,-1.45459,-2.41902,-2.41446,-2.29488,-2.35671,-2.60457,-0.280446,-1.42789,-1.8431,-0.765521,-2.48572,-1.6671,-1.79017,-1.39248,-2.37509,,-2.8417,0.0341893,-2.78913,-1.84954
4,L,-0.202228,-1.95959,-1.72164,-2.71077,-1.4842,-0.720047,0.0173958,0.0695957,-0.0480697,,-1.42071,-0.222812,-2.19535,-1.2641,-0.0649357,-0.313656,-0.299738,0.0502655,-0.2186,-0.889277
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258,L,-0.103001,-0.617685,-2.95432,-2.52951,-2.41565,-0.663794,0.0786722,-2.57103,0.0126656,,-2.32138,-1.72228,-2.42792,-1.65506,-2.70359,-0.231658,-2.78162,-0.0205734,-2.60799,-3.01127
259,I,-0.537631,-0.657012,-2.6017,-2.78886,-2.24781,-0.123746,,-2.24993,-0.0885522,-0.190349,-2.48529,-1.95522,-2.61388,-1.48797,-1.94825,-0.300995,-1.71315,-0.030749,-0.0821665,-2.07812
260,K,-0.0389621,-0.91867,-0.0274808,-0.0614761,-0.0924616,-0.286851,-0.514835,,-0.107796,-0.0442595,-0.0745288,-0.0715558,-2.59327,-0.0202318,-0.04172,-0.090441,-0.449885,-0.261297,-0.111472,-0.00521766
261,H,-0.465274,-0.251095,-0.338762,-1.53546,-0.509937,-0.149276,-2.72312,-0.271887,-0.226841,-0.553996,0.0240883,-0.141955,-2.72487,-0.160349,-0.157191,-0.113391,-0.0762315,-0.825731,-0.0946859,


In [None]:
def get_wt_sequence(experimental_data):
    wt_list = [] 
    for (position, mutant_from), row in experimental_data.iterrows():
        wt_list.append(mutant_from)
    wt_seq = "".join(wt_list)
    wt_one_hot = convert_to_one_hot_encoding(wt_seq)
    # wt_one_hot = torch.from_numpy(wt_one_hot)
    return wt_seq, wt_one_hot

wt_seq, wt_one_hot = get_wt_sequence(experimental_data)

def get_mutated_sequences(experimental_data, wt_seq):
    mutated_list = []
    mutation_list = []
    exp_list = []
    for (position, mutant_from), row in experimental_data.iterrows():
        for mutant_to, exp_value in row.iteritems():
            if mutant_from != mutant_to:
                mut_seq = wt_seq[:position] + mutant_to + wt_seq[position+1:]
                mut_seq_one_hot = convert_to_one_hot_encoding(mut_seq)
                # mut_seq_one_hot = torch.from_numpy(mut_seq_one_hot)
                mutated_list.append(mut_seq_one_hot)
                mutation_list.append((position, mutant_to))
                exp_list.append(exp_value) 
    return(mutated_list, mutation_list, exp_list)

mut_one_hot, pos_info, exp_list = get_mutated_sequences(experimental_data, wt_seq)
# print(torch.stack(mut_one_hot).squeeze().shape)
test_dataloader = torch.utils.data.DataLoader(torch.stack(mut_one_hot).squeeze(), batch_size=model_params['training_parameters']['batch_size'], shuffle=False)


In [None]:

def get_basic_model(experimental_data, dataset_letters, wt_seq):
    val_list = []
    exp_list = []
    for (position, mutant_from), row in experimental_data.iterrows():
        # print(mutant_from)
        col= dataset_letters[:,position]
        counter_wt = np.sum(weights[col==str(mutant_from)])
        # print(counter_wt)
        for mutant_to, exp_value in row.iteritems():
            if mutant_from != mutant_to:
                counter_mut = np.sum(weights[col==str(mutant_to)])
                val = counter_mut/counter_wt
                val_list.append(val)
                exp_list.append(exp_value)
    return(val_list, exp_list)

val_list, exp_list = get_basic_model(experimental_data, dataset_letters, wt_seq)

In [None]:
print(spearmanr(val_list, exp_list))

SpearmanrResult(correlation=0.5637427310119946, pvalue=0.0)


In [None]:
class data_class: 
    def __init__(self,
                  data_filename,
                  calc_weights):
        dataset_digits, _, dataset_one_hot, phyla_idx, weights, neff, phyla_lookup_table = get_data(data_filename, calc_weights)
        self.seq_len = dataset_one_hot.shape[1]
        self.alphabet_size = dataset_one_hot.shape[2]
        self.Neff = neff
        self.one_hot_encoding = dataset_one_hot
        self.weights = weights

                      

In [None]:
class VAE_MLP_encoder(nn.Module):
    """
    MLP encoder class for the VAE model.
    """
    def __init__(self,params):
        """
        Required input parameters:
        - seq_len: (Int) Sequence length of sequence alignment
        - alphabet_size: (Int) Alphabet size of sequence alignment (will be driven by the data helper object)
        - hidden_layers_sizes: (List) List of sizes of DNN linear layers
        - z_dim: (Int) Size of latent space
        - convolve_input: (Bool) Whether to perform 1d convolution on input (kernel size 1, stide 1)
        - convolution_depth: (Int) Size of the 1D-convolution on input
        - nonlinear_activation: (Str) Type of non-linear activation to apply on each hidden layer
        - dropout_proba: (Float) Dropout probability applied on all hidden layers. If 0.0 then no dropout applied
        """
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.seq_len = params['seq_len']
        self.alphabet_size = params['alphabet_size']
        self.hidden_layers_sizes = params['hidden_layers_sizes']
        self.z_dim = params['z_dim']
        self.convolve_input = params['convolve_input']
        self.convolution_depth = params['convolution_input_depth']
        self.dropout_proba = params['dropout_proba']

        self.mu_bias_init = 0.1
        self.log_var_bias_init = -10.0

        #Convolving input with kernels of size 1 to capture potential similarities across amino acids when encoding sequences
        if self.convolve_input:
            self.input_convolution = nn.Conv1d(in_channels=self.alphabet_size,out_channels=self.convolution_depth,kernel_size=1,stride=1,bias=False)
            self.channel_size = self.convolution_depth
        else:
            self.channel_size = self.alphabet_size

        self.hidden_layers=torch.nn.ModuleDict()
        for layer_index in range(len(self.hidden_layers_sizes)):
            if layer_index==0:
                self.hidden_layers[str(layer_index)] = nn.Linear((self.channel_size*self.seq_len),self.hidden_layers_sizes[layer_index])
                nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init)
            else:
                self.hidden_layers[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index])
                nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init)
        
        self.fc_mean = nn.Linear(self.hidden_layers_sizes[-1],self.z_dim)
        nn.init.constant_(self.fc_mean.bias, self.mu_bias_init)
        self.fc_log_var = nn.Linear(self.hidden_layers_sizes[-1],self.z_dim)
        nn.init.constant_(self.fc_log_var.bias, self.log_var_bias_init)

        # set up non-linearity
        if params['nonlinear_activation'] == 'relu':
            self.nonlinear_activation = nn.ReLU()
        elif params['nonlinear_activation'] == 'tanh':
            self.nonlinear_activation = nn.Tanh()
        elif params['nonlinear_activation'] == 'sigmoid':
            self.nonlinear_activation = nn.Sigmoid()
        elif params['nonlinear_activation'] == 'elu':
            self.nonlinear_activation = nn.ELU()
        elif params['nonlinear_activation'] == 'linear':
            self.nonlinear_activation = nn.Identity()
        
        if self.dropout_proba > 0.0:
            self.dropout_layer = nn.Dropout(p=self.dropout_proba)

    def forward(self, x):
        if self.dropout_proba > 0.0:
            x = self.dropout_layer(x)

        if self.convolve_input:
            x = x.permute(0,2,1) 
            x = self.input_convolution(x)
            x = x.view(-1,self.seq_len*self.channel_size)
        else:
            x = x.view(-1,self.seq_len*self.channel_size) 
        
        for layer_index in range(len(self.hidden_layers_sizes)):
            x = self.nonlinear_activation(self.hidden_layers[str(layer_index)](x))
            if self.dropout_proba > 0.0:
                x = self.dropout_layer(x)

        z_mean = self.fc_mean(x)
        z_log_var = self.fc_log_var(x)

        return z_mean, z_log_var

class VAE_Standard_MLP_decoder(nn.Module):
    """
    Standard MLP decoder class for the VAE model.
    """
    def __init__(self, params):
        """
        Required input parameters:
        - seq_len: (Int) Sequence length of sequence alignment
        - alphabet_size: (Int) Alphabet size of sequence alignment (will be driven by the data helper object)
        - hidden_layers_sizes: (List) List of the sizes of the hidden layers (all DNNs)
        - z_dim: (Int) Dimension of latent space
        - first_hidden_nonlinearity: (Str) Type of non-linear activation applied on the first (set of) hidden layer(s)
        - last_hidden_nonlinearity: (Str) Type of non-linear activation applied on the very last hidden layer (pre-sparsity)
        - dropout_proba: (Float) Dropout probability applied on all hidden layers. If 0.0 then no dropout applied
        - convolve_output: (Bool) Whether to perform 1d convolution on output (kernel size 1, stide 1)
        - convolution_depth: (Int) Size of the 1D-convolution on output
        - include_temperature_scaler: (Bool) Whether we apply the global temperature scaler
        - include_sparsity: (Bool) Whether we use the sparsity inducing scheme on the output from the last hidden layer
        - num_tiles_sparsity: (Int) Number of tiles to use in the sparsity inducing scheme (the more the tiles, the stronger the sparsity)
        - bayesian_decoder: (Bool) Whether the decoder is bayesian or not
        """
        super().__init__()        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.seq_len = params['seq_len']
        self.alphabet_size = params['alphabet_size']
        self.hidden_layers_sizes = params['hidden_layers_sizes']
        self.z_dim = params['z_dim']
        self.bayesian_decoder = False
        self.dropout_proba = params['dropout_proba']
        self.convolve_output = params['convolve_output']
        self.convolution_depth = params['convolution_output_depth']
        self.include_temperature_scaler = params['include_temperature_scaler']
        self.include_sparsity = params['include_sparsity']
        self.num_tiles_sparsity = params['num_tiles_sparsity']

        self.mu_bias_init = 0.1

        self.hidden_layers=nn.ModuleDict()
        for layer_index in range(len(self.hidden_layers_sizes)):
            if layer_index==0:
                self.hidden_layers[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index])
                nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init)
            else:
                self.hidden_layers[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index])
                nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init)

        if params['first_hidden_nonlinearity'] == 'relu':
            self.first_hidden_nonlinearity = nn.ReLU()
        elif params['first_hidden_nonlinearity'] == 'tanh':
            self.first_hidden_nonlinearity = nn.Tanh()
        elif params['first_hidden_nonlinearity'] == 'sigmoid':
            self.first_hidden_nonlinearity = nn.Sigmoid()
        elif params['first_hidden_nonlinearity'] == 'elu':
            self.first_hidden_nonlinearity = nn.ELU()
        elif params['first_hidden_nonlinearity'] == 'linear':
            self.first_hidden_nonlinearity = nn.Identity()
        
        if params['last_hidden_nonlinearity'] == 'relu':
            self.last_hidden_nonlinearity = nn.ReLU()
        elif params['last_hidden_nonlinearity'] == 'tanh':
            self.last_hidden_nonlinearity = nn.Tanh()
        elif params['last_hidden_nonlinearity'] == 'sigmoid':
            self.last_hidden_nonlinearity = nn.Sigmoid()
        elif params['last_hidden_nonlinearity'] == 'elu':
            self.last_hidden_nonlinearity = nn.ELU()
        elif params['last_hidden_nonlinearity'] == 'linear':
            self.last_hidden_nonlinearity = nn.Identity()

        if self.dropout_proba > 0.0:
            self.dropout_layer = nn.Dropout(p=self.dropout_proba)

        if self.convolve_output:
            self.output_convolution = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False)
            self.channel_size = self.convolution_depth
        else:
            self.channel_size = self.alphabet_size
        
        if self.include_sparsity:
            self.sparsity_weight = nn.Parameter(torch.randn(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len))

        self.W_out = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1]))
        nn.init.xavier_normal_(self.W_out) #Initialize weights with Glorot initialization
        self.b_out = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len))
        nn.init.constant_(self.b_out, self.mu_bias_init)
        
        if self.include_temperature_scaler:
            self.temperature_scaler = nn.Parameter(torch.ones(1))

    def forward(self, z):
        batch_size = z.shape[0]
        if self.dropout_proba > 0.0:
            x = self.dropout_layer(z)
        else:
            x=z

        for layer_index in range(len(self.hidden_layers_sizes)-1):
            x = self.first_hidden_nonlinearity(self.hidden_layers[str(layer_index)](x))
            if self.dropout_proba > 0.0:
                x = self.dropout_layer(x)

        x = self.last_hidden_nonlinearity(self.hidden_layers[str(len(self.hidden_layers_sizes)-1)](x)) #of size (batch_size,H)
        if self.dropout_proba > 0.0:
            x = self.dropout_layer(x)

        W_out = self.W_out.data

        if self.convolve_output:
            W_out = torch.mm(W_out.view(self.seq_len * self.hidden_layers_sizes[-1], self.channel_size), 
                                    self.output_convolution.weight.view(self.channel_size,self.alphabet_size))

        if self.include_sparsity:
            sparsity_tiled = self.sparsity_weight.repeat(self.num_tiles_sparsity,1) #of size (H,seq_len)
            sparsity_tiled = nn.Sigmoid()(sparsity_tiled).unsqueeze(2) #of size (H,seq_len,1)
            W_out = W_out.view(self.hidden_layers_sizes[-1], self.seq_len, self.alphabet_size) * sparsity_tiled

        W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1])

        x = F.linear(x, weight=W_out, bias=self.b_out)

        if self.include_temperature_scaler:
            x = torch.log(1.0+torch.exp(self.temperature_scaler)) * x

        x = x.view(batch_size, self.seq_len, self.alphabet_size)
        x_recon_log = F.log_softmax(x, dim=-1) #of shape (batch_size, seq_len, alphabet)

        return x_recon_log

class VAE_Bayesian_MLP_decoder(nn.Module):
    """
    Bayesian MLP decoder class for the VAE model.
    """
    def __init__(self, params):
        """
        Required input parameters:
        - seq_len: (Int) Sequence length of sequence alignment
        - alphabet_size: (Int) Alphabet size of sequence alignment (will be driven by the data helper object)
        - hidden_layers_sizes: (List) List of the sizes of the hidden layers (all DNNs)
        - z_dim: (Int) Dimension of latent space
        - first_hidden_nonlinearity: (Str) Type of non-linear activation applied on the first (set of) hidden layer(s)
        - last_hidden_nonlinearity: (Str) Type of non-linear activation applied on the very last hidden layer (pre-sparsity)
        - dropout_proba: (Float) Dropout probability applied on all hidden layers. If 0.0 then no dropout applied
        - convolve_output: (Bool) Whether to perform 1d convolution on output (kernel size 1, stide 1)
        - convolution_depth: (Int) Size of the 1D-convolution on output
        - include_temperature_scaler: (Bool) Whether we apply the global temperature scaler
        - include_sparsity: (Bool) Whether we use the sparsity inducing scheme on the output from the last hidden layer
        - num_tiles_sparsity: (Int) Number of tiles to use in the sparsity inducing scheme (the more the tiles, the stronger the sparsity)
        - bayesian_decoder: (Bool) Whether the decoder is bayesian or not
        """
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.seq_len = params['seq_len']
        self.alphabet_size = params['alphabet_size']
        self.hidden_layers_sizes = params['hidden_layers_sizes']
        self.z_dim = params['z_dim']
        self.bayesian_decoder = True
        self.dropout_proba = params['dropout_proba']
        self.convolve_output = params['convolve_output']
        self.convolution_depth = params['convolution_output_depth']
        self.include_temperature_scaler = params['include_temperature_scaler']
        self.include_sparsity = params['include_sparsity']
        self.num_tiles_sparsity = params['num_tiles_sparsity']

        self.mu_bias_init = 0.1
        self.logvar_init = -10.0
        self.logit_scale_p = 0.001
        
        self.hidden_layers_mean=nn.ModuleDict()
        self.hidden_layers_log_var=nn.ModuleDict()
        for layer_index in range(len(self.hidden_layers_sizes)):
            if layer_index==0:
                self.hidden_layers_mean[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index])
                self.hidden_layers_log_var[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index])
                nn.init.constant_(self.hidden_layers_mean[str(layer_index)].bias, self.mu_bias_init)
                nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].weight, self.logvar_init)
                nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].bias, self.logvar_init)
            else:
                self.hidden_layers_mean[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index])
                self.hidden_layers_log_var[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index])
                nn.init.constant_(self.hidden_layers_mean[str(layer_index)].bias, self.mu_bias_init)
                nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].weight, self.logvar_init)
                nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].bias, self.logvar_init)

        if params['first_hidden_nonlinearity'] == 'relu':
            self.first_hidden_nonlinearity = nn.ReLU()
        elif params['first_hidden_nonlinearity'] == 'tanh':
            self.first_hidden_nonlinearity = nn.Tanh()
        elif params['first_hidden_nonlinearity'] == 'sigmoid':
            self.first_hidden_nonlinearity = nn.Sigmoid()
        elif params['first_hidden_nonlinearity'] == 'elu':
            self.first_hidden_nonlinearity = nn.ELU()
        elif params['first_hidden_nonlinearity'] == 'linear':
            self.first_hidden_nonlinearity = nn.Identity()
        
        if params['last_hidden_nonlinearity'] == 'relu':
            self.last_hidden_nonlinearity = nn.ReLU()
        elif params['last_hidden_nonlinearity'] == 'tanh':
            self.last_hidden_nonlinearity = nn.Tanh()
        elif params['last_hidden_nonlinearity'] == 'sigmoid':
            self.last_hidden_nonlinearity = nn.Sigmoid()
        elif params['last_hidden_nonlinearity'] == 'elu':
            self.last_hidden_nonlinearity = nn.ELU()
        elif params['last_hidden_nonlinearity'] == 'linear':
            self.last_hidden_nonlinearity = nn.Identity()

        if self.dropout_proba > 0.0:
            self.dropout_layer = nn.Dropout(p=self.dropout_proba)

        if self.convolve_output:
            self.output_convolution_mean = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False)
            self.output_convolution_log_var = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False)
            nn.init.constant_(self.output_convolution_log_var.weight, self.logvar_init)
            self.channel_size = self.convolution_depth
        else:
            self.channel_size = self.alphabet_size
        
        if self.include_sparsity:
            self.sparsity_weight_mean = nn.Parameter(torch.zeros(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len))
            self.sparsity_weight_log_var = nn.Parameter(torch.ones(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len))
            nn.init.constant_(self.sparsity_weight_log_var, self.logvar_init)

        self.last_hidden_layer_weight_mean = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1]))
        self.last_hidden_layer_weight_log_var = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1]))
        nn.init.xavier_normal_(self.last_hidden_layer_weight_mean) #Glorot initialization
        nn.init.constant_(self.last_hidden_layer_weight_log_var, self.logvar_init)

        self.last_hidden_layer_bias_mean = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len))
        self.last_hidden_layer_bias_log_var = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len))
        nn.init.constant_(self.last_hidden_layer_bias_mean, self.mu_bias_init)
        nn.init.constant_(self.last_hidden_layer_bias_log_var, self.logvar_init)
        
        if self.include_temperature_scaler:
            self.temperature_scaler_mean = nn.Parameter(torch.ones(1))
            self.temperature_scaler_log_var = nn.Parameter(torch.ones(1) * self.logvar_init) 
            
    def sampler(self, mean, log_var):
        """
        Samples a latent vector via reparametrization trick
        """
        eps = torch.randn_like(mean).to(self.device)
        z = torch.exp(0.5*log_var) * eps + mean
        return z

    def forward(self, z):
        batch_size = z.shape[0]
        if self.dropout_proba > 0.0:
            x = self.dropout_layer(z)
        else:
            x = z

        for layer_index in range(len(self.hidden_layers_sizes)-1):
            layer_i_weight = self.sampler(self.hidden_layers_mean[str(layer_index)].weight, self.hidden_layers_log_var[str(layer_index)].weight)
            layer_i_bias = self.sampler(self.hidden_layers_mean[str(layer_index)].bias, self.hidden_layers_log_var[str(layer_index)].bias)
            x = self.first_hidden_nonlinearity(F.linear(x, weight=layer_i_weight, bias=layer_i_bias))
            if self.dropout_proba > 0.0:
                x = self.dropout_layer(x)

        last_index = len(self.hidden_layers_sizes)-1
        last_layer_weight = self.sampler(self.hidden_layers_mean[str(last_index)].weight, self.hidden_layers_log_var[str(last_index)].weight)
        last_layer_bias = self.sampler(self.hidden_layers_mean[str(last_index)].bias, self.hidden_layers_log_var[str(last_index)].bias)
        x = self.last_hidden_nonlinearity(F.linear(x, weight=last_layer_weight, bias=last_layer_bias))
        if self.dropout_proba > 0.0:
            x = self.dropout_layer(x)

        W_out = self.sampler(self.last_hidden_layer_weight_mean, self.last_hidden_layer_weight_log_var)
        b_out = self.sampler(self.last_hidden_layer_bias_mean, self.last_hidden_layer_bias_log_var)

        if self.convolve_output:
            output_convolution_weight = self.sampler(self.output_convolution_mean.weight, self.output_convolution_log_var.weight)
            W_out = torch.mm(W_out.view(self.seq_len * self.hidden_layers_sizes[-1], self.channel_size), 
                                    output_convolution_weight.view(self.channel_size,self.alphabet_size)) #product of size (H * seq_len, alphabet)
            
        if self.include_sparsity:
            sparsity_weights = self.sampler(self.sparsity_weight_mean,self.sparsity_weight_log_var)
            sparsity_tiled = sparsity_weights.repeat(self.num_tiles_sparsity,1) 
            sparsity_tiled = nn.Sigmoid()(sparsity_tiled).unsqueeze(2) 

            W_out = W_out.view(self.hidden_layers_sizes[-1], self.seq_len, self.alphabet_size) * sparsity_tiled
        
        W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1])
        
        x = F.linear(x, weight=W_out, bias=b_out)

        if self.include_temperature_scaler:
            temperature_scaler = self.sampler(self.temperature_scaler_mean,self.temperature_scaler_log_var)
            x = torch.log(1.0+torch.exp(temperature_scaler)) * x

        x = x.view(batch_size, self.seq_len, self.alphabet_size)
        x_recon_log = F.log_softmax(x, dim=-1) #of shape (batch_size, seq_len, alphabet)

        return x_recon_log


In [None]:

# from . import VAE_encoder, VAE_decoder

class VAE_model(nn.Module):
    """
    Class for the VAE model with estimation of weights distribution parameters via Mean-Field VI.
    """
    def __init__(self,
            model_name,
            data,
            encoder_parameters,
            decoder_parameters,
            random_seed
            ):
        
        super().__init__()
        
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.float32
        self.random_seed = random_seed
        torch.manual_seed(random_seed)
        
        self.seq_len = data.seq_len
        self.alphabet_size = data.alphabet_size
        self.Neff = data.Neff

        encoder_parameters['seq_len'] = self.seq_len
        encoder_parameters['alphabet_size'] = self.alphabet_size
        decoder_parameters['seq_len'] = self.seq_len
        decoder_parameters['alphabet_size'] = self.alphabet_size
        
        self.encoder = VAE_MLP_encoder(params=encoder_parameters)
        if decoder_parameters['bayesian_decoder']:
            self.decoder = VAE_Bayesian_MLP_decoder(params=decoder_parameters)
        else:
            self.decoder = VAE_Standard_MLP_decoder(params=decoder_parameters)
        self.logit_sparsity_p = decoder_parameters['logit_sparsity_p']
        


    def sample_latent(self, mu, log_var):
        """
        Samples a latent vector via reparametrization trick
        """
        eps = torch.randn_like(mu).to(self.device)
        z = torch.exp(0.5*log_var) * eps + mu
        return z

    def KLD_diag_gaussians(self, mu, logvar, p_mu, p_logvar):
        """
        KL divergence between diagonal gaussian with prior diagonal gaussian.
        """
        KLD = 0.5 * (p_logvar - logvar) + 0.5 * (torch.exp(logvar) + torch.pow(mu-p_mu,2)) / (torch.exp(p_logvar)+1e-20) - 0.5

        return torch.sum(KLD)

    def annealing_factor(self, annealing_warm_up, training_step):
        """
        Annealing schedule of KL to focus on reconstruction error in early stages of training
        """
        if training_step < annealing_warm_up:
            return training_step/annealing_warm_up
        else:
            return 1

    def KLD_global_parameters(self):
        """
        KL divergence between the variational distributions and the priors (for the decoder weights).
        """
        KLD_decoder_params = 0.0
        zero_tensor = torch.tensor(0.0).to(self.device) 
        
        for layer_index in range(len(self.decoder.hidden_layers_sizes)):
            for param_type in ['weight','bias']:
                KLD_decoder_params += self.KLD_diag_gaussians(
                                    self.decoder.state_dict(keep_vars=True)['hidden_layers_mean.'+str(layer_index)+'.'+param_type].flatten(),
                                    self.decoder.state_dict(keep_vars=True)['hidden_layers_log_var.'+str(layer_index)+'.'+param_type].flatten(),
                                    zero_tensor,
                                    zero_tensor
                )
                
        for param_type in ['weight','bias']:
                KLD_decoder_params += self.KLD_diag_gaussians(
                                        self.decoder.state_dict(keep_vars=True)['last_hidden_layer_'+param_type+'_mean'].flatten(),
                                        self.decoder.state_dict(keep_vars=True)['last_hidden_layer_'+param_type+'_log_var'].flatten(),
                                        zero_tensor,
                                        zero_tensor
                )

        if self.decoder.include_sparsity:
            self.logit_scale_sigma = 4.0
            self.logit_scale_mu = 2.0**0.5 * self.logit_scale_sigma * erfinv(2.0 * self.logit_sparsity_p - 1.0)

            sparsity_mu = torch.tensor(self.logit_scale_mu).to(self.device) 
            sparsity_log_var = torch.log(torch.tensor(self.logit_scale_sigma**2)).to(self.device)
            KLD_decoder_params += self.KLD_diag_gaussians(
                                    self.decoder.state_dict(keep_vars=True)['sparsity_weight_mean'].flatten(),
                                    self.decoder.state_dict(keep_vars=True)['sparsity_weight_log_var'].flatten(),
                                    sparsity_mu,
                                    sparsity_log_var
            )
            
        if self.decoder.convolve_output:
            for param_type in ['weight']:
                KLD_decoder_params += self.KLD_diag_gaussians(
                                    self.decoder.state_dict(keep_vars=True)['output_convolution_mean.'+param_type].flatten(),
                                    self.decoder.state_dict(keep_vars=True)['output_convolution_log_var.'+param_type].flatten(),
                                    zero_tensor,
                                    zero_tensor
                )

        if self.decoder.include_temperature_scaler:
            KLD_decoder_params += self.KLD_diag_gaussians(
                                    self.decoder.state_dict(keep_vars=True)['temperature_scaler_mean'].flatten(),
                                    self.decoder.state_dict(keep_vars=True)['temperature_scaler_log_var'].flatten(),
                                    zero_tensor,
                                    zero_tensor
            )        
        return KLD_decoder_params

    def loss_function(self, x_recon_log, x, mu, log_var, kl_latent_scale, kl_global_params_scale, annealing_warm_up, training_step, Neff):
        """
        Returns mean of negative ELBO, reconstruction loss and KL divergence across batch x.
        """
        BCE = F.binary_cross_entropy_with_logits(x_recon_log, x, reduction='sum') / x.shape[0]
        KLD_latent = (-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())) / x.shape[0]
        if self.decoder.bayesian_decoder:
            KLD_decoder_params_normalized = self.KLD_global_parameters() / Neff
        else:
            KLD_decoder_params_normalized = 0.0
        warm_up_scale = self.annealing_factor(annealing_warm_up,training_step)
        neg_ELBO = BCE + warm_up_scale * (kl_latent_scale * KLD_latent + kl_global_params_scale * KLD_decoder_params_normalized)
        return neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized
    
    def all_likelihood_components(self, x):
        """
        Returns tensors of ELBO, reconstruction loss and KL divergence for each point in batch x.
        """
        mu, log_var = self.encoder(x)
        z = self.sample_latent(mu, log_var)
        recon_x_log = self.decoder(z)

        recon_x_log = recon_x_log.view(-1,self.alphabet_size*self.seq_len)
        x = x.view(-1,self.alphabet_size*self.seq_len)
        
        BCE_batch_tensor = torch.sum(F.binary_cross_entropy_with_logits(recon_x_log, x, reduction='none'),dim=1)
        KLD_batch_tensor = (-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(),dim=1))
        
        ELBO_batch_tensor = -(BCE_batch_tensor + KLD_batch_tensor)

        return ELBO_batch_tensor, BCE_batch_tensor, KLD_batch_tensor

    def train_model(self, data, training_parameters):
        """
        Training procedure for the VAE model.
        If use_validation_set is True then:
            - we split the alignment data in train/val sets.
            - we train up to num_training_steps steps but store the version of the model with lowest loss on validation set across training
        If not, then we train the model for num_training_steps and save the model at the end of training
        """
        if torch.cuda.is_available():
            cudnn.benchmark = True
        self.train()
        
        # if training_parameters['log_training_info']:
        #     filename = training_parameters['training_logs_location']+os.sep+self.model_name+"_losses.csv"
        #     with open(filename, "a") as logs:
        #         logs.write("Number of sequences in alignment file:\t"+str(data.num_sequences)+"\n")
        #         logs.write("Neff:\t"+str(self.Neff)+"\n")
        #         logs.write("Alignment sequence length:\t"+str(data.seq_len)+"\n")

        optimizer = optim.Adam(self.parameters(), lr=training_parameters['learning_rate'], weight_decay = training_parameters['l2_regularization'])
        
        if training_parameters['use_lr_scheduler']:
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=training_parameters['lr_scheduler_step_size'], gamma=training_parameters['lr_scheduler_gamma'])

        if training_parameters['use_validation_set']:
            x_train, x_val, weights_train, weights_val = train_test_split(data.one_hot_encoding, data.weights, test_size=training_parameters['validation_set_pct'], random_state=self.random_seed)
            best_val_loss = float('inf')
            best_model_step_index=0
        else:
            x_train = data.one_hot_encoding
            weights_train = data.weights
            best_val_loss = None
            best_model_step_index = training_parameters['num_training_steps']

        batch_order = np.arange(x_train.shape[0])
        # print(weights_train)
        # print(torch.sum(weights_train))
        seq_sample_probs = weights_train / np.sum(weights_train)

        self.Neff_training = np.sum(weights_train)
        N_training =  x_train.shape[0]
        
        start = time.time()
        train_loss = 0
        
        print(weights_train)

        for training_step in tqdm.tqdm(range(1,training_parameters['num_training_steps']+1), desc="Training model", leave=False, ascii=True):


            batch_index = np.random.choice(batch_order, training_parameters['batch_size'], p=seq_sample_probs).tolist()
            # print(len(batch_index))
            x = torch.tensor(x_train[batch_index], dtype=self.dtype).to(self.device)
            optimizer.zero_grad()

            mu, log_var = self.encoder(x)
            z = self.sample_latent(mu, log_var)
            recon_x_log = self.decoder(z)
            
            neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized = self.loss_function(recon_x_log, x, mu, log_var, training_parameters['kl_latent_scale'], training_parameters['kl_global_params_scale'], training_parameters['annealing_warm_up'], training_step, self.Neff_training)
            
            neg_ELBO.backward()
            optimizer.step()
            
            if training_parameters['use_lr_scheduler']:
                scheduler.step()
            
            if training_step % training_parameters['log_training_freq'] == 0:
                progress = "|Train : Update {0}. Negative ELBO : {1:.3f}, BCE: {2:.3f}, KLD_latent: {3:.3f}, KLD_decoder_params_norm: {4:.3f}, Time: {5:.2f} |".format(training_step, neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized, time.time() - start)
                print(progress)
                mut_scores_list = []
                mean_mut_score = self.get_mutation_likelihood(wt_one_hot, mut_one_hot, test_dataloader, 50, training_parameters['batch_size'])

                print(spearmanr(mean_mut_score, exp_list))

                z_list, label_list = model.get_z(dataloader)
                self.visualize_z(z_list, label_list, training_step, 1, 1)

                if training_parameters['log_training_info']:
                    with open(filename, "a") as logs:
                        logs.write(progress+"\n")

            if training_step % training_parameters['save_model_params_freq']==0:
                self.save(model_checkpoint=training_parameters['model_checkpoint_location']+os.sep+self.model_name+"_step_"+str(training_step),
                            encoder_parameters=encoder_parameters,
                            decoder_parameters=decoder_parameters,
                            training_parameters=training_parameters)
            
            if training_parameters['use_validation_set'] and training_step % training_parameters['validation_freq'] == 0:
                x_val = torch.tensor(x_val, dtype=self.dtype).to(self.device)
                val_neg_ELBO, val_BCE, val_KLD_latent, val_KLD_global_parameters = self.test_model(x_val, weights_val, training_parameters['batch_size'])

                progress_val = "\t\t\t|Val : Update {0}. Negative ELBO : {1:.3f}, BCE: {2:.3f}, KLD_latent: {3:.3f}, KLD_decoder_params_norm: {4:.3f}, Time: {5:.2f} |".format(training_step, val_neg_ELBO, val_BCE, val_KLD_latent, val_KLD_global_parameters, time.time() - start)
                print(progress_val)
                if training_parameters['log_training_info']:
                    with open(filename, "a") as logs:
                        logs.write(progress_val+"\n")

                if val_neg_ELBO < best_val_loss:
                    best_val_loss = val_neg_ELBO
                    best_model_step_index = training_step
                    self.save(model_checkpoint=training_parameters['model_checkpoint_location']+os.sep+self.model_name+"_best",
                                encoder_parameters=encoder_parameters,
                                decoder_parameters=decoder_parameters,
                                training_parameters=training_parameters)
                self.train()
    
    def test_model(self, x_val, weights_val, batch_size):
        self.eval()
        
        with torch.no_grad():
            val_batch_order = np.arange(x_val.shape[0])
            val_seq_sample_probs = weights_val / np.sum(weights_val)

            val_batch_index = np.random.choice(val_batcneffh_order, batch_size, p=val_seq_sample_probs).tolist()
            x = torch.tensor(x_val[val_batch_index], dtype=self.dtype).to(self.device)
            mu, log_var = self.encoder(x)
            z = self.sample_latent(mu, log_var)
            recon_x_log = self.decoder(z)
            
            neg_ELBO, BCE, KLD_latent, KLD_global_parameters = self.loss_function(recon_x_log, x, mu, log_var, kl_latent_scale=1.0, kl_global_params_scale=1.0, annealing_warm_up=0, training_step=1, Neff = self.Neff_training) #set annealing factor to 1
            
        return neg_ELBO.item(), BCE.item(), KLD_latent.item(), KLD_global_parameters.item()


    def get_z(self, dataloader):
        z_list = []
        label_list = []
        self.eval()
        with torch.no_grad():
            for x,y in dataloader:
                # For illustrative purposes, make sure we can see the entire tensor
                torch.set_printoptions(threshold=np.inf)
                # print("data: ", x)
                # print("labels: ", y)
                x = torch.tensor(x, dtype=self.dtype).to(self.device)

                mu, log_var = self.encoder(x)
                z = self.sample_latent(mu, log_var)
                # interrupt after first batch
                z_list.append(mu)
                label_list.append(y)

        label_list = torch.cat(label_list, dim=0).cpu().numpy()
        z_list = torch.cat(z_list, dim=0).cpu().numpy()
        return(z_list, label_list)

    def get_mutation_likelihood(self, wt_sequence, mut_one_hot_df, mut_dataloader, num_samples, batch_size=256):
        
        Neff_training = np.sum(np.ones(len(mut_one_hot_df)))
        self.eval()

        prediction_matrix_wt = torch.zeros((1,num_samples))
        with torch.no_grad():
            x = wt_sequence.to(self.device)
            for i in range(num_samples):
                seq_predictions_wt, _, _ = self.all_likelihood_components(x)
                prediction_matrix_wt[0,i] = seq_predictions_wt

        mean_wt_score = prediction_matrix_wt.mean(dim=1, keepdim=False)[0]


        prediction_matrix = torch.zeros((len(mut_one_hot_df),num_samples))
        with torch.no_grad():
            
            # for i, batch in enumerate(tqdm.tqdm(mut_dataloader, 'Looping through mutation batches')):
            for i, batch in enumerate(tqdm.auto.tqdm(mut_dataloader, 'Looping through mutation batches', leave=True, ascii=True)):
                # print(batch[0])
                # print(batch[1])
                x = batch.type(self.dtype).to(self.device)
                for j in range(num_samples):
                    seq_predictions, _, _ = self.all_likelihood_components(x)
                    prediction_matrix[i*batch_size:i*batch_size+len(x),j] = seq_predictions
            mean_predictions = prediction_matrix.mean(dim=1, keepdim=False)
            # std_predictions = prediction_matrix.std(dim=1, keepdim=False)
            delta_elbos =   mean_wt_score - mean_predictions
            evol_indices = - delta_elbos.detach().cpu().numpy()
        return(evol_indices)

    def visualize_z(self, z_list, label_list, iteration, s, alpha):
        cdict = {0:'tab:cyan', 1: 'tab:blue', 2: 'tab:orange', 3: 'tab:green', 4: 'tab:red', 5:'tab:purple', 6:'tab:brown', 7:'tab:pink' ,8:'tab:gray', 9: 'tab:olive'}
        fig, ax = plt.subplots()
        for g in np.unique(label_list):
            ix = np.where(label_list == g)
            ax.scatter(z_list[ix,0], z_list[ix,1], c = cdict[g], marker='.', label = phyla_lookup_table[g], s=s, alpha=alpha)
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.65))
        plt.xlabel('values of Z1')
        plt.ylabel('values of Z2')
        plt.title('latent space of dimension 2 of VAE model')
        plt.show()
        plt.savefig(f"{iteration}.png")

    def save(self, model_checkpoint, encoder_parameters, decoder_parameters, training_parameters, batch_size=256):
        torch.save({
            'model_state_dict':self.state_dict(),
            'encoder_parameters':encoder_parameters,
            'decoder_parameters':decoder_parameters,
            'training_parameters':training_parameters,
            }, model_checkpoint)
    
    def compute_evol_indices(self, msa_data, list_mutations_location, num_samples, batch_size=256):
        """
        The column in the list_mutations dataframe that contains the mutant(s) for a given variant should be called "mutations"
        """
        #Multiple mutations are to be passed colon-separated
        list_mutations=pd.read_csv(list_mutations_location, header=0)
        
        #Remove (multiple) mutations that are invalid
        list_valid_mutations = ['wt']
        list_valid_mutated_sequences = {}
        list_valid_mutated_sequences['wt'] = msa_data.focus_seq_trimmed # first sequence in the list is the wild_type
        for mutation in list_mutations['mutations']:
            individual_substitutions = mutation.split(':')
            mutated_sequence = list(msa_data.focus_seq_trimmed)[:]
            fully_valid_mutation = True
            for mut in individual_substitutions:
                wt_aa, pos, mut_aa = mut[0], int(mut[1:-1]), mut[-1]
                if pos not in msa_data.uniprot_focus_col_to_wt_aa_dict or msa_data.uniprot_focus_col_to_wt_aa_dict[pos] != wt_aa or mut not in msa_data.mutant_to_letter_pos_idx_focus_list:
                    print ("Not a valid mutant: "+mutation)
                    fully_valid_mutation = False
                    break
                else:
                    wt_aa,pos,idx_focus = msa_data.mutant_to_letter_pos_idx_focus_list[mut]
                    mutated_sequence[idx_focus] = mut_aa #perform the corresponding AA substitution
            
            if fully_valid_mutation:
                list_valid_mutations.append(mutation)
                list_valid_mutated_sequences[mutation] = ''.join(mutated_sequence)
        
        #One-hot encoding of mutated sequences
        mutated_sequences_one_hot = np.zeros((len(list_valid_mutations),len(msa_data.focus_cols),len(msa_data.alphabet)))
        for i,mutation in enumerate(list_valid_mutations):
            sequence = list_valid_mutated_sequences[mutation]
            for j,letter in enumerate(sequence):
                if letter in msa_data.aa_dict:
                    k = msa_data.aa_dict[letter]
                    mutated_sequences_one_hot[i,j,k] = 1.0

        mutated_sequences_one_hot = torch.tensor(mutated_sequences_one_hot)
        dataloader = torch.utils.data.DataLoader(mutated_sequences_one_hot, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
        prediction_matrix = torch.zeros((len(list_valid_mutations),num_samples))

        with torch.no_grad():
            for i, batch in enumerate(tqdm.tqdm(dataloader, 'Looping through mutation batches')):
                x = batch.type(self.dtype).to(self.device)
                for j in range(num_samples):
                    seq_predictions, _, _ = self.all_likelihood_components(x)
                    prediction_matrix[i*batch_size:i*batch_size+len(x),j] = seq_predictions
                tqdm.tqdm.write('\n')
            mean_predictions = prediction_matrix.mean(dim=1, keepdim=False)
            std_predictions = prediction_matrix.std(dim=1, keepdim=False)
            delta_elbos = mean_predictions - mean_predictions[0]
            evol_indices =  - delta_elbos.detach().cpu().numpy()

        return list_valid_mutations, evol_indices, mean_predictions[0].detach().cpu().numpy(), std_predictions.detach().cpu().numpy()

In [None]:
data = data_class(data_filename='BLAT_ECOLX_1_b0.5_labeled.fasta', calc_weights=model_params["training_parameters"]['calc_weights'])
   

In [None]:
import os, sys
import argparse
import pandas as pd
import json

load_model = False

if load_model:
    data = data_class(data_filename='BLAT_ECOLX_1_b0.5_labeled.fasta', calc_weights=model_params["training_parameters"]['calc_weights'])
    model = VAE_model(
                    model_name=model_name,
                    data=data,
                    encoder_parameters=torch.load("experiment1_final")["encoder_parameters"],
                    decoder_parameters=torch.load("experiment1_final")["decoder_parameters"],
                    random_seed=42)
    model.load_state_dict(torch.load("experiment1_final")['model_state_dict']) # Change name of saved file
    model = model.to(model.device)

else:

    model_name = 'experiment1'

    model = VAE_model(
                    model_name=model_name,
                    data=data,
                    encoder_parameters=model_params["encoder_parameters"],
                    decoder_parameters=model_params["decoder_parameters"],
                    random_seed=42
    )
    model = model.to(model.device)

    model.train_model(data=data, training_parameters=model_params["training_parameters"])

    print("Saving model: " + model_name)
    model.save(model_checkpoint=model_params["training_parameters"]['model_checkpoint_location']+os.sep+model_name+"_final", 
                encoder_parameters=model_params["encoder_parameters"], 
                decoder_parameters=model_params["decoder_parameters"], 
                training_parameters=model_params["training_parameters"]
    )

[1. 1. 1. ... 1. 1. 1.]




KeyboardInterrupt: ignored

Get sequences for test set

##Get Z values and mut_scores

In [None]:
if __name__=='__main__':
    z_list, label_list = model.get_z(dataloader)
    mut_scores = model.get_mutation_likelihood(wt_one_hot, mut_one_hot)

In [None]:
cdict = {0:'tab:cyan', 1: 'tab:blue', 2: 'tab:orange', 3: 'tab:green', 4: 'tab:red', 5:'tab:purple', 6:'tab:brown', 7:'tab:pink' ,8:'tab:gray', 9: 'tab:olive'}
fig, ax = plt.subplots()
for g in np.unique(label_list):
    ix = np.where(label_list == g)
    ax.scatter(z_list[ix,0], z_list[ix,1], c = cdict[g], marker='.', label = phyla_lookup_table[g], alpha=1)
ax.legend(loc='center left', bbox_to_anchor=(1, 0.65))
plt.xlabel('values of Z1')
plt.ylabel('values of Z2')
plt.title('latent space of dimension 2 of VAE model')
plt.show()

In [None]:
spear = spearmanr(mut_scores, exp_list)
print(spear)