<a href="https://colab.research.google.com/github/spetti/SMURF/blob/main/examples/af_msa_backprop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%bash
if [ ! -d af_backprop ]; then
  git clone https://github.com/sokrypton/af_backprop.git
  pip -q install py3Dmol biopython ml_collections dm-haiku
fi
if [ ! -d SMURF ]; then
  git clone https://github.com/spetti/SMURF.git
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi

In [2]:
! nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Tue_Sep_15_19:10:02_PDT_2020
Cuda compilation tools, release 11.1, V11.1.74
Build cuda_11.1.TC455_06.29069683_0


In [22]:
import os
import jax
import jax.numpy as jnp
from jax.example_libraries.optimizers import adam
import pickle

import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.append('af_backprop')
sys.path.append('SMURF')
from utils import *

import laxy
import sw_functions as sw
import network_functions as nf

# import libraries
from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.model import data, config, model
from alphafold.data import parsers
from alphafold.model import all_atom

In [23]:
#os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/n/helmod/apps/centos7/Core/cuda/11.1.0-fasrc01"
#os.environ["TF_CPP_MIN_LOG_LEVEL"]= "0"

In [24]:
print(jnp.ones(3).device_buffer.device()) 

gpu:0


In [25]:
! nvidia-smi

Thu Mar 17 16:50:09 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:1C:00.0 Off |                    0 |
| N/A   43C    P0    72W / 300W |   1675MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [26]:
def get_feat(filename, alphabet="ARNDCQEGHILKMFPSTWYV"):
  '''
  Given A3M file (from hhblits)
  return MSA (aligned), MS (unaligned) and ALN (alignment)
  '''
  def parse_fasta(filename):
    '''function to parse fasta file'''    
    header, sequence = [],[]
    lines = open(filename, "r")
    for line in lines:
      line = line.rstrip()
      if len(line) == 0: pass
      else:
        if line[0] == ">":
          header.append(line[1:])
          sequence.append([])
        else:
          sequence[-1].append(line)
    lines.close()
    sequence = [''.join(seq) for seq in sequence]
    return header, sequence

  names, seqs = parse_fasta(filename)  
  a2n = {a:n for n,a in enumerate(alphabet)}
  def get_seqref(x):
    n,seq,ref,aligned_seq = 0,[],[],[]
    for aa in list(x):
      if aa != "-":
        seq.append(a2n.get(aa.upper(),-1))
        if aa.islower(): ref.append(-1); n -= 1
        else: ref.append(n); aligned_seq.append(seq[-1])
      else: aligned_seq.append(-1)
      n += 1
    return seq, ref, aligned_seq
  
  # get the multiple sequence alignment
  max_len = 0
  ms, aln, msa = [],[],[]
  for seq in seqs:
    seq_,ref_,aligned_seq_ = get_seqref(seq)
    if len(seq_) > max_len: max_len = len(seq_)
    ms.append(seq_)
    msa.append(aligned_seq_)
    aln.append(ref_)
  
  return msa, ms, aln

In [27]:
def prep_inputs(DOM):
  a3m_file = f"SMURF/examples/CASP_examples/{DOM}.mmseqs.id90cov75.a3m"
  _, ms, aln = get_feat(a3m_file)
  lens = np.asarray([len(m) for m in ms])
  ms = nf.one_hot(nf.pad_max(ms))
  aln = nf.one_hot(nf.pad_max(aln))
  N = len(ms)
  protein_obj = protein.from_pdb_string(pdb_to_string(f"SMURF/examples/CASP_examples/{DOM}.pdb"))
  batch = {'aatype': protein_obj.aatype,
          'all_atom_positions': protein_obj.atom_positions,
          'all_atom_mask': protein_obj.atom_mask}
  batch.update(all_atom.atom37_to_frames(**batch)) # for fape calculcation
  msa, mtx = parsers.parse_a3m(open(a3m_file,"r").read())
  feature_dict = {
      **pipeline.make_sequence_features(sequence=msa[0],description="none",num_res=len(msa[0])),
      **pipeline.make_msa_features(msas=[msa], deletion_matrices=[mtx])
  }
  feature_dict["residue_index"] = protein_obj.residue_index
  return {"N":N,"lens":lens,
          "ms":ms,"aln":aln,
          "feature_dict":feature_dict,
          "protein_obj":protein_obj, "batch":batch}

In [28]:
def get_model_runner(num_seq, model_name="model_3_ptm", dropout=False, backprop_recycles=False):
  # setup which model params to use
  model_config = config.model_config(model_name)
  model_config.model.global_config.use_remat = True

  model_config.model.num_recycle = 3
  model_config.data.common.num_recycle = 3

  model_config.data.eval.max_msa_clusters = num_seq
  model_config.data.common.max_extra_msa = 1
  model_config.data.eval.masked_msa_replace_fraction = 0

  # backprop through recycles
  model_config.model.backprop_recycle = backprop_recycles
  model_config.model.embeddings_and_evoformer.backprop_dgram = backprop_recycles

  if not dropout:
    model_config = set_dropout(model_config,0)

  # setup model
  model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
  model_runner = model.RunModel(model_config, model_params, is_training=True)
  return model_runner, model_params

In [29]:
from utils import order_restype
order_restype[20]="-"

In [30]:
def get_seqs(seq):
  x = seq.argmax(-1)
  if x.ndim == 1:
    return "".join([order_restype[a] for a in x])
  else:
    return ["".join([order_restype[a] for a in s]) for s in x]

In [31]:
def add_gap_to_one_hot(x, lens):
    xnew = jnp.zeros((x.shape[0],x.shape[1],x.shape[2]+1)).at[...,:x.shape[-1]].set(x)
    if lens is not None:
        for i,l in enumerate(lens):
            xnew = xnew.at[i,l:,20].set(1)
    return xnew

In [32]:
# only works with an alignment-- not a distribution over alignments
def naive_insert_counter(A, true_len):
  R = A.shape[-1] # length of reference seq
  L = A.shape[-2] # length of query seq
  C = jnp.array([range(1,L+1)])[0]
  align_to = A.T @ C
  out = []
  gap_before_first =0
  most_recent_aligned_pos_in_cons = 0
  for j in range(L):
    if np.sum(A[j,...])==0: gap_before_first += 1 
    else: 
      out.append(gap_before_first)
      break  
  for i in range(R):
    if align_to[i]==0:
      out.append(0)
      continue
    most_recent_aligned_pos_in_cons = i
    next = 0
    for j in range(i+1,R):
      if align_to[j]!= 0:
        next = j
        out.append(align_to[j].item()-align_to[i].item() - 1)
        break 
    if next == 0:
      out.append(0)
  if align_to[most_recent_aligned_pos_in_cons]<true_len:
    out[-1]=true_len - align_to[most_recent_aligned_pos_in_cons]
  return out

In [33]:
def get_readable_msa(x, aln, lens):
    aln = aln>.5
    # aln is MSA where inserts wrt to first sequence are removed
    # this function adds them back
    
    seqs = add_gap_to_one_hot(x, lens)
    
    # compute where inserts are
    inserts = jnp.zeros((aln.shape[0],aln.shape[-1]+1))
    for _ in range(aln.shape[0]):
        inserts = inserts.at[_,...].set(naive_insert_counter(aln[_,...], lens[_]))
    inserts = np.array(inserts, dtype = 'int')
    max_inserts = np.max(inserts, axis = 0)
    #print(max_inserts)
    total_length = int((jnp.sum(max_inserts) + aln.shape[-1]).item())
    
    # compute MSA with inserts
    gap_pos = seqs.shape[-1]-1
    msa = jnp.zeros((aln.shape[0],total_length, gap_pos+1))
    for s in range(seqs.shape[0]): # number of seqs
        #print(f"SEQUENCE {s}")
        # add gaps before first seq if applicable
        p = max_inserts[0] # reference with respect to the full alignment that we are building (total_length)
        #print(f"filling 0 to {inserts[s,0]} with values from seq")
        msa = msa.at[s,0:inserts[s,0],:].set(seqs[s,:inserts[s,0],:])
        #print(f"filling {inserts[s,0]} to {p} with gaps")
        msa = msa.at[s,inserts[s,0]:p,gap_pos].set(1)
        #print(get_seqs(np.array(msa[s,...])))
        for i in range(aln.shape[-1]): #length of consensus
            #print(p,i, max_inserts[i])
            if jnp.sum(aln[s,:,i]) == 0: # no position in seq is aligned to reference pos i
                msa = msa.at[s,p,gap_pos].set(1)
                #print(f"gap at reference pos {i}")
            else: 
                s_pos = jnp.argmax(aln[s,:,i]) # position in seq that is aligned to reference pos i
                #print(f"position {s_pos} is aligned to reference pos {i}")
                msa = msa.at[s,p,:].set(seqs[s,s_pos,:])
            for k in range(p+1, p + max_inserts[i+1]+1):
                # this sequence as an insert
                if k-p<= inserts[s,i+1]:
                    #print(f"inserting pos {s_pos+k-p} at {k}")
                    msa = msa.at[s,k,:].set(seqs[s,s_pos+k-p,:])
                # this sequence doesn't have an insert
                else: 
                    msa = msa.at[s,k,gap_pos].set(1)
            p += (max_inserts[i+1] + 1)
    
    #return positions to highlight as inserts wrt first seq
    red_pos = np.zeros(total_length, dtype = 'int')
    p = max_inserts[0]
    red_pos[0:p] = 1
    for i in range(aln.shape[-1]): #length of consensus
        red_pos[p]=0
        red_pos[p+1: p + max_inserts[i+1]+1]= 1
        p = p + max_inserts[i+1]+1
    
    return get_seqs(np.array(msa)), red_pos

In [34]:
from termcolor import colored

In [35]:
def print_seqs_color(seqs, min_i, max_i, red_pos, color = "blue"):
    if min_i is False:
        min_i = 0
    if max_i is False:
        max_i = len(seqs[0])
    max_i = min(max_i, len(seqs[0]))
    for n,seq in enumerate(seqs):
        to_print=f"{n:03} "
        for i in range(min_i,max_i):
            if red_pos[i]==1:
                to_print+=colored(seq[i],color)
            else:
                to_print+=seq[i]
        print(to_print)

In [36]:
def print_chunks(input_msa, red_pos, cs):
    for _ in range(int(len(input_msa[0])/cs)+1):
        print(f"\npositions {cs*_} to {cs*(_+1)}")
        print_seqs_color(input_msa, cs*_, cs*(_+1), red_pos)

In [37]:
def save_sto(msa, msa_name, sto_file_name):
    fh = open(sto_file_name, "w")
    fh.write("# STOCKHOLM 1.0\n")
    fh.write(f'#=GF {msa_name}\n')
    for n,seq in enumerate(msa):
        fh.write(f"seq{n}\t{seq}\n")
    fh.write("//")
        

In [38]:
def print_aln_matrix(input_aln, output_aln, hard_aln = True):
    if hard_aln:
        output_aln = (output_aln>.5)
    for _ in range(input_aln.shape[0]):
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
        ax1.imshow(input_aln[_,...])
        ax1.set_title(f"input seq {_}")
        ax2.imshow(output_aln[_,...])
        ax2.set_title(f"output seq {_}")
        ax3.imshow(output_aln[_,...]-input_aln[_,...], vmin = -1, vmax = 1, cmap="RdBu")
        ax3.set_title("difference")
        plt.show()

In [39]:
def check_seqs_same(input_msa, output_msa):
    for _ in range(len(input_msa)):
        if input_msa[_].replace("-","")!=output_msa[_].replace("-",""):
            print(f"not the same sequence at position {_}")
            print(input_msa[_].replace("-",""))
            print(output_msa[_].replace("-",""))
            raise ValueError(f"not the same sequence at position {_}")
    return True

In [40]:
def highlist_ref_seqs_pos(positions, red_pos):
    #red_pos = 0 when corresponds to ref seq
    num_zeros = 0
    zero_locs = []
    for _ in range(len(red_pos)):
        if red_pos[_]==0:
            zero_locs.append(_)
    highlights = np.zeros_like(red_pos)
    highlights[[zero_locs[_] for _ in positions]] = 1 
    return highlights

## Save all MSAs

In [22]:
adv_loss = "None"
out_path_base = "2_26"
mode = "random"
e_val = "None"
temp = "None"
for DOM in ["T1064-D1","T1070-D1","T1043-D1","T1039-D1"]:
    out_path = f"{out_path_base}/{DOM}_e_{e_val}_t_{temp}"
    npy_file = f"{out_path}/{DOM}.traj.{adv_loss}.{mode}.npy"
    print(f"processing {DOM}")
    INPUTS = prep_inputs(DOM)
    inputs = {"x":INPUTS["ms"], "aln":INPUTS["aln"], "lengths":INPUTS["lens"]}
    outputs = pickle.load(open(f"{out_path}/{DOM}.{adv_loss}.{mode}.out_dict.best", "rb"))
    params = pickle.load(open(f"{out_path}/{DOM}.{adv_loss}.{mode}.p_dict.best", "rb"))
    lr, seed = params["lr"], params["seed"]
    print(f"lr: {lr} seed: {seed}")
    input_msa, input_red_pos = get_readable_msa(inputs['x'], inputs["aln"], inputs["lengths"])
    output_msa, output_red_pos = get_readable_msa(inputs['x'], outputs["aln"], inputs["lengths"])
    check_seqs_same(input_msa, output_msa)
    save_sto(output_msa, DOM, f"{DOM}_opt_aln_{temp}.sto")
    save_sto(input_msa, DOM, f"{DOM}_input_aln.sto")

processing T1064-D1
lr: 0.01 seed: 6
processing T1070-D1
lr: 0.001 seed: 85
processing T1043-D1
lr: 0.001 seed: 70
processing T1039-D1
lr: 0.01 seed: 66


In [18]:
adv_loss = "None"
out_path_base = "2_26"
mode = "random"
e_val = "None"
temp = "Gentle_Cool"
for DOM in ["T1064-D1","T1070-D1","T1043-D1","T1039-D1"]:
#for DOM in ["T1043-D1","T1039-D1"]:
    out_path = f"{out_path_base}/{DOM}_e_{e_val}_t_{temp}"
    npy_file = f"{out_path}/{DOM}.traj.{adv_loss}.{mode}.npy"
    print(f"processing {DOM}")
    INPUTS = prep_inputs(DOM)
    inputs = {"x":INPUTS["ms"], "aln":INPUTS["aln"], "lengths":INPUTS["lens"]}
    outputs = pickle.load(open(f"{out_path}/{DOM}.{adv_loss}.{mode}.out_dict.best", "rb"))
    params = pickle.load(open(f"{out_path}/{DOM}.{adv_loss}.{mode}.p_dict.best", "rb"))
    lr, seed = params["lr"], params["seed"]
    print(f"lr: {lr} seed: {seed}")
    input_msa, input_red_pos = get_readable_msa(inputs['x'], inputs["aln"], inputs["lengths"])
    output_msa, output_red_pos = get_readable_msa(inputs['x'], outputs["aln"], inputs["lengths"])
    check_seqs_same(input_msa, output_msa)
    save_sto(output_msa, DOM, f"{DOM}_opt_aln_{temp}.sto")
    save_sto(input_msa, DOM, f"{DOM}_input_aln.sto")

processing T1064-D1


2022-03-02 08:45:30.086034: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /n/helmod/apps/centos7/Core/cudnn/8.1.0.77_cuda11.2-fasrc01/lib64:/n/helmod/apps/centos7/Core/cuda/11.1.0-fasrc01/cuda/extras/CUPTI/lib64:/n/helmod/apps/centos7/Core/cuda/11.1.0-fasrc01/cuda/lib64:/n/helmod/apps/centos7/Core/cuda/11.1.0-fasrc01/cuda/lib
2022-03-02 08:45:30.086078: W external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)


lr: 0.0001 seed: 65
processing T1070-D1
lr: 0.01 seed: 60
processing T1043-D1
lr: 0.001 seed: 37
processing T1039-D1
lr: 0.0001 seed: 68


## View learned alignment of 1039

In [41]:
adv_loss = "None"
out_path_base = "2_26"
mode = "random"
e_val = "None"
temp = "None"
DOM = "T1039-D1"    
out_path = f"{out_path_base}/{DOM}_e_{e_val}_t_{temp}"
print(f"processing {DOM}")
INPUTS = prep_inputs(DOM)
inputs = {"x":INPUTS["ms"], "aln":INPUTS["aln"], "lengths":INPUTS["lens"]}
outputs = pickle.load(open(f"{out_path}/{DOM}.{adv_loss}.{mode}.out_dict.best", "rb"))
params = pickle.load(open(f"{out_path}/{DOM}.{adv_loss}.{mode}.p_dict.best", "rb"))
lr, seed = params["lr"], params["seed"]
print(f"lr: {lr} seed: {seed}")
input_msa, input_red_pos = get_readable_msa(inputs['x'], inputs["aln"], inputs["lengths"])
output_msa, output_red_pos = get_readable_msa(inputs['x'], outputs["aln"], inputs["lengths"])
check_seqs_same(input_msa, output_msa)

processing T1039-D1
lr: 0.01 seed: 66


True

In [42]:
print("input MSA")
print_chunks(input_msa, input_red_pos, 100)
print("\noutput MSA")
print_chunks(output_msa, output_red_pos, 100)

input MSA

positions 0 to 100
000 NNPI[34m-[0m[34m-[0m[34m-[0mS[34m-[0m[34m-[0mSKLTEYYTN[34m-[0mF[34m-[0m[34m-[0m[34m-[0mKY[34m-[0mK[34m-[0mIL[34m-[0m[34m-[0mP[34m-[0m[34m-[0m[34m-[0m[34m-[0mG[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0mGKLNKGKLK[34m-[0m[34m-[0mDL[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0mQSTV[34m-[0m[34m-[0m[34m-[0mT[34m-[0m[34m-[0m[34m-[0mSLLEKTRKE[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0mN[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m
001 --PI[34m-[0m[34m-[0m[34m-[0mD[34m-[0m[34m-[0mNKLKDYYVN[34m-[0mF[34m-[0m[34m-[0m[34m-[0mKN[34m-[0mL[34m-[0mFL[34m-[0m[34m-[0mD[34m-[0m[34m-[0m[34m-[0m[34m-[0mK[34mK[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0mNNLDKESLK[34m-[0m[34m-[0mSI[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0m[34m-[0mKKEV[34m-[0m[34m-[0m[34m-[0mG[34m-[0m[34m

In [None]:
# for figure, manually trimmed and kept sequences 0, 2,8, 9, 11, 14