# Sequence landscape figure

Requires: theano and biopython.

Must be run in python 2 due to theano requirements. Other scripts used to produced required data:

  - "Load Taxonomy Labels from Uniprot" (jupyter notebook) was used to produced taxonomy assignments (which is able to produce `PSE1_NATURAL_TAXONOMY.csv` and `PSEAB_NATURAL_TAXONOMY.csv`)

Note: training of the variational autoencoder takes a long time.

Requires `sequence_space_data.zip` from https://evcouplings.org/3Dseq


In [0]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='png'

In [0]:
PROJECT_ROOT = '.'

In [0]:
import sys
sys.path.append(PROJECT_ROOT)

import os, copy, time
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import scipy

from Bio import SeqIO
from Bio import pairwise2

#then import the deepseqs package code
import model
import helper
import train

import theano
theano.config.gcc.cxxflags = "-Wno-c++11-narrowing" #FIXES ISSUE WITH COMPILING THEANO LIVE COMPILATION

# Global Variables

In [0]:
PROJECT_ROOT = '.'

ROUNDS = 'ROUNDS'
QUERY = 'QUERY'
NATURAL_FULL = 'Natural Full'
NATURAL_SUBSAMPLE_RANDOM = 'Natural random'
RND10 = 'Round 10'
RND20 = 'Round 20'
RND2 = 'Round 2'
RND4 = 'Round 4'
RND8 = 'Round 8'

WT_FASTANAME_TO_COMMONNAME = {
    'BLAT_PSE1/1-266': 'PSE1',
    'WT_PSEAB': 'AAC6',
    'WT_PSEAB/1-148': 'AAC6'
}

ALIGNMENT_DIR_PSE1 = PROJECT_ROOT+'/data.PSE1'
ALIGNMENT_DIR_AAC6 = PROJECT_ROOT+'/data.AAC6'
FIGURE_OUTPUT_DIR = PROJECT_ROOT+'/figures'

PSE1_ATTRIBUTES = {
    'name': 'PSE1',
    'sequence_start_offset': 6, #these are lower case - not aligned
    'sequence_end_offset': 4,   #these are lower case - not aligned
    'alignment_directory': ALIGNMENT_DIR_PSE1,
    'full_natural_alignment_filename': 
        ALIGNMENT_DIR_PSE1+'/7fa1c5691376beab198788a726917d48_b0.4.a2m',
    'label_filenames': {
        QUERY: ALIGNMENT_DIR_PSE1+'/PSE1.fas',
        NATURAL_SUBSAMPLE_RANDOM: 
            ALIGNMENT_DIR_PSE1+'/PSE1_natural_>85%ungapped_5K_subsample.a2m',
        RND20: ALIGNMENT_DIR_PSE1+'/Rnd20_init.fas', 
        RND10: ALIGNMENT_DIR_PSE1+'/Rnd10.fas'
    },
    'label_color':{ #taken from "tab10" matplotlib colors
        NATURAL_FULL: 'black',
        NATURAL_SUBSAMPLE_RANDOM: 'black',
        RND10: '#3a76af',#'firebrick',
        RND20: '#ef8536', #'darkgreen', 
        QUERY: 'yellow'
    },
    'taxonomy_filename': ALIGNMENT_DIR_PSE1+'/PSE1_NATURAL_TAXONOMY.csv',
    'vae_random':{
        'file_prefix': 'PSE1_2latent_5ksubsample',
        'alignment_filename': ALIGNMENT_DIR_PSE1+'/PSE1_natural_>85%ungapped_5K_subsample.a2m',
        'working_dir': ALIGNMENT_DIR_PSE1
    },
    'class_colors': { #min 100 seq members
       'Actinobacteria': "#1f77b4", 'Alphaproteobacteria': "#ff7f0e", 'Bacilli': "#2ca02c", 
       'Betaproteobacteria': "#d62728", 'Gammaproteobacteria': "#17becf", 
    }
}

AAC6_ATTRIBUTES = {
    'name': 'AAC6',
    'sequence_start_offset': 0,
    'sequence_end_offset': 0,
    'alignment_directory': ALIGNMENT_DIR_AAC6,
    'full_natural_alignment_filename': 
        ALIGNMENT_DIR_AAC6+'/44883374318b63406a7415d2f4d4cfc1_b0.4.a2m',
    'label_filenames': {
        QUERY: ALIGNMENT_DIR_AAC6+'/WT_PSEAB.fas',
        NATURAL_SUBSAMPLE_RANDOM: 
            ALIGNMENT_DIR_AAC6+'/AAC6_natural_>85%ungapped_5K_subsample.a2m',
        RND2: ALIGNMENT_DIR_AAC6+'/G2_PSEAB_0_St115_50ksub.fas', 
        RND4: ALIGNMENT_DIR_AAC6+'/G4_PSEAB_0_St115_50ksub.fas', 
        RND8: ALIGNMENT_DIR_AAC6+'/G8_all_PSEAB_0_St115_50ksub.fas',
    },
    'label_color':{
        NATURAL_FULL: 'black',#'gray',
        NATURAL_SUBSAMPLE_RANDOM: 'black',#'gray',
        RND2: '#c53932',#'red', 
        RND4: '#8d6bb8',#'limegreen', 
        RND8: '#85584e',#'cornflowerblue',
        QUERY: 'yellow'
    },
    'taxonomy_filename': ALIGNMENT_DIR_AAC6+'/AAC6_NATURAL_TAXONOMY.csv',
    'vae_random':{
        'file_prefix': 'AAC6_2latent_5ksubsample',
        'alignment_filename': ALIGNMENT_DIR_AAC6+'/AAC6_natural_>85%ungapped_5K_subsample.a2m',
        'working_dir': ALIGNMENT_DIR_AAC6,
    },
    'class_colors': { #min 100 seq members
       'Actinobacteria': "#1f77b4", 'Alphaproteobacteria': "#ff7f0e", 'Bacilli': "#2ca02c", 
       'Betaproteobacteria': "#d62728", 'Clostridia': "#9467bd", 'Gammaproteobacteria': "#17becf", 
    }
}

ALPHABET = 'ACDEFGHIKLMNPQRSTVWY'
AA_IDX_DICT = {aa:i for i,aa in enumerate(ALPHABET)}
IDX_AA_DICT = {idx: aa for aa, idx in AA_IDX_DICT.iteritems()}

# Setup Data Munging Functions

In [0]:
def getMutationCountFromQuery(query_in_binary, sequences_in_binary):
    '''
    Returns a list of the number of mutations each sequence in sequences has
    relative to the query sequence.
    '''
    toreturn = []
    for target_seq_in_binary in sequences_in_binary:
        num_matching_residues = np.sum( np.logical_and(
            query_in_binary, target_seq_in_binary
        ) )
        toreturn.append( (len(query_in_binary)/len(ALPHABET)) - num_matching_residues )
    return toreturn

def encodeSequencesAsBinary(sequences, flatten_each_sequence = False):
    '''
    Encodes a set of polypeptide sequences in to binary.
    '''
    start = time.time()
    if type(sequences[0]) is SeqIO.SeqRecord: 
        sequences = [str(seq.seq) for seq in sequences]
    
    seq_arr = np.zeros((len(sequences),len(sequences[0]),len(ALPHABET)))
    for i,seq in enumerate(sequences):
        for j,aa in enumerate(seq):
            if aa in AA_IDX_DICT: #non-nomal letters just get all zeros
                seq_arr[i,j,AA_IDX_DICT[aa]] = 1.
                
    if flatten_each_sequence:
        return seq_arr.reshape((len(sequences),-1))
    return seq_arr

def loadSequenceFile(filename, remove_start=0, remove_end= 0):
    '''
    Loads sequences from a fasta file. Removes initial or trailing amino acids
    as needed.
    '''
    print('Loading "{0}" (removing {1} start and {2} end residues)'.format(
        filename, remove_start, remove_end
    ))
    toreturn = list(SeqIO.parse(filename, 'fasta'))
    if remove_start or remove_end:
        toremove = None 
        if remove_end: toremove = -1*remove_end
        toreturn = [seq[remove_start:toremove] for seq in toreturn]
    return toreturn


def loadFastafileIntoDF(attributes, 
                        filename, 
                        label, 
                        focus_seq_only=True,
                        override_remove_start = None,
                        override_remove_end = None):
    '''
    Load a single fastafile into a dataframe.
    If focus_seq_only == False, no amino acids will be removed from
    the sequences.
    '''
    remove_start=attributes['sequence_start_offset'] 
    remove_end=attributes['sequence_end_offset']
    if focus_seq_only == False:
        remove_start = 0
        remove_end = 0
    if override_remove_start: remove_start = override_remove_start
    if override_remove_end: remove_start = override_remove_end
    
    query_sequence = loadSequenceFile(
        attributes['label_filenames'][QUERY], 
        remove_start=remove_start, 
        remove_end=remove_end
    )[0]
    
    toreturn = pd.DataFrame()
    seq_records = loadSequenceFile(
        filename, 
        remove_start=remove_start, 
        remove_end=remove_end
    )

    toreturn['seq_record'] = seq_records
    toreturn['label']= label
    toreturn['color']= attributes['label_color'][label]
    toreturn['seq_name'] = [seq.name for seq in toreturn['seq_record']]
    toreturn['seq_str'] = [ str(seq.seq) for seq in toreturn['seq_record'] ]
    toreturn['num_aminoacids'] = [
        len(seq_str.replace('.','').replace('-','')) for seq_str in toreturn['seq_str']
    ]
    
    #encode sequences as binary
    toreturn['seq_binary'] = list(
        encodeSequencesAsBinary( toreturn['seq_record'] )
    )
    toreturn['seq_binary_flat'] = list(encodeSequencesAsBinary( 
        toreturn['seq_record'], flatten_each_sequence=True
    ))

    #mut count from the query sequence
    toreturn['mutcount_from_query'] = getMutationCountFromQuery(
        encodeSequencesAsBinary([query_sequence], True)[0],
        toreturn.seq_binary_flat
    )
    
    return toreturn


def createSubsample(attributes):
    '''
    Full natural alignments were generated by EVcouplings 
    using a bitscore of 0.4. This method will generate a 
    5000 sequence subsample of the full alignment using
    the following process
      1. remove any sequences that are not 85% ungapped 
         relative to the query sequence.
      2. Save the file to attributes['label_filenames']['NATURAL_SUBSAMPLE_***']
    '''
    query_sequence = loadSequenceFile(
        attributes['label_filenames'][QUERY]
    )[0]
    
    tmpdf = loadFastafileIntoDF(
        attributes, 
        attributes['full_natural_alignment_filename'], 
        NATURAL_FULL,
        focus_seq_only=False
    )
    
    len_all_natural_sequences = len(tmpdf)
    num_aminoacids_queryseq = len(query_sequence)
    tmpdf = tmpdf[tmpdf.num_aminoacids >= 0.85 * num_aminoacids_queryseq]
    tmpdf = tmpdf[tmpdf.seq_str != str(query_sequence.seq)] #remove query sequence
    
    print('{0} seqs removed due to 85% length requirement ({1})'.format(
        len_all_natural_sequences - len(tmpdf), 
        0.85 * num_aminoacids_queryseq
    ))
    
    filename = attributes['label_filenames'][NATURAL_SUBSAMPLE_RANDOM]
    tmpdf = tmpdf.sample(4999)
    
    SeqIO.write(
        [query_sequence] + list(tmpdf.seq_record), 
        filename, 
        'fasta'
    )
    print('done saving 5000 seq subsample')

    

def loadDataframe(attributes, 
                  labels_list=None, 
                  load_taxonomy=True,
                  override_remove_start = None,
                  override_remove_end = None):
    '''
    Load the dataframe for a set of sequences.
    
    Parameters:
        attributes:       a dictionary of attributes for the requested
                          dataset to load
        labels_list:      a list of labels to load. If None, then the
                          full list will be used that is specified in: 
                             attributes['label_filenames'].keys()
        load_taxonomy:    whether or not to include taxonomy in the
                          returned dataframe.
    '''
    if os.path.isfile(attributes['label_filenames'][NATURAL_SUBSAMPLE_RANDOM]) == False:
        createSubsample(attributes)
    
    toreturn_df = pd.DataFrame()
    
    if labels_list == None: labels_list = attributes['label_filenames'].keys()
    for lbl in labels_list:
        filename = None
        if lbl == NATURAL_FULL:
            filename = attributes['full_natural_alignment_filename']
        else:  filename = attributes['label_filenames'][lbl]
        
        lbl_df = loadFastafileIntoDF(
            attributes, filename, lbl,
            override_remove_start=override_remove_start,
            override_remove_end=override_remove_end
        )
        
        #add to dataframe that is returned
        toreturn_df = toreturn_df.append( lbl_df, ignore_index=True )
    
    if not load_taxonomy: return toreturn_df
    
    #load taxonomy
    taxonomy = pd.read_csv(attributes['taxonomy_filename'], sep='\t')
    return pd.merge(toreturn_df, taxonomy, how='left', 
                    left_on='seq_name', right_on='seq_name')#.drop('uniprot_name', axis=1)


# Load Data

In [0]:
pse1_df = loadDataframe(
    PSE1_ATTRIBUTES, 
    labels_list=[QUERY, NATURAL_SUBSAMPLE_RANDOM, RND10, RND20]
)

In [0]:
aac6_df = loadDataframe(
    AAC6_ATTRIBUTES, 
    labels_list=[QUERY, NATURAL_SUBSAMPLE_RANDOM, RND2, RND4, RND8]
)

# lab evolution sequence space vs natural sequence space

In [0]:
#
#
# data munging functions
#
#
def getVAEDataHelper(vae_attributes):
    return helper.DataHelper(
        alignment_file=vae_attributes['alignment_filename'],
        working_dir=vae_attributes['working_dir'],
        calc_weights=False,
    )

def loadVAEModel(vae_attributes):
    start = time.time()
    data_helper = getVAEDataHelper(vae_attributes)
    vae_model   = model.VariationalAutoencoder(
        data_helper,
        batch_size                = 100,
        encoder_architecture      = [1500,1500],
        decoder_architecture      = [100,500],
        n_latent                  = 2,
        n_patterns                = 4,
        warm_up                   = 0.0,
        convolve_patterns         = True,
        conv_decoder_size         = 40,
        logit_p                   = 0.001,
        sparsity                  = 'logit',
        encode_nonlinearity_type  = 'relu',
        decode_nonlinearity_type  = 'relu',
        final_decode_nonlinearity = 'sigmoid',
        output_bias               = True,
        final_pwm_scale           = True,
        working_dir               = data_helper.working_dir
    )
    print('Done loading model - took {0}s'.format(time.time()-start))
    return vae_model

def loadVAEParameters(vae_attributes, vae_model):
    start = time.time()
    vae_model.load_parameters(file_prefix=vae_attributes['file_prefix'])
    print('Done Loading parameters - took {0}s'.format(time.time()-start))
    return vae_model

def loadVAEModelFromDisk(vae_attributes):
    '''
    Shorthand to load model and parameters -- or create them if they
    don't exist -- and then return the loaded model.
    '''
    try:
        vae_model = loadVAEModel(vae_attributes)
        vae_model = loadVAEParameters(vae_attributes, vae_model)
    except IOError:
        print('error: unable to load vae model. please train the model first.')
    return vae_model

    
def projectSequencesIntoVAESpace(model, sequences_df):
    '''
    Calculates the position in Z space for each sequence in
    the dataframe, creates a copy of the dataframe, adds 
    those positions (z1 and z2) as a column and returns the
    new dataframe 
    '''
    #project all sequences into Z space
    batch_mu, batch_log_sigma = model.recognize(
        list( sequences_df.seq_binary )
    )
    toreturn = sequences_df.copy()
    toreturn['z1'] = np.asarray(batch_mu)[:,0]
    toreturn['z2'] = np.asarray(batch_mu)[:,1]
    return toreturn


In [0]:
pse1_vae_model = loadVAEModelFromDisk(PSE1_ATTRIBUTES['vae_random'])

In [0]:
aac6_vae_model = loadVAEModelFromDisk(AAC6_ATTRIBUTES['vae_random'])

In [0]:
aac6_vae_df = projectSequencesIntoVAESpace(aac6_vae_model, aac6_df)

In [0]:
pse1_vae_df = projectSequencesIntoVAESpace(pse1_vae_model, pse1_df)

In [0]:
def plotSingleJointPlot(dataframe, 
                        label, 
                        g=None, 
                        plot_width_height=10, 
                        alpha=0.4,
                        hist_alpha = 0.2,
                        num_bins = 50, 
                        dot_size = 2,
                        plot_histograms=True,
                        subsample_size=False,
                        color='grey'):
    
    tmpdf = dataframe
    if subsample_size and subsample_size < len(dataframe):
        tmpdf = dataframe.sample(subsample_size)
    
    
    if not g:
        g = sns.JointGrid(
            x=tmpdf.z1, 
            y=tmpdf.z2,
            height=plot_width_height
        )
    else:
        g.x = tmpdf.z1
        g.y = tmpdf.z2
        
    g = g.plot_joint(func=plt.scatter, color=color, 
                     s=dot_size, alpha=alpha,
                     label='{0} (n={1})'.format(label, len(g.x)))
    
    if plot_histograms:
        g.ax_marg_x.hist(
            tmpdf.z1,
            color=color,
            alpha = hist_alpha,
            bins=num_bins
        )
        g.ax_marg_y.hist(
            tmpdf.z2,
            orientation = 'horizontal',
            color=color,
            alpha = hist_alpha,
            bins=num_bins
        )
    return g

def plotSequenceSpace(dataframe, 
                      title, 
                      orderedLabels = [RND8, RND4, RND2],
                      plot_width_height=12,
                      plot_histograms=True,
                      highlightQuery=True,
                      save_as_filenames=None,
                      save_as_dpi=300,
                      alpha=0.4,
                      hist_alpha=0.2,
                      dot_size=2,
                      subsample_size=False,
                      color_lineage = None,   #color_lineage accepts "phylum" or "class"
                      lineage_colors_hm = None,
                      include_title=True,
                      include_legend=True,
                      override_num_y_bins = None,
                      override_num_x_bins = None,
                      query_dot_size=100):
    g = None
    if color_lineage and color_lineage not in ['phylum', 'class']:
        print('warning: color_lineage must be either "phylum" or "class" NOT {0}'.format(color_lineage))
        color_lineage = None
    
    for label in orderedLabels:
        #print('label == {0}'.format(label))
        
        color = list(dataframe[dataframe.label==label].color)[0]
        if color_lineage and label in [NATURAL_FULL, 
                                       NATURAL_SUBSAMPLE_RANDOM]:
            
            lbl_df = dataframe[dataframe.label == label]
            lbl_lineage_df = lbl_df[pd.isna(lbl_df[color_lineage]) == False]
            if subsample_size and subsample_size < len(lbl_lineage_df):
                #subsample the label, not the color
                lbl_lineage_df = lbl_lineage_df.sample( subsample_size )
            
            coloring_hm = lineage_colors_hm
            for classorphylum in coloring_hm:
                color = coloring_hm[classorphylum]
                tmpdf = lbl_lineage_df[lbl_lineage_df[color_lineage] == classorphylum]
                
                g = plotSingleJointPlot(
                    tmpdf, classorphylum, g=g, 
                    plot_width_height=plot_width_height,
                    plot_histograms=plot_histograms,
                    alpha=alpha, hist_alpha=hist_alpha,
                    color = color, dot_size=dot_size
                )
        else:
            g = plotSingleJointPlot(
                dataframe[dataframe.label==label], label, g=g, 
                plot_width_height=plot_width_height,
                plot_histograms=plot_histograms,
                alpha=alpha, hist_alpha=hist_alpha,
                subsample_size=subsample_size,
                color = color, dot_size=dot_size
            )
    if include_title:
        g.fig.suptitle(title, fontsize=16, y=1.01)
    
    g.ax_joint.xaxis.get_label().set_fontsize(36)
    g.ax_joint.yaxis.get_label().set_fontsize(36)
    [ticklbl.set_fontsize(28) for ticklbl in g.ax_joint.xaxis.get_ticklabels()]
    [ticklbl.set_fontsize(28) for ticklbl in g.ax_joint.yaxis.get_ticklabels()]
    
    if override_num_y_bins is not None:
        if override_num_y_bins != 0:
            plt.locator_params(axis='y', nbins=override_num_y_bins)
        else: plt.yticks([], [])

    if override_num_x_bins is not None:
        if override_num_x_bins != 0:
            plt.locator_params(axis='x', nbins=override_num_x_bins)
        else: plt.xticks([], [])
    
    if highlightQuery:
        query = dataframe[dataframe.label == QUERY]
        if len(query) != 1:
            print('Error: QUERY label occurs {0} times (expected only 1 occurance)'.format(
                len(query)
            ))
        else:
            g.x=query.z1
            g.y=query.z2
            g.plot_joint(
                plt.scatter, 
                s=query_dot_size, 
                color='yellow', 
                edgecolor='black', 
                label=WT_FASTANAME_TO_COMMONNAME[list(query.seq_name)[0]]
            )
    if include_legend:
        g.ax_joint.legend()
    
    if save_as_filenames:
        if isinstance(save_as_filenames, list):
            for filename in save_as_filenames:
                g.savefig(filename, dpi=save_as_dpi)
        else: 
            g.savefig(filename, dpi=save_as_dpi)


In [0]:
plotSequenceSpace(
    pse1_vae_df, 
    'PSE1 All Sequences\ncolored by class', 
    orderedLabels=[NATURAL_SUBSAMPLE_RANDOM, RND20, RND10],
    color_lineage='class',
    lineage_colors_hm=PSE1_ATTRIBUTES['class_colors'],
    save_as_filenames = [
        #FIGURE_OUTPUT_DIR+'/PSE1_VAE_trainnatural5ksubsample_showall_classcolor.pdf'
    ],
    save_as_dpi=300,
    alpha=1,
    hist_alpha=0.4,
    dot_size=10,
    plot_histograms=False,
    subsample_size=10000,
    include_legend=False,
    include_title=False,
    override_num_y_bins=4,
    override_num_x_bins=4
)
plotSequenceSpace(
    pse1_vae_df, 
    'PSE1 Lab Sequences', 
    orderedLabels=[RND20, RND10],
    save_as_filenames = [
        #FIGURE_OUTPUT_DIR+'/PSE1_VAE_trainnatural5ksubsample_showlab.pdf'
    ],
    save_as_dpi=300,
    alpha=1,
    hist_alpha=0.4,
    dot_size=5,
    subsample_size=10000,
    include_legend=False,
    include_title=False,
    override_num_y_bins=4,
    override_num_x_bins=4,
    query_dot_size=350
)

In [0]:
plotSequenceSpace(
    aac6_vae_df, 
    'AAC6 All Sequences\ncolored by class', 
    orderedLabels=[NATURAL_SUBSAMPLE_RANDOM, RND8, RND4, RND2],
    color_lineage='class',
    lineage_colors_hm=AAC6_ATTRIBUTES['class_colors'],
    save_as_filenames = [
        #FIGURE_OUTPUT_DIR+'/AAC6_VAE_trainnatural5ksubsample_showall_colorclass.pdf'
    ],
    save_as_dpi=300,
    alpha=1,
    hist_alpha=0.4,
    dot_size=10,
    subsample_size=10000,
    include_legend=False,
    include_title=False,
    override_num_y_bins=4,
    override_num_x_bins=4
)
plotSequenceSpace(
    aac6_vae_df, 
    'AAC6 Lab Sequences', 
    orderedLabels=[RND8, RND4, RND2],
    save_as_filenames = [
        #FIGURE_OUTPUT_DIR+'/AAC6_VAE_trainnatural5ksubsample_showlab.pdf'
    ],
    save_as_dpi=300,
    alpha=1,
    hist_alpha=0.4,
    dot_size=5,
    subsample_size=10000,
    include_legend=False,
    include_title=False,
    override_num_y_bins=4,
    override_num_x_bins=4,
    query_dot_size=350
)