##### Copyright 2020 Google LLC.


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Figure 3 and Tables


## Setup

In [None]:
import os
import zipfile

from IPython.display import display
from matplotlib import pyplot
import numpy
import pandas
import scipy.spatial.distance as distance
import scipy.stats
import seaborn


# The canonical single-letter code residue alphabet.
RESIDUES = tuple('ACDEFGHIKLMNPQRSTVWY')

# Residues sorted by physicochemical properties.
#
# This ordering is useful when generating visualizations to highlight common
# behaviors among similar residues.
RESIDUES_PHYSCHEM_ORDER = tuple('ILAVGMFYWEDQNHCRKSTP')

# The full VP1 sequence for AAV serotype 2.
AAV2_VP1_SEQ = 'MAADGYLPDWLEDTLSEGIRQWWKLKPGPPPPKPAERHKDDSRGLVLPGYKYLGPFNGLDKGEPVNEADAAALEHDKAYDRQLDSGDNPYLKYNHADAEFQERLKEDTSFGGNLGRAVFQAKKRVLEPLGLVEEPVKTAPGKKRPVEHSPVEPDSSSGTGKAGQQPARKRLNFGQTGDADSVPDPQPLGQPPAAPSGLGTNTMATGSGAPMADNNEGADGVGNSSGNWHCDSTWMGDRVITTSTRTWALPTYNNHLYKQISSQSGASNDNHYFGYSTPWGYFDFNRFHCHFSPRDWQRLINNNWGFRPKRLNFKLFNIQVKEVTQNDGTTTIANNLTSTVQVFTDSEYQLPYVLGSAHQGCLPPFPADVFMVPQYGYLTLNNGSQAVGRSSFYCLEYFPSQMLRTGNNFTFSYTFEDVPFHSSYAHSQSLDRLMNPLIDQYLYYLSRTNTPSGTTTQSRLQFSQAGASDIRDQSRNWLPGPCYRQQRVSKTSADNNNSEYSWTGATKYHLNGRDSLVNPGPAMASHKDDEEKFFPQSGVLIFGKQGSEKTNVDIEKVMITDEEEIRTTNPVATEQYGSVSTNLQRGNRQAATADVNTQGVLPGMVWQDRDVYLQGPIWAKIPHTDGHFHPSPLMGGFGLKHPPPQILIKNTPVPANPSTTFSAAKFASFITQYSTGQVSVEIEWELQKENSKRWNPEIQYTSNYNKSVNVDFTVDTNGVYSEPRPIGTRYLTRNL'

# The AAV serotype 2 wild type subsequence corresponding to tile #21 in round #1
# of experimental results.
R1_TILE21_WT_SEQ = 'DEEEIRTTNPVATEQYGSVSTNLQRGNR'

# The start and end residue numbers (inclusive) for the tile 21 wild type seq.
#
# Residue numbering scheme corresponds to the 1-based index of the AAV2 VP1
# sequence.
R1_TILE21_WT_START_RESNUM = 561
R1_TILE21_WT_END_RESNUM = 588


seed_display_name = 'model-selected'
walked_display_name = 'model-designed'

partition_pretty_names = {
    'cnn_designed_plus_rand_train_seed': 'CNN-C ' + seed_display_name,
    'cnn_designed_plus_rand_train_walked': 'CNN-C ' + walked_display_name,
    'cnn_rand_doubles_plus_single_seed': 'CNN-A ' + seed_display_name,
    'cnn_rand_doubles_plus_single_walked': 'CNN-A ' + walked_display_name,
    'cnn_standard_seed': 'CNN-B ' + seed_display_name,
    'cnn_standard_walked': 'CNN-B ' + walked_display_name,
    'designed': 'Additive',
    'lr_designed_plus_rand_train_seed': 'LR-C ' + seed_display_name,
    'lr_designed_plus_rand_train_walked': 'LR-C ' + walked_display_name,
    'lr_rand_doubles_plus_single_seed': 'LR-A ' + seed_display_name,
    'lr_rand_doubles_plus_single_walked': 'LR-A ' + walked_display_name,
    'lr_standard_seed': 'LR-B ' + seed_display_name,
    'lr_standard_walked': 'LR-B ' + walked_display_name,
    'rand': 'Random',
    'rnn_designed_plus_rand_train_seed': 'RNN-C ' + seed_display_name,
    'rnn_designed_plus_rand_train_walked': 'RNN-C ' + walked_display_name,
    'rnn_rand_doubles_plus_singles_seed': 'RNN-A ' + seed_display_name,
    'rnn_rand_doubles_plus_singles_walked': 'RNN-A ' + walked_display_name,
    'rnn_standard_seed': 'RNN-B ' + seed_display_name,
    'rnn_standard_walked': 'RNN-B ' + walked_display_name,
}

ml_generated_partitions = [
    'rnn_designed_plus_rand_train_walked',
    'rnn_designed_plus_rand_train_seed',
    'rnn_rand_doubles_plus_singles_walked',
    'rnn_rand_doubles_plus_singles_seed',
    'rnn_standard_walked',
    'rnn_standard_seed',
    'cnn_designed_plus_rand_train_walked',
    'cnn_designed_plus_rand_train_seed',
    'cnn_rand_doubles_plus_single_walked',
    'cnn_rand_doubles_plus_single_seed',
    'cnn_standard_walked',
    'cnn_standard_seed',
    'lr_designed_plus_rand_train_walked',
    'lr_designed_plus_rand_train_seed',
    'lr_rand_doubles_plus_single_walked',
    'lr_rand_doubles_plus_single_seed',
    'lr_standard_walked',
    'lr_standard_seed',
]

ml_designed_partitions = [
    'lr_rand_doubles_plus_single_walked',
    'lr_standard_walked',
    'lr_designed_plus_rand_train_walked',

    'cnn_rand_doubles_plus_single_walked',
    'cnn_standard_walked',
    'cnn_designed_plus_rand_train_walked',

    'rnn_rand_doubles_plus_singles_walked',
    'rnn_standard_walked',
    'rnn_designed_plus_rand_train_walked',
]

nn_designed_partitions = [
    'cnn_rand_doubles_plus_single_walked',
    'cnn_standard_walked',
    'cnn_designed_plus_rand_train_walked',

    'rnn_rand_doubles_plus_singles_walked',
    'rnn_standard_walked',
    'rnn_designed_plus_rand_train_walked',
]

ml_selected_partitions = [
    'lr_rand_doubles_plus_single_seed',
    'lr_standard_seed',
    'lr_designed_plus_rand_train_seed',
    
    'cnn_rand_doubles_plus_single_seed',
    'cnn_standard_seed',
    'cnn_designed_plus_rand_train_seed',

    'rnn_rand_doubles_plus_singles_seed',
    'rnn_standard_seed',
    'rnn_designed_plus_rand_train_seed',    
]

ml_designed_partitions_doubles = [
    'lr_rand_doubles_plus_single_walked',
    'cnn_rand_doubles_plus_single_walked',
    'rnn_rand_doubles_plus_singles_walked',
]

ml_designed_partitions_standard = [
    'lr_standard_walked',
    'cnn_standard_walked',
    'rnn_standard_walked',
]

ml_designed_partitions_designed = [
    'lr_designed_plus_rand_train_walked',
    'cnn_designed_plus_rand_train_walked',
    'rnn_designed_plus_rand_train_walked',
]

baseline_random_partitions = ['rand']

baseline_additive_partitions = ['designed']

In [None]:
# Sentinel token value for denoting "no mutation here".
_PLACEHOLDER_TOKEN = '_'
# Number of different mutation slots to use for each wildtype sequence position.
_NUM_MUTATION_SLOTS = 2  # 1 substitution + 1 insert possible per wt position.
# Slot index for substitution mutations.
_SUB_INDEX = 0
# Slot index for insertion mutations.
_INS_INDEX = 1


def tokenize_mutation_seq(seq, placeholder_token='_'):
  """Converts a variable-length mutation sequence to a fixed-length sequence.

  For an N-residue reference sequence, the encoding is shape (N+1, M, A), where
  A is the alphabet size (e.g., A=20 for the canonical peptide alphabet) and M
  is the number of distinct mutation types at each position; here, M=2
  (1x sub + 1x ins at each reference sequence position).

  Args:
    seq: (str) A mutation sequence to tokenize; e.g., "__A_" or "aTEST".
    placeholder_token: (str) Sentinel value used to encode non-mutated positions
      in the mutation sequence.
  Returns:
    A length-N+1 sequence of ("<substitution_token>", "<insertion token>")
    2-tuples.
  """
  tokens = []
  i = 0
  # Consume the prefix insertion mutation if there is one.
  #
  # A prefix insertion is denoted by a leading lower case letter on the seq.
  if seq[i].islower():
    tokens.append((placeholder_token, seq[i].upper()))
    i += 1
  else:
    tokens.append((placeholder_token, placeholder_token))

  while i < len(seq):
    if i < len(seq) - 1 and seq[i + 1].islower():
      tokens.append((seq[i], seq[i+1].upper()))
      i += 2
    else:
      tokens.append((seq[i], placeholder_token))
      i += 1
  return tokens


class MutationSequenceEncoder(object):
  """Mutation sequence encoder for generating fixed-length representations.

  The encoding has two slots for each residue position in the ref sequence:
    1. A slot that encodes a residue substitution mutation
    2. A slot that encodes a single-residue insertion mutation

  There is also a pair of slots for any single-position prefix mutation.

  Attributes:
    encoding_size: (int) The encoding length for a single residue.
  """

  def __init__(self, residue_encoder, ref_seq):
    """Constructor.

    Args:
      residue_encoder: (object) A single residue encoder
      ref_seq: (str) The reference (non-mutated) sequence.
    """
    self._residue_encoder = residue_encoder
    self._ref_seq = ref_seq
    self.encoding_size = self._residue_encoder.encoding_size

  def encode(self, seq):
    """Encodes a mutation sequence as a fixed-length multi-dimensional array.

    Args:
      seq: (str) A mutation sequence to encode; e.g., "__A_".
    Returns:
      A numpy.ndarray(shape=(len(ref_seq)+1, 2, encoding_size), dtype=float).
    Raises:
      ValueError: if the mutation sequence references a different number of
        sequence positions than the specified ref_seq.
    """
    seq_encoding = numpy.zeros((
        len(self._ref_seq) + 1,
        _NUM_MUTATION_SLOTS,
        self.encoding_size))

    sub_ins_tokens = tokenize_mutation_seq(seq, _PLACEHOLDER_TOKEN)
    if len(sub_ins_tokens) != len(self._ref_seq) + 1:
      raise ValueError('Mutation sequence dimension mismatch: '
                       '%d mutation positions vs %d in reference sequence'
                       % (len(sub_ins_tokens), len(self._ref_seq) + 1))

    for position_i, (sub_token, ins_token) in enumerate(sub_ins_tokens):
      if sub_token != _PLACEHOLDER_TOKEN:
        seq_encoding[position_i, _SUB_INDEX, :] = self._residue_encoder.encode(
            sub_token)
      if ins_token != _PLACEHOLDER_TOKEN:
        seq_encoding[position_i, _INS_INDEX, :] = self._residue_encoder.encode(
            ins_token)
    return seq_encoding


class ResidueIdentityEncoder(object):
  """Residue identity encoder, either one-hot or residue index.

  Attributes:
    encoding_size: (int) The number of encoding dimensions for a single residue.
  """

  def __init__(self, alphabet, one_hot=True):
    """Constructor.

    Args:
      alphabet: (seq<char>) The alphabet of valid tokens for the sequence;
        e.g., the 20x 1-letter residue codes for standard peptides.
      one_hot: if true, performs one-hot encoding (dim = alphabet length);
        if false, encodes as the index of the residue in the alphabet (dim = 1).
    """
    self._alphabet = [l.upper() for l in alphabet]
    self._letter_to_id = dict((letter, id) for (id, letter)
                              in enumerate(self._alphabet))
    self._one_hot = one_hot
    self.encoding_size = len(self._alphabet) if one_hot else 1

  def encode(self, residue):
    """Encodes a single residue as a one hot identity vector.

    Args:
      residue: (str) A single-character string representing one residue; e.g.,
        'A' for Alanine.
    Returns:
      If one-hot=True, a numpy.ndarray(shape=(A,), dtype=float) with a single
      non-zero (1) value; the identity index for each residue in the alphabet is
      given by the residue's index in the alphabet ordered sequence; i.e., for
      the alphabet 'ACDE', a 'C' would be encoded as [0, 1, 0, 0].
      If one-hot=False, a numpy.ndarray(shape=(1,), dtype=float) with the
      residue's index in the alphabet ordered sequence.
    """
    if self._one_hot:
      onehot = numpy.zeros(self.encoding_size, dtype=float)
      onehot[self._letter_to_id[residue]] = 1
      return onehot
    else:
      return numpy.array((self._letter_to_id[residue],), dtype=float)


ONEHOT_FIXEDLEN_MUTATION_ENCODER = MutationSequenceEncoder(
    ResidueIdentityEncoder(RESIDUES), R1_TILE21_WT_SEQ)

### Load data

In [None]:
my_zip = zipfile.ZipFile('allseqs_20191230.csv.zip')
my_zip.extractall() # extract csv file to the current working directory

df = pandas.read_csv('allseqs_20191230.csv', index_col=None)
del df['num_mutations']  # Prefer 'num_edits' column which is Levenshtein distance to WT

print df.shape
df.head()

(296970, 6)


Unnamed: 0,sequence,partition,mutation_sequence,num_edits,viral_selection,is_viable
0,ADEEIRATNPIATEMYGSVSTNLQLGNR,designed,AD____A___I___M_________L___,6,-2.027259,False
1,ADEEIRATNPVATEQYGSVSTNQQRQNR,designed,AD____A_______________Q__Q__,5,-0.429554,True
2,ADEEIRTTNPVATEQWGGVSTNLQIGNY,designed,AD_____________W_G______I__Y,6,-0.527843,True
3,ADEEIRTTNPVATEQYGEVSTNLQRGNR,designed,AD_______________E__________,3,2.887908,True
4,ADEEIRTTNPVATEQYGSVSTNLQRGNR,designed,AD__________________________,2,0.57573,True


## SI Tables: generated and viable capsid statistics

#### Lib

In [None]:
def _get_percent_viable(num_viable, num_total):
  if num_total == 0:
    return 0
  else: 
    return float(num_viable) / num_total * 100


def _format_stats_table(stats_table):
  percent_formatter = lambda pct: '{:4.1f}%'.format(pct)
  count_formatter = lambda n: '{:7,}'.format(n)

  stats_table['percent_viable'] = stats_table.percent_viable.apply(percent_formatter)
  stats_table['num_total'] = stats_table.num_total.apply(count_formatter)
  stats_table['num_viable'] = stats_table.num_viable.apply(count_formatter)

  return stats_table.rename({
    'num_total': '# generated', 
    'num_viable': '# viable',
    'percent_viable': '% viable',
  }, axis=1)


def performance_by_wt_distance(df, partitions, distances=range(2, 30)):
  rows = []
  for t in distances:
    num_viable = len(df[
        df.partition.isin(partitions) 
        & (df.num_edits >= t)
        & df.is_viable
    ])
    num_total = len(df[
        df.partition.isin(partitions) 
        & (df.num_edits >= t)
    ])

    rows.append({
        'min_mutations': t,
        'num_total': num_total, 
        'num_viable': num_viable,
        'percent_viable': _get_percent_viable(num_viable, num_total),
    })
  col_order = [
    'min_mutations',
    'num_total',
    'num_viable',
    'percent_viable',
  ]    
  return _format_stats_table(pandas.DataFrame(rows, columns=col_order))


def performance_by_model(df, partitions):
  rows = []
  for partition in partitions:
    num_viable = len(df[(df.partition == partition) & df.is_viable])
    num_total = len(df[df.partition == partition])

    rows.append({
        'partition': partition,
        'num_total': num_total, 
        'num_viable': num_viable, 
        'percent_viable': _get_percent_viable(num_viable, num_total),
    })

  col_order = [
    'partition',
    'num_total',
    'num_viable',
    'percent_viable',
  ]
  return _format_stats_table(pandas.DataFrame(rows, columns=col_order))


# display(performance_by_wt_distance(df, ml_designed_partitions))
# display(performance_by_model(df, ml_designed_partitions))

#### SI Table 1

In [None]:
performance_by_wt_distance(df, ml_generated_partitions)

Unnamed: 0,min_mutations,# generated,# viable,% viable
0,2,201426,110689,55.0%
1,3,201426,110689,55.0%
2,4,201424,110687,55.0%
3,5,201368,110633,54.9%
4,6,193413,103403,53.5%
5,7,184424,95422,51.7%
6,8,175443,87571,49.9%
7,9,166361,79628,47.9%
8,10,157294,72180,45.9%
9,11,148167,64678,43.7%


### SI Table 2

In [None]:
performance_by_wt_distance(df, ml_designed_partitions)

Unnamed: 0,min_mutations,# generated,# viable,% viable
0,2,183466,106665,58.1%
1,3,183466,106665,58.1%
2,4,183464,106663,58.1%
3,5,183411,106612,58.1%
4,6,176351,100150,56.8%
5,7,168231,92923,55.2%
6,8,160096,85766,53.6%
7,9,151805,78411,51.7%
8,10,143464,71416,49.8%
9,11,135099,64267,47.6%


### SI Table 3

In [None]:
performance_by_wt_distance(df, nn_designed_partitions)

Unnamed: 0,min_mutations,# generated,# viable,% viable
0,2,123331,79837,64.7%
1,3,123331,79837,64.7%
2,4,123329,79835,64.7%
3,5,123280,79788,64.7%
4,6,117855,74431,63.2%
5,7,112376,69020,61.4%
6,8,106907,63624,59.5%
7,9,101326,58145,57.4%
8,10,95698,52658,55.0%
9,11,90035,47192,52.4%


#### SI Table 4

In [None]:
display(performance_by_model(df, ml_selected_partitions))

Unnamed: 0,partition,# generated,# viable,% viable
0,lr_rand_doubles_plus_single_seed,2071,114,5.5%
1,lr_standard_seed,1989,486,24.4%
2,lr_designed_plus_rand_train_seed,2030,340,16.7%
3,cnn_rand_doubles_plus_single_seed,2022,381,18.8%
4,cnn_standard_seed,1924,476,24.7%
5,cnn_designed_plus_rand_train_seed,1898,529,27.9%
6,rnn_rand_doubles_plus_singles_seed,2045,575,28.1%
7,rnn_standard_seed,1916,412,21.5%
8,rnn_designed_plus_rand_train_seed,2065,711,34.4%


#### SI Table 5

In [None]:
display(performance_by_model(df, ml_designed_partitions))

Unnamed: 0,partition,# generated,# viable,% viable
0,lr_rand_doubles_plus_single_walked,19999,1483,7.4%
1,lr_standard_walked,20456,19211,93.9%
2,lr_designed_plus_rand_train_walked,19680,6134,31.2%
3,cnn_rand_doubles_plus_single_walked,20454,11229,54.9%
4,cnn_standard_walked,20395,13086,64.2%
5,cnn_designed_plus_rand_train_walked,20759,14968,72.1%
6,rnn_rand_doubles_plus_singles_walked,20154,13056,64.8%
7,rnn_standard_walked,20838,15525,74.5%
8,rnn_designed_plus_rand_train_walked,20731,11973,57.8%


### SI Table 6

In [None]:
performance_by_wt_distance(df, baseline_additive_partitions, distances=range(2, 40))

Unnamed: 0,min_mutations,# generated,# viable,% viable
0,2,56372,35217,62.5%
1,3,50572,30068,59.5%
2,4,41232,22129,53.7%
3,5,31561,14551,46.1%
4,6,22407,8159,36.4%
5,7,13892,2953,21.3%
6,8,12603,2181,17.3%
7,9,11387,1561,13.7%
8,10,10245,1101,10.7%
9,11,9171,757,8.3%


### SI Table 7

In [None]:
performance_by_wt_distance(df, baseline_random_partitions, distances=range(2, 11))

Unnamed: 0,min_mutations,# generated,# viable,% viable
0,2,9885,964,9.8%
1,3,8129,461,5.7%
2,4,6378,213,3.3%
3,5,4631,93,2.0%
4,6,2883,32,1.1%
5,7,1154,3,0.3%
6,8,866,2,0.2%
7,9,576,1,0.2%
8,10,284,1,0.4%


## Additional: performance vs wt distance per model 

In [None]:
sub_tables = []
for partition in ml_designed_partitions:
  partition_perf = performance_by_wt_distance(
      df, 
      [partition],
  )  
  partition_perf['partition'] = partition
  sub_tables.append(partition_perf)
stats_table = pandas.concat(sub_tables)
stats_table

Unnamed: 0,min_mutations,# generated,# viable,% viable,partition
0,2,19999,1483,7.4%,lr_rand_doubles_plus_single_walked
1,3,19999,1483,7.4%,lr_rand_doubles_plus_single_walked
2,4,19999,1483,7.4%,lr_rand_doubles_plus_single_walked
3,5,19998,1482,7.4%,lr_rand_doubles_plus_single_walked
4,6,19375,1265,6.5%,lr_rand_doubles_plus_single_walked
5,7,18504,1005,5.4%,lr_rand_doubles_plus_single_walked
6,8,17613,881,5.0%,lr_rand_doubles_plus_single_walked
7,9,16712,592,3.5%,lr_rand_doubles_plus_single_walked
8,10,15829,570,3.6%,lr_rand_doubles_plus_single_walked
9,11,14895,349,2.3%,lr_rand_doubles_plus_single_walked


## Figure 3

### Perplexity of residues by position per model

In [None]:
def get_mutation_count_matrix(
    sequences, encoder=ONEHOT_FIXEDLEN_MUTATION_ENCODER):
  mutations = None
  for seq in sequences:
    seq_mutations = encoder.encode(seq)
    if mutations is None:
      mutations = seq_mutations
    else:
      mutations += seq_mutations
  subs = mutations[1:, 0, :]
  inserts = mutations[1:, 1, :]
  return subs, inserts


def get_perplexity(mutation_count_matrix, replace_nan=True):
  """
  Args:
    mutation_count_matrix: (n_positions, 20) array containing #mutations of each
      residue type for the set of positions
  Returns:
    perplexity per position (n_positions,) array with max value of 20
    (a uniform distribution for a given position would have perplexity of 20
    b/c complete confusion across 20 options).
  """
  counts_matrix = mutation_count_matrix.T
  perplexity = 2**scipy.stats.entropy(counts_matrix, base=2)
  if replace_nan:
    perplexity[numpy.isnan(perplexity)] = 0  # For plotting purposes
  return perplexity


def plot_mutation_perplexity(
    mutation_count_matrix, 
    start_resnum=R1_TILE21_WT_START_RESNUM,
    end_resnum=R1_TILE21_WT_END_RESNUM,  # inclusive      
    label=None, 
    linewidth=1,
    ):
  
  perplexity = get_perplexity(mutation_count_matrix)
  resnums = range(start_resnum, end_resnum+1)
  # Trick to make the final step in plot be full-width: add extra point
  perplexity = list(perplexity) + [0]
  resnums.append(end_resnum+2)  # TODO: simplify
  pyplot.step(
      resnums, 
      perplexity,
      label=label,
      where='post', 
      linewidth=linewidth)

def plot_mutation_perplexity_multi(
    df,
    partitions, 
    start_resnum=R1_TILE21_WT_START_RESNUM,
    end_resnum=R1_TILE21_WT_END_RESNUM,  # inclusive      
    subs=True,
    figsize=(12, 3),
    tick_size=10,
    anno_fontsize=10,
    axis_label_size=10,
    ):
  fig, ax=pyplot.subplots(figsize=figsize)
  
  for p in partitions:
    sub_counts, insert_counts = get_mutation_count_matrix(
        df[
            (df.partition == p) 
            & (df.is_viable)
            & (df.num_edits >= 12)
        ].sequence)

    linewidth = 1
    if subs:
      plot_mutation_perplexity(
          sub_counts, 
          start_resnum=start_resnum,
          end_resnum=end_resnum,
          label=p, 
          # color=color,
          linewidth=linewidth)
    else:
      plot_mutation_perplexity(
          insert_counts, 
          start_resnum=start_resnum,
          end_resnum=end_resnum,
          label=p, 
          # color=color,
          linewidth=linewidth)

  pyplot.ylim(0, 15)
  pyplot.yticks([0, 5, 10, 15])
  for y_thresh in [5, 10, 15, 20]:
    pyplot.axhline(
        y=y_thresh, color='black', linestyle='--', alpha=.7, linewidth=.25)    

  ax.spines['right'].set_visible(True)
  ax.spines['right'].set_linewidth(0.5)  
  ax.tick_params(axis='both', labelsize=tick_size)
  pyplot.legend(loc='upper left')


seaborn.set_style('white')
for p in [
          ml_designed_partitions_doubles, 
          ml_designed_partitions_standard,
          ml_designed_partitions_designed,
          ]:

  plot_mutation_perplexity_multi(df, p, subs=True)
  pyplot.show()

### Mutation distribution heatmaps by model

In [None]:
def show_mutation_heatmap_side_by_side_horizontal(
    df,
    encoder=ONEHOT_FIXEDLEN_MUTATION_ENCODER,
    log=True,
    normalize=False,
    colorbar_num_quantiles=None,
    wt_seq=R1_TILE21_WT_SEQ,
    figsize=(2.5, 1),
    cmap_name='viridis',
    color_rgb=None,
    scale_color_rgb=False, # should the rgb values be divided by 255
    wt_point_size=15,
    cbar=True,
    vmax=None,
    subs_only=False,
    dpi=300,
    linewidth=1,
    threshold_linewidth=.25,
    scale=4):

  tick_size = 6
  axis_label_size = 8
  anno_fontsize = tick_size
  figsize = tuple(x*scale for x in figsize)
  pyplot.figure(figsize=figsize, dpi=dpi)  
  ax = pyplot.gca()

  assert all(numpy.array(encoder._residue_encoder._alphabet) == numpy.array(RESIDUES))
  
  mutations = None
  for seq in df['mutation_sequence']:
    seq_mutations = encoder.encode(seq)
    if mutations is None:
      mutations = seq_mutations
    else:
      mutations += seq_mutations
  mutations = mutations[1:, :, :]  # Remove the prefix position
  print 'mutation heatmap range <%d, %d>' % (mutations.min(), mutations.max())

  residue_to_index = {
      v: k for k,v in enumerate(encoder._residue_encoder._alphabet)
  }
  physchem_residue_order = [
      residue_to_index[aa] for aa in RESIDUES_PHYSCHEM_ORDER
  ]      
  residue_order = physchem_residue_order
  physchem_residue_labels = RESIDUES_PHYSCHEM_ORDER
  tile21_resnums = [
    R1_TILE21_WT_START_RESNUM + i 
    for i in range(len(wt_seq))
  ]

  pyplot.ylabel('AAV2 residue number')
  
  subs = mutations[:, 0, :]
  inserts = mutations[:, 1, :]
  mutations = numpy.concatenate([
      subs,  # subs only
      inserts,  # ins only   
  ], axis=1)

  if normalize:
    mutations /= len(df)  # normalize by number of sequences
  if log:
    mutations = numpy.log10(1 + mutations)

  # Rotate the heatmap horizontally via transpose
  mutations = mutations.T
  subs_and_ins_residue_order = (
      residue_order 
      + list(len(residue_order) + numpy.array(residue_order))  # residues but offset by 20
  )

  wt_mutations = encoder.encode(wt_seq)
  wt_mutations = wt_mutations[1:, : :]  # drop prefix slot
  wt_subs = wt_mutations[:, 0, :]
  wt_ins = wt_mutations[:, 1, :]
  wt_mutations = numpy.concatenate([wt_subs, wt_ins], axis=1)
  wt_mutations = wt_mutations.T
  wt_mutations = wt_mutations[subs_and_ins_residue_order, :]
  wt_residue_indices, wt_position_indices = numpy.where(wt_mutations > 0)
  marker_offset_epsilon = 0.1
  marker_offset_residue_indices = 0.5 - marker_offset_epsilon
  marker_offset_position_indices = 0.5
  wt_position_indices = wt_position_indices + marker_offset_position_indices
  wt_residue_indices = wt_residue_indices + marker_offset_residue_indices
  
  cmap = pyplot.cm.get_cmap(cmap_name, colorbar_num_quantiles)
  if color_rgb is not None:
    if scale_color_rgb:
      color_rgb = [c/255. for c in color_rgb]
    cmap = seaborn.light_palette(color_rgb, n_colors=100, input="rgb")
  ax = seaborn.heatmap(
      mutations[subs_and_ins_residue_order, :], 
      cmap=cmap,
      xticklabels=tile21_resnums,
      yticklabels=physchem_residue_labels + physchem_residue_labels,
      robust=True,
      cbar=cbar,
      vmax=vmax if not log else numpy.log10(vmax),
  )

  if cbar:
    cbar = ax.collections[0].colorbar
    if log:
      tick_values = [1, 10, 100, 1000]
      possible_log_ticks = [numpy.log10(t) for t in tick_values]
      possible_log_tick_labels = [str(t) for t in tick_values]
      log_ticks = []
      log_tick_labels = []
      for t, l in zip(possible_log_ticks, possible_log_tick_labels):
        if t <= numpy.log10(vmax):
          log_ticks.append(t)
          log_tick_labels.append(l)
      cbar.set_ticks(log_ticks)
      log_tick_labels[-1] = '>' + log_tick_labels[-1]
      cbar.set_ticklabels(log_tick_labels)

  ax.scatter(
      wt_position_indices - .1,  # shift the point more to the center of the square
      wt_residue_indices, 
      color='white', 
      s=wt_point_size,
      )
  pyplot.axhline(y=20, color='white', linewidth=1)  # horizontal separator between subs and inserts



################################################################################
for name in ml_designed_partitions:
    data = df[
              (df.partition == name) 
              & (df.is_viable) 
              & (df.num_edits >= 12)]
    show_mutation_heatmap_side_by_side_horizontal(
        data, 
        normalize=False, 
        log=True,
        colorbar_num_quantiles=30,
        scale=8,
        cmap_name='viridis',
        cbar=True,
        vmax=1000,
    )
    pyplot.show()
