## "EvoMPNN" preparation notebook

### Goal of this notebook
* This notebook takes in a multiple sequence aligment (MSA) for a protein of interest and outputs residues to design for MPNN
* Residues are selected following the ["EvoMPNN" paper](https://www.biorxiv.org/content/10.1101/2023.10.03.560713v1.abstract)
    * The paper did this by ranking the frequency of the most common amino acid at all positions in the MSA
    * Then, the top 30%, 50%, or 70% of the most frequent residues were locked, i.e. not designed, during MPNN
    * Empirically, enzymes with 50% or 70% locked residues performed well in that paper, so this notebook focuses on those
    * I also made it the case that the residue is only locked if it's residue is the same as the most frequent residue
    * If you want other cutoff values, this can be done by changing the cutoff argument when calling cutoff_ranked_resis()
* The final cell (bottom of the script) will print the residues to lock in the format that MPNN wants, a space delimted list.

### How to use:
1. Make an MSA in .a3m format using a tool such as hhblits [(link here)](https://toolkit.tuebingen.mpg.de/tools/hhblits)
    - Input your sequence in fasta format
    - When the run is finished, go to the Query MSA tab and click Download Full A3M
2. Update the variables in the cell directly below this to the .a3m path and the name of your protein 
    - The protein name should be the same as the first sequence in the output .a3m file
    - You can add additional residues to lock (e.g. active site residues) as the resis_to_lock list
3. Run all cells (spam shift+enter)

In [None]:
input_a3m = ''
input_protein = ''
resis_to_lock = []


In [None]:

# these all assume that your input protein at the top of your a3m file has no '-' characters!

def process_fasta(fasta_file):
    with open(fasta_file) as f:
        fasta_dict = {}
        lines = [l.strip() for l in f.readlines() if l.strip()]
        for i in range(len(lines)):
            if lines[i][0] == '>':
                name = lines[i][1:]
                seq = ''
                new_index = i + 1
                while lines[new_index][0] != '>':
                    # seq += lines[new_index] #.strip()
                    upper_only = ''
                    for chr in lines[new_index]:
                        if chr.isupper() or chr == '-':
                            upper_only += chr
                    seq += upper_only
                    new_index += 1
                    if new_index == len(lines):
                        break
                
                fasta_dict[name] = seq
    
    return fasta_dict

def normalize_freqs(d):
    total_resis = sum([d[i] for i in d if i != '-'])
    perc_freqs = {}
    
    for res, num in d.items():
        if res != '-':
            perc_freqs[res] = num / total_resis
    
    max_resi = max(perc_freqs, key=perc_freqs.get)
    max_freq = perc_freqs[max_resi]

    for resi, freq in perc_freqs.items():
        if freq >= max_freq and resi != max_resi:
            max_resi += resi
            max_freq += freq

    return max_resi, max_freq

def get_consensus_dict(fasta_dict):
    # output: a dict of res_pos keys, (most common aa, aa freq) values
    seqs = list(fasta_dict.values())
    
    # check all seq lens are the same
    if len({len(seq) for seq in seqs}) != 1:
        return 'not all seqs are the same len! check alignment'
    
    # num_aas_by_pos will be {resnum: {aa1: num, aa2: num}}
    # everything will be 1-indexed to be consistent with protein numbering, not standard python
    num_aas_by_pos = {i:{} for i in range(1, len(seqs[0])+1)}
    for seq in seqs:
        for idx, aa in enumerate(seq):
            if aa not in num_aas_by_pos[idx+1]:
                num_aas_by_pos[idx+1][aa] = 1
            else:
                num_aas_by_pos[idx+1][aa] += 1
    
    # perc_aas_by_pos is {resnum: (one letter res, percent freq)}
    # if two or more resis have the same freq, one letter res will be a str of all one letter codes and percent freq will be summed
    # skip "-" (i.e. gap) characters
    perc_aas_by_pos = {}
    consensus = ['' for _ in range(len(seqs[0]))]
    for resnum, aa_freq in num_aas_by_pos.items():
        max_resi, max_freq = normalize_freqs(aa_freq)
        perc_aas_by_pos[resnum] = (max_resi, max_freq)
        consensus[resnum-1] = max_resi

    return perc_aas_by_pos, consensus

# goal: rank the conservation of residues and take only the top x% that are conserved.

def cutoff_ranked_resis(perc_aas_by_pos, cutoff_perc):
    # perc_aas_by_pos is {resnum: (one letter res, percent freq)}

    # ordered resis should be a list of (resnum, one letter res, percent freq)
    ordered_resis = [(i, *perc_aas_by_pos[i]) for i in range(1, len(perc_aas_by_pos)+1)]
    
    ordered_resis.sort(key=lambda x:x[2], reverse=True)
    cutoff_num = cutoff_perc * len(ordered_resis)
    
    return ordered_resis[:int(cutoff_num)]

def renum_ranked_to_seq(seqname, consensus, seq_dict, ranked_resis, keep_same_as_consensus=False):

    seq = seq_dict[seqname]
    keep_resis = set([i[0] for i in ranked_resis])

    # need to ensure that we are getting the consensus at each position

    mapped_seq = []

    for idx, seq_aa in enumerate(seq):
        assert seq_aa != '-'
        if idx+1 in keep_resis:
            if (keep_same_as_consensus and seq_aa in consensus[idx]) or not keep_same_as_consensus:
                mapped_seq.append(idx+1)

    return mapped_seq

In [None]:
# a3m output made using https://toolkit.tuebingen.mpg.de/tools/hhblits
# output contains insertions (lower case chars), which are cut out during the process_fasta() step

name_to_seq = process_fasta(input_a3m)
num_aas_by_pos = get_consensus_dict(name_to_seq)
perc_aas_by_pos, consensus = get_consensus_dict(name_to_seq)

ranked_resis_50 = cutoff_ranked_resis(perc_aas_by_pos, 0.5)
ranked_resis_70 = cutoff_ranked_resis(perc_aas_by_pos, 0.7)

# update the string below with your protein's name
resis_to_keep_50 = renum_ranked_to_seq(input_protein, consensus, name_to_seq, ranked_resis_50, False)
resis_to_keep_70 = renum_ranked_to_seq(input_protein, consensus, name_to_seq, ranked_resis_70, False)

resis_to_keep_50 += resis_to_lock.copy()
resis_to_keep_70 += resis_to_lock.copy()

# print outputs for pymol coloring
# print('color palegreen, sele and resi ' + '+'.join([str(i) for i in resis_to_keep_70]))
# print('color paleyellow, sele and resi ' + '+'.join([str(i) for i in resis_to_keep_50]))

# print outputs for MPNN
print('70% most frequent residues, unless your aa is different than consensus:')
print(f'fixed_positions="{" ".join(str(i) for i in resis_to_keep_70)}"')
print('\n50% most frequent residues, unless your aa is different than consensus:')
print(f'fixed_positions="{" ".join(str(i) for i in resis_to_keep_50)}"')

