
# Structural Alignments with MMTF

### Features/Summary
Align a set of PDB structure against a reference structure. Works on chain level.
A sequence alignment is performed, followed by a structural alignment (SVD).
Visualization features (1:1 or 1:n alignments).
Additional functions for searching sets of PDB structures by name.



### Team Members
Florian, Brian, Ariel, Sebastian, Juexin, Xinlian, Spencer, Haipeng, Hansaim, Ansong


# Configuration/Imports

In [0]:
# Imports
import Bio
from Bio.SVDSuperimposer import SVDSuperimposer
from Bio import pairwise2

import mmtf
import py3Dmol
from pyspark import SparkContext
from pyspark.sql import SparkSession
from mmtfPyspark.io import mmtfReader, mmtfWriter
from mmtfPyspark.utils import traverseStructureHierarchy, ColumnarStructure
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains, StructureToPolymerSequences
from mmtfPyspark.webfilters import AdvancedQuery, ChemicalStructureQuery, PdbjMineSearch
from mmtfPyspark.datasets import g2sDataset, pdbjMineDataset, myVariantDataset
from mmtfPyspark.structureViewer import view_structure

import numpy as np

from io import BytesIO
import requests
from ipywidgets import IntSlider, interact



# MMTFStructAlign Class

In [0]:
class MMTFStructAlign:
  def __init__(self, reference_pdb, reference_chain, target_set):
    """"Initializes the class with a reference PDB chain."""
    self.spark = self.initialize_spark_session()
    self.passed = set()
    self.rotran_structures = {}
    self.target_set = target_set
    self.reference_pdb = reference_pdb.upper()
    self.reference_chain = reference_chain
    self.reference_structure, self.reference_sequence = self.get_reference_data()
    self.reference_arrays = ColumnarStructure(self.reference_structure.flatMap(StructureToPolymerChains()).filter(lambda x: x[0] == f'{reference_pdb}.{reference_chain}').values().first(), firstModelOnly=True)
    self.target_structures, self.target_sequences = self.get_target_data()
    self.atom_mappings = {} # Contains for each alignment the lists of reference and target atoms
    self.residue_sets = {} # Contains for each alignment the lists of aligned residues in reference and target
    self.structural_alignment()
    
  def initialize_spark_session(self):
    """Start the Spark session."""
    spark = SparkSession.builder.master("local[*]").appName("1-Input").getOrCreate()
    return spark.sparkContext

  def get_reference_data(self):
    """Get structure and sequence for the reference structure"""
    structure = mmtfReader.download_full_mmtf_files([self.reference_pdb],self.spark)
    sequences = structure.flatMap(StructureToPolymerChains()).collect()
    sequence = [self.extract_structure_from_protein_chain(x[1]) for x in sequences if x[0] == f'{self.reference_pdb}.{self.reference_chain}'][0]
    return structure, sequence
  
  def get_target_data(self):
    """Get structure and sequence for the target structure"""
    structures = mmtfReader.download_full_mmtf_files(self.target_set,self.spark)
    sequences = {x:self.extract_structure_from_protein_chain(y) for x,y in structures.flatMap(StructureToPolymerChains()).collect()}
    for target_id in sequences:
      if sequences[target_id] != None:
        self.passed.add(target_id)
    return structures, sequences
  
  def calculate_sequence_mapping_pair(self, target_id):
    """Perform a pairwise global sequence alignment to find corresponding sets of residues."""
    # Perform pairwise alignment
    target_sequence = self.target_sequences[target_id]
    alignments = pairwise2.align.globalxx(self.reference_sequence, target_sequence)
    print(pairwise2.format_alignment(*alignments[0]))
    aligned_ref, aligned_target, score, start, end = alignments[0]
    res_array_ref, res_array_target = [], []
    num1, num2 = 0, 0
    for seq1, seq2 in zip(aligned_ref, aligned_target):
      if seq1 != '-' and seq2 != '-':
        res_array_ref.append(num1)
        res_array_target.append(num2)
      if seq2 != '-':
        num2 += 1
      if seq1 != '-':
        num1 += 1
        
    res_array_ref = np.array(res_array_ref)
    res_array_target = np.array(res_array_target)
    return res_array_ref, res_array_target
 
  def structural_alignment(self):
    """Perform structural alignments of the reference against all target structures."""
    # get numpy representation of reference structures
    # compute alignment for each target against reference
    target_chains = self.target_structures.flatMap(StructureToPolymerChains()).collect()
    for target_chain in target_chains:
      if target_chain[0] in self.passed:
        self.align(target_chain)
      
    return None
  
  def align(self,target_chain):
    """Perform a structural alignment of the reference against a single target structure."""
    pdb_chain_target = target_chain[0]
    title = f"Aligning {self.reference_pdb}.{self.reference_chain} against {pdb_chain_target}"
    length_title = len(title)
    print(title)
    print(''.join('='*length_title))
    
    # get mapping of residues from sequence alignnment
    mapping = self.calculate_sequence_mapping_pair(pdb_chain_target)
    
    arrays = ColumnarStructure(target_chain[1]) 
    
    # get target coordinates
    x_target = np.asarray(arrays.get_x_coords())
    y_target = np.asarray(arrays.get_y_coords())
    z_target = np.asarray(arrays.get_z_coords())

    # determine residues that should be mapped
    target_atom_to_group_indices = arrays.get_atom_to_group_indices()
    reference_atom_to_group_indices = self.reference_arrays.get_atom_to_group_indices()
    
    # In case residues are missing from the structure, remove them here from the selection
    residues_target = set(target_atom_to_group_indices)
    
    # In case there are only unknown aa in the target ('XXXXX...'), skip it
    residues_reference = set(reference_atom_to_group_indices)
    mapping_target_refined_set = set()
    mapping_reference_refined_set = set()
    for res_ref, res_tar in zip(mapping[0], mapping[1]):
      if res_ref in residues_reference and res_tar in residues_target:
        mapping_target_refined_set.add(res_tar)
        mapping_reference_refined_set.add(res_ref)
    
    self.residue_sets[pdb_chain_target] = (mapping_reference_refined_set, mapping_target_refined_set)
    
    
    boolean_array_reference = np.array(([True if x in mapping_reference_refined_set else False for x in reference_atom_to_group_indices]))
    boolean_array_target = np.array(([True if x in mapping_target_refined_set else False for x in target_atom_to_group_indices]))
    
    # get per-atom information
    atom_names_target = arrays.get_atom_names()
    entity_types_target = arrays.get_entity_types()
    
    # reduce atom information of target
    reduced_target_idx = (atom_names_target == 'CA') & (boolean_array_target) & (entity_types_target == 'PRO')
    x_target_reduced = x_target[reduced_target_idx]
    y_target_reduced = y_target[reduced_target_idx]
    z_target_reduced = z_target[reduced_target_idx]
    target_coordinates =  np.swapaxes(np.array([x_target_reduced,y_target_reduced,z_target_reduced]), 0, 1)
    
    # reduce atom information of reference
    atom_names_reference = self.reference_arrays.get_atom_names()
    entity_types_reference = self.reference_arrays.get_entity_types()
    x_reference = np.asarray(self.reference_arrays.get_x_coords())
    y_reference = np.asarray(self.reference_arrays.get_y_coords())
    z_reference = np.asarray(self.reference_arrays.get_z_coords())
    reduced_reference_idx = (atom_names_reference == 'CA') & (boolean_array_reference) & (entity_types_reference == 'PRO')

    x_reference_reduced = x_reference[reduced_reference_idx]
    y_reference_reduced = y_reference[reduced_reference_idx]
    z_reference_reduced = z_reference[reduced_reference_idx]
    reference_coordinates = np.swapaxes(np.array([x_reference_reduced,y_reference_reduced,z_reference_reduced]), 0, 1)
    
    if len(target_coordinates) != len(reference_coordinates):
      print(f'Got different lengths of atom coordinates after alignment and mapping: {len(target_coordinates)} vs {len(reference_coordinates)}')
    
    self.atom_mappings[pdb_chain_target] = (reference_coordinates, target_coordinates)
    
    sup = SVDSuperimposer()
    sup.set(reference_coordinates, target_coordinates)
    sup.run()
    rmsd = sup.get_rms()
    rotran_matrix = sup.get_rotran()
    print(f'Alignment has RMSD of {rmsd}\n\n')
    self.rotran_structures[pdb_chain_target] = self.generate_rotran(target_chain, rotran_matrix)
    
  def generate_rotran(self, target, rotran_matrix):
    """Transform coordinates of target structure"""
    all_coords = []
    for x,y,z in zip(target[1].x_coord_list,target[1].y_coord_list,target[1].z_coord_list):
      all_coords.append((x,y,z))
    transformed_x, transformed_y, transformed_z = [], [], []
    transformed_coords = np.dot(all_coords, rotran_matrix[0]) + rotran_matrix[1]
    for x,y,z in transformed_coords:
        transformed_x.append(x)
        transformed_y.append(y)
        transformed_z.append(z)
    target[1].x_coord_list = transformed_x
    target[1].y_coord_list = transformed_y
    target[1].z_coord_list = transformed_z
    return target
  
  def extract_structure_from_protein_chain(self, mmtf_structure):
    """Given an MMTFEncoder structure, extract the real sequence as appearing in the protein"""
    arrays = ColumnarStructure(mmtf_structure) 
    groups = arrays.get_atom_to_group_indices()
    seq3 = arrays.get_group_names()
    seq1 = []
    for aa in seq3:
      aa_std = aa.lower().capitalize()
      if aa_std in Bio.Data.IUPACData.protein_letters_3to1:
        seq1.append(Bio.Data.IUPACData.protein_letters_3to1[aa_std])
      else:
        seq1.append('X')
    y = []
    max_num = -1

    for ele in groups:
        if ele > max_num:
            y.append(True)
        else:
            y.append(False)
        max_num = ele
    sequence_from_target_structure = (''.join([e for i, e in enumerate(seq1) if y[i]]))
    if set(sequence_from_target_structure) == set('X'):
      return None
    else:
      return sequence_from_target_structure
    
  def _get_3dmol_selection(self, target_id):
    """Get a selection for 3DMOL visualization"""
    # Alignment as residue indices
    ref_residx, tar_residx = self.residue_sets[target_id]
    
    # Convert reference residue indices to residue numbers
    ref_g2a = self.reference_arrays.get_group_to_atom_indices()
    ref_atmidx = {ref_g2a[i] for i in ref_residx}
    ref_gnum = self.reference_arrays.get_group_numbers()
    ref_grpname = {ref_gnum[i] for i in ref_atmidx}
    
    # build 3Dmol selection
    ref_sel = {'model': 0, 'chain': self.reference_chain, 'resi': list(ref_grpname)}
    
    # Convert target residue indices to residue numbers
    tar_struct = self.rotran_structures[target_id][1]
    tar_array = ColumnarStructure(tar_struct, firstModelOnly=True)
    tar_g2a = tar_array.get_group_to_atom_indices()
    tar_atmidx = {tar_g2a[i] for i in tar_residx}
    tar_gnum = tar_array.get_group_numbers()
    tar_grpname = {tar_gnum[i] for i in tar_atmidx}
    
    # build 3Dmol selection
    tar_sel = {'model': 1, 'resi': list(tar_grpname)}

    return ref_sel, tar_sel

  def visualize_alignments(self, target):
    """Visualize a single alignment"""
    
    """Show the 3D visualization"""
    viewer = py3Dmol.view()
    # Add reference structure
    reference_dump = mmtfWriter.to_mmtf_base64(self.reference_structure.first()[1])
    viewer.addModel(reference_dump, 'mmtf')
    # Add target structures
    selected_target = self.rotran_structures[target]
    target_id, target_struct = selected_target
    target_dump = mmtfWriter.to_mmtf_base64(target_struct)
    viewer.addModel(target_dump, 'mmtf')
    #viewer.setStyle({'cartoon': {'hidden':True}}) # Hide everything
    reference_selection = {'model': 0, 'chain': self.reference_chain}
    target_selection = {'model': 1}

    ref_sel,target_sel = self._get_3dmol_selection(target)
    
    
    red = '0xd95f02'
    lightred = "0xfedec5"
    purple = '0x7570b3'
    lightpurple = '0xe3e2ef'
    viewer.setStyle({'model': 0}, {})
    viewer.setStyle(reference_selection, {'cartoon': {'color': lightred}}) # Reference protein, selected chain
    viewer.setStyle(ref_sel, {'cartoon': {'color': red}})
    viewer.setStyle(target_selection, {'cartoon': {'color': lightpurple}}) # Target protein
    viewer.setStyle(target_sel, {'cartoon': {'color': purple}})

    # Add labels
    viewer.center()
    viewer.zoomTo()
    return viewer.show()

  def visualize_slider(self):
    """Visualization with slider for all targets. Doesn't work in colaboratory"""
    targets = list(self.residue_sets.keys())
    s_widget = IntSlider(min=0, max=len(targets)-1, description='Alignment', continuous_update=False)
    return interact(lambda i: self.visualize_alignments(targets[i]), i=s_widget)

  def visualize_all(self):
    """superimpose all targets"""
    viewer = py3Dmol.view()
    # Add reference structure
    reference_dump = mmtfWriter.to_mmtf_base64(self.reference_structure.first()[1])
    viewer.addModel(reference_dump, 'mmtf')
    # Add target structures
    for selected_target in self.rotran_structures.values():
        target_id, target_struct = selected_target
        target_dump = mmtfWriter.to_mmtf_base64(target_struct)
        viewer.addModel(target_dump, 'mmtf')
    #viewer.setStyle({'cartoon': {'hidden':True}}) # Hide everything
    reference_selection = {'model': 0, 'chain': self.reference_chain}
    target_selection = {'model': 1}    
    
    red = '0xd95f02'
    lightred = "0xfedec5"
    purple = '0x7570b3'
    lightpurple = '0xe3e2ef'
    
    viewer.setStyle({}, {'cartoon': {'color': lightpurple}}) # Target protein    
    viewer.setStyle({'model': 0}, {})
    viewer.setStyle(reference_selection, {'cartoon': {'color': red}}) # Reference protein, selected chain

    # Add labels
    viewer.center()
    viewer.zoomTo()
    return viewer.show()

  
  
    

In [0]:
spark = SparkSession.builder.master("local[4]").appName("proteinMultiplePDB").getOrCreate()
sc = spark.sparkContext

path = "/tmp/full"

def find_by_gene_name (query_name):
    """Find PDB structures by name."""
    #query_gene_name = "KRAS"
    query = (
      "<orgPdbQuery>"
        "<queryType>org.pdb.query.simple.UniprotGeneNameQuery</queryType>"
        "<query>"+query_name+"</query>"
      "</orgPdbQuery>"
    )
    pdb = mmtfReader.read_sequence_file(path, sc).cache()
    trimer_of_trimers = pdb.filter(AdvancedQuery(query))
    
    return pdb

def find_by_uniprot_name (query_name):
    """Find PDB structures by UniProt name."""
    #query_gene_name = "KRAS"
    query = (
      "<orgPdbQuery>"
        "<queryType>org.pdb.query.simple.UpAccessionIdQuery</queryType>"
        "<accessionIdList>"+query_name+"</accessionIdList>"
      "</orgPdbQuery>"
    )
    pdb = mmtfReader.read_sequence_file(path, sc).cache()
    trimer_of_trimers = pdb.filter(AdvancedQuery(query))
    
    return trimer_of_trimers

def find_G2S_name(query_name):
    """Find PDB structures by G2S name."""
    G2S_REST_URL = "https://g2s.genomenexus.org/api/alignments/uniprot/"

    data = []

    url = G2S_REST_URL + query_name
    #print(url)
    try:
        req = requests.get(url)
    except:
        print(f"WARNING: could not load data for: {query_name}")
        return data

    if b"\"error\":\"Not Found\"" in req.content:
        print(f"WARNING: could not load data for: {query_name}")
        return data

    results = [inputStream.decode() for inputStream in BytesIO(req.content).readlines()]

    records=results[0].split(",")
    for record in records:
        if record.startswith( '\"pdbId\"' ):           
            temp=record.split(":\"")
            data.append(temp[1].split("\"")[0].upper())
    pdbids = list(set(data))
    #print(pdbids)

    #trimer_of_trimers = mmtfReader.read_sequence_file(path, sc, pdbids).cache()
    #trimer_of_trimers = mmtfReader.download_full_mmtf_files(pdbids, sc)

    return pdbids


# Examples
- the alignment is computed by doing an initial global sequence alignment and aligning corresponding residue with each other
- the following example will align structures of histidyl-tRNA synthetases (HisRS) from different organisms
- **1ADJ.A**, a HisRS from *Thermus thermophilus*, is used as reference chain onto which all chains of 3HRI, 4YRE, and 5E3I are aligned


In [234]:
# Perform a structural alignment between aaRS structures
m = MMTFStructAlign(reference_pdb='1ADJ', reference_chain='A', target_set=['3HRI', '4YRE', '5E3I'])

Aligning 1ADJ.A against 3HRI.A
---TARA--VR-GTK--DLFGK--ELR-MHQRI-------VA---TARK---------VLEAAGAL-E-LVTP-IF----EE-T-QV-FEKGVGAATD--I---V--RK-EMFTFQD----RG-----GRSLTLR-PEGTAAMVRAYLEHGMKVW---PQPVRL-WMAGPMFRA-E---RPQKGRYRQF--H-QV-NYEAL----G--S---ENPI--LD--AEAVVLLYECLKE-----LGLRRLKVKLSS--VGDPEDRARYNAYLREVL---SPHREA--LS----EDSKE---------RLEENPMR-------I-LDSKSER----DQA----L-LKELGVRPML--DFLGEE-A-RAHLK--EVERHLERLS-VPY-EL-------EPAL------V-------RGLD-YYVRTAFE--VHHE-EIGAQS--------ALG-GGGRYDGLSE-LLG---G-PR--VPG-VGFA-FG--VER-VA-LALEA-EGFG--L-PEEKGPDLY------L-IPLTEEAVAEAFYLA-E-----AL---RPRLRAEYAL-AP-RKP-A-----KGLEEALKRGA----AF--AGFLGEDELRA-GEVT-LKR--L-AT------GEQVR---LSREEVPGYLLQA--LG---
   |     |  |    | |    |   |  |        |    ||||         |||      | |    |     || | |  |           |   |  |  || |       |      |||| |  |        |      | |   ||     |      |  |   |   || |    | |  |        |  |   |     |   | |      |        ||       |||  ||              |    |  |    |     || |          | ||         

Given a Gene name or Uniprot Protein name, output all the available structures in PDB and output the alignments

In [53]:
query_uniprot_name = "Q8IXA5"
pdbIds = find_G2S_name(query_uniprot_name)
print(pdbIds)
m2 = MMTFStructAlign(reference_pdb=pdbIds[0], reference_chain='A', target_set=pdbIds[1:10])
m2.visualize_alignments(pdbIds[1]+'.A')

['1UBZ', '1GE4', '1I1Z', '1INU', '133L', '1TAY', '1CKD', '1JKB', '1B7S', '1GB2', '1CJ7', '1B5X', '1WQO', '2ZWB', '1GFU', '1GE3', '2ZIL', '1YAN', '2ZIJ', '1OUH', '1GAZ', '1DI3', '1I22', '1D6Q', '1IP2', '1OUC', '1IP7', '2ZIK', '1REZ', '4YF2', '1OUF', '1GB0', '1GF3', '5LVK', '1CJ9', '1GFV', '1CJ8', '1IP3', '1GF7', '1GFJ', '3LN2']
Aligning 1UBZ.A against 1GB2.A
KV-FERCELARTLKRLGMDGYRGISLANWMCLAKWESGYNTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRE-PQGIRAWVAWRNRCQNRDVRQYVQGCGV
|  |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||  ||||||||||||||||||||||||||||
K-MFERCELARTLKRLGMDGYRGISLANWMCLAKWESGYNTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVR-DPQGIRAWVAWRNRCQNRDVRQYVQGCGV
  Score=128

Alignment has RMSD of 0.43361361041029994


Aligning 1UBZ.A against 1GE4.A
KVFERCELARTLKRLGMDGYRGISLANWMCLAKWESGYNTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRE-PQGIRAWVAWRNRCQN-RDVRQYVQGCGV
||||||||||||||||||||

## Visualization
- to visualize the alignment, the function `visualize_alignments` has to be called

In [236]:
m.visualize_alignments('3HRI.A')

In [237]:
m.visualize_alignments('4YRE.A')

In [31]:
m.visualize_all()

In [32]:
m.visualize_slider()

<function __main__.MMTFStructAlign.visualize_slider.<locals>.<lambda>>

# Alternative alignment implementation

In [170]:
import math;


##return a list a:
## a[0] = g;
## a[1] = h;
def ROTATE(a, i, j, k, l, s, tau):
	result = [];

	g=a[i][j];
	h=a[k][l];
	a[i][j] = g-s*(h+g*tau);
	a[k][l] = h+s*(g-h*tau);

	result.append(g);
	result.append(h);

	return result;





# 
# Computes al eigenvalues and eigenvectors of a real symmetric matrix
# a[1..n][1.n]. On output, elements above the diagonal are destroyed.
# Algorithm taken from Numerical Recipies pg#467.
# The output of the algorithm is unsorted.
# 
# param a the input matrix (a list of lists).
# param n the dimension of the matrix (an integer).
# param d the eigenvalues are returned here (a list).
# param v the eigenvectors are returned here (a list of lists).
# output nrot the number of Jacobi rotations required (returned as output, integer).
# 
def jacobi ( a, n, d, v ):
#  int j, iq, ip, i;
#  double tresh, theta, tau, t, sm, s, h, g, c, *b, *z;
	b = [];
	z = [];
	tresh = 0;

	# Allocate, initialzie
	for i in range(0, n) :
		b.append(0);
		z.append(0);

	# Initialize to the identity matrix.
	for ip in range( 0, n) :
		for iq in range(0, n) :
			v[ip][iq] = 0.0;
			v[ip][ip] = 1.0;
			
	for ip in range(0, n) :
		b[ip] = a[ip][ip];
		d[ip] = a[ip][ip];
		z[ip] = 0.0;

	#Initialize the number of rotations.
	nrot = 0;

	#main loop
	for i in range(0, 50) :
		##sum off diagonal elements.
		sm = 0;
		for ip in range(0, n-1) :
			for iq in range(ip+1, n):
				sm += math.fabs (a[ip][iq]);

		# The normal return, which relies on the quadratic convergence
		# to machine overflow.
		if (sm == 0.0):
			return;

		#On the first 3 sweeps.
		if (i < 3):
			tresh = 0.2*sm/(n*n);
		else:
			tresh = 0.0;

		###
		for ip in range(0, n-1) :
			for iq in range(ip + 1, n) :
				g = 100 * math.fabs (a[ip][iq]) ;
	
				#After four sweeps, skip the rotation if the
				#off-diagonal element is small.
				if ((i > 3) and ( (math.fabs (d[ip] + g)) == math.fabs (d[ip])) and (math.fabs (d[iq] + g) == math.fabs(d[iq])) ):
					a[ip][iq] = 0;
				elif (math.fabs(a[ip][iq]) > tresh):
					h = d[iq] - d[ip];

					if ( (math.fabs(h) + g) ==  math.fabs (h)):
						t = (a[ip][iq]) / h;
					else:
						theta = ( 0.5 * h / (a[ip][iq]) );
						t = ( 1.0 / (math.fabs (theta) + math.sqrt (1.0 + theta * theta)) );
						if (theta < 0.0):
							t = -t;

					c = ( 1.0 / math.sqrt (1.0 + t * t) );
					s = t * c;
					tau = (s / (1.0 + c) );
					h= t * a[ip][iq];
					z[ip] -= h;
					z[iq] += h;
					d[ip] -= h;
					d[iq] += h;
					a[ip][iq] = 0;

					#Case of rotations 0<=j<p.
					for j in range(0, ip):
						temp = []
						temp = ROTATE(a,j,ip,j,iq, s, tau);
						g = temp[0];
						h = temp[1];

					#Case of rotations p<j<q.
					for j in range(ip+1, iq):
						temp = []
						temp = ROTATE(a,ip,j,j,iq, s, tau);
						g = temp[0];
						h = temp[1];
					
					#Case of rotations q<j<n.
					for j in range(iq+1, n):
						temp = []
						temp = ROTATE(a,ip,j,iq,j, s, tau);
						g = temp[0];
						h = temp[1];
					
					for j in range(0, n):
						temp = []
						temp = ROTATE(v,j,ip,j,iq, s, tau);
						g = temp[0];
						h = temp[1];
					
					nrot = nrot+1;

		# Update d and reinitialize z.
		for ip in range( 0, n):
			b[ip] += z[ip];
			d[ip] = b[ip];
			z[ip] = 0.0;


	print("Too many iterations in jacobi");
	return(-1);





	
	
	
# 
# Given the eigenvalues d and eigenvectors v as output from Jacobi,
# this routine sorts the eigenvalues into descending order, and
# rearranges the columns of V correspondingly. The method is straight
# insertion. Notice that O(N^2) is affordable since Jacobi is cubic.
# Algorithm taken from Numerical Recipies pg#468.
# 
# param d the eigenvalues (a list().
# param v the eigenvalues (a list of lists).
# param n the dimension of the matrix.
# 
def eigsrt ( d, v, n):
	i = 0;
	j = 0;
	k = 0;
	p = 0;
	for i in range(0, n):
#		p = d[k = i];							##FIXME
		p = d[i];
		k = i
		for j in range( i+1, n):
			if (d[j] >= p):
#				p = d[k = j];					##FIXME
				p = d[j];						
				k = j;

		# The current is not the maximum.
		if (i != k):
			d[k] = d[i];
			d[i] = p;
			for j in range(0, n):
				p = v[j][i];
				v[j][i] = v[j][k];
				v[j][k] = p;
				
















## 
## Calculates the eigenvalues and eigenvectors of a matrix.
## Assumes that the matrix is symmetric.
## 
## param A the matrix (a list of lists).
## param n the dimension of the matrix (an integer).
## param u the eigenvectors (a list of lists).
## param w the eigenvalues (a list).
## 
def eigen ( A, n, u, w):
	#int i, j;
	#double ** a;
	
	## Copy a, since it is to be damaged by the jacobi calculation funciton.
	a = [];
	for i in range(0, n):
		a.append( [] );
		for j in range(0, n) :
			a[i].append( 0 );
	for i in range(0, n):
		for j in range(0, n):
			a[i][j] = A[i][j];
			
	nrot = jacobi (a, n, w, u);
  
	# Sort the eigenvectors & values to descending order.
	eigsrt (w, u, n);










# 
# Computes the Euclidean distance between to vectors.
# 
# param v1 the first vector.
# param v2 the second vector.
# 
def distance ( v1, v2):
	v = [0, 0, 0];
	v[0] = v1[0] - v2[0];
	v[1] = v1[1] - v2[1];
	v[2] = v1[2] - v2[2];
	
	result = math.sqrt (  (v[0]*v[0]) + (v[1]*v[1]) + (v[2]*v[2])  );
	return result;









# 
# Normalizes the vector.
# 
# param vector the vector. A 3d list (3 elements)
# 
def normalize (v):
	length = (math.sqrt ( (v[0] * v[0]) + (v[1] * v[1]) + (v[2] * v[2]) ) );
	
	if(length < .000001):
		v[0] = 0.0;
		v[1] = 0.0;
		v[2] = 0.0;
		return;

	v[0] = v[0] / length;
	v[1] = v[1] / length;
	v[2] = v[2] / length;







# 
# Performs matrix-vector multiplication.
# Specialized for the 3D case.
# This function is a helper for the RMSD calculation.
# 
# param a the matrix (a list of lists).
# param v the vector (a list).
# param vr the resulting vector (a list).
# 
def matrix_vector_mult3D(a, v, vr):
	for i in range(0, 3):
		vr[i] = (v[0] * a[0][i]) + (v[1] * a[1][i]) + (v[2] * a[2][i]);











# 
# Performs matrix-matrix multiplications.
# Specialized for the 3D case.
# This function is a helper for the RMSD calculation.
# 
# param a the first matrix.
# param b the second matrix.
# param c the result matrix.
# 
# REQUIRES INITIALIZED LIST C.
#
def matrix_matrix_mult3D(a, b, c):
	#Clear C:
	for i in range(0, 4):
		for j in range(0, 4):
			c[i][j] = 0;
	

	# Complete the multiplication.
	for i in range(0, 4):
		for j in range(0, 4):
			for k in range (0, 4):
				c[i][j] += a[i][k] * b[k][j];
			











###########################################################################
###########################################################################
### min_rmsd: this performs the actual superposition.  Executive function to be called from elsewhere.
###########################################################################
###########################################################################

##############################################################################################
#  Calculates the minimized RMSD between two molecular conformations.
# 
#  param ref_orig: These are the coordinates that do not move.
#  param fit_orig: These are the coordinates that are rotated and translated 
#                   to superpose with the ref_orig points.
#  return A python list of several things:
##	first the RMSD. (a single double)
##	second the aligned points. (a python list of lists, where each sublist is a triplet)
##	third the transform (a python list of 16 elements )
##	fourth the distances (a python list of doubles indicating the distances between aligned points)
##                     (this is of length equal to the input)
# 
#	NOTE: THIS CODE IS MATHEMATICALLY UNSTABLE IF PROVIDED 3 OR LESS PAIRS OF POINTS TO ALIGN.
#
##############################################################################################
def min_rmsd(ref_orig, fit_orig):
	if( len(ref_orig) != len(fit_orig) ):
		print("ERROR: Input arrays are not the same length. Exiting.");
		return[];

	i = 0;
	j = 0;
	k = 0;
	
	cmass_ref = [];
	cmass_fit = [];
	ref = [];
	fit = [];
	atoms = len(ref_orig);

	##this is the array of arrays of output stuffs.
	transform = [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];

	##copy the data to avoid influencing data passed in.
	for i in range(0, atoms):
		ref.append([ref_orig[i][0], ref_orig[i][1], ref_orig[i][2] ]);
		fit.append([fit_orig[i][0], fit_orig[i][1], fit_orig[i][2] ]);

	cmass_ref = [0,0,0];
	cmass_fit = [0,0,0];

	##compute the center of mass of both ref and fit.
	for i in range(0, atoms):
		cmass_ref[0] += ref[i][0];
		cmass_ref[1] += ref[i][1];
		cmass_ref[2] += ref[i][2];
		cmass_fit[0] += fit[i][0];
		cmass_fit[1] += fit[i][1];
		cmass_fit[2] += fit[i][2];
		
	##now compute the average
	cmass_ref[0] = cmass_ref[0] / atoms ;
	cmass_ref[1] = cmass_ref[1] / atoms ;
	cmass_ref[2] = cmass_ref[2] / atoms ;
	cmass_fit[0] = cmass_fit[0] / atoms ;
	cmass_fit[1] = cmass_fit[1] / atoms ;
	cmass_fit[2] = cmass_fit[2] / atoms ;


	# Compute R. R(i,j) is the dot product of i'th column of the
	# normalized reference and j'th column of the normalized fit.
	# But normalized we mean subtracting the appropriate Centers of
	# Masses. 
	R = [
		[0, 0, 0, 0],
		[0, 0, 0, 0],
		[0, 0, 0, 0],
		[0, 0, 0, 0],
	]
	R[3][3] = 1;

	##Here we are computing the dot products for each R[i][j], as mentioned above.
	for i in range(0, 3):
		for j in range(0, 3):
			for k in range(0, atoms):
				R[i][j] += (ref[k][i] - cmass_ref[i]) * (fit[k][j] - cmass_fit[j]);

	#Transpose R
	Rt = [
		[0, 0, 0, 0],
		[0, 0, 0, 0],
		[0, 0, 0, 0],
		[0, 0, 0, 0],
	] 

	for i in range(0, 4):
		for j in range(0, 4):
			Rt[i][j] = R[j][i];

	#compute Rt * R
	RtR = [
		[0, 0, 0, 0],
		[0, 0, 0, 0],
		[0, 0, 0, 0],
		[0, 0, 0, 0],
	] 


	##This call is made with a call implemented above.
	##Probably it can be replaced with an outside method
	##but be careful.
	matrix_matrix_mult3D(Rt, R, RtR);

	#Compute the eigenvalues/eigenvectors of RtR.
	#The eigenvectors will come sorted.
 
	evalue = [0, 0, 0];
	evector = [
		[0, 0, 0],
		[0, 0, 0],
		[0, 0, 0],
	];
	
	##Call this method, implemented above, compute the eigen vectors.
	eigen (RtR, 3, evector, evalue);

	#Transpose so that eigenvectors become rows.
	temp_float = 0;
	##This is a simplistic transpose because it is only a 3x3 matrix.
	temp_float = evector[0][1];
	evector[0][1] = evector[1][0];
	evector[1][0] = temp_float;  
	temp_float = evector[0][2];
	evector[0][2] = evector[2][0];
	evector[2][0] = temp_float;
	temp_float = evector[1][2];
	evector[1][2] = evector[2][1];
	evector[2][1] = temp_float;

  # Now compute  b(i) = R*a(i).
	b = [
		[0, 0, 0],
		[0, 0, 0],
		[0, 0, 0],
	];

	matrix_vector_mult3D (Rt, evector[0], b[0]);
	matrix_vector_mult3D (Rt, evector[1], b[1]);
	matrix_vector_mult3D (Rt, evector[2], b[2]);

	normalize (b[0]);
	normalize (b[1]);
	normalize (b[2]);

  # Compute U = u(i,j) = sum(b(k,i) * a(k,j))
	U = [
		[0, 0, 0, 0],
		[0, 0, 0, 0],
		[0, 0, 0, 0],
		[0, 0, 0, 0],
	] 
	U[3][3] = 1;

	##Here we compute the product u(i,j) = sum(b(k,i) * a(k,j))
	for i in range(0, 3):
		for j in range(0, 3):
			for k in range(0, 3):
				U[j][i] += (b[k][i] * evector[k][j]);

	##cull numerical underflow
	for i in range (0, 4):
		for j in range (0, 4):
			if( math.fabs(U[i][j]) < .000001 ):
				U[i][j] = 0.0;

	##catch degenerate rotations:
	if( (U[0][0] == 0) and (U[0][1] == 0) and (U[0][2] == 0) and (U[1][0] == 0) and (U[2][0] == 0)):
		U[0][0] = 1.0;
	if( (U[1][1] == 0) and (U[0][1] == 0) and (U[2][1] == 0) and (U[1][0] == 0) and (U[1][2] == 0)):
		U[1][1] = 1.0;
	if( (U[2][2] == 0) and (U[2][1] == 0) and (U[2][0] == 0) and (U[1][2] == 0) and (U[0][2] == 0)):
		U[2][2] = 1.0;

	# Update the rotational part of transform
	transform[0] = U[0][0];	transform[4] = U[1][0];		transform[8] = U[2][0];		transform[12] = U[3][0];
	transform[1] = U[0][1];	transform[5] = U[1][1];		transform[9] = U[2][1];		transform[13] = U[3][1];
	transform[2] = U[0][2];	transform[6] = U[1][2];		transform[10] = U[2][2];		transform[14] = U[3][2];
	transform[3] = 0;		transform[7] = 0;			transform[11] = 0;			transform[15] = 1.0;

	##update the fit atoms, moving them to their superposed locations:
	for i in range(0, atoms) :
		x = (fit[i][0] * U[0][0]) + (fit[i][1] * U[1][0]) + (fit[i][2] * U[2][0]) + U[3][0];
		y = (fit[i][0] * U[0][1]) + (fit[i][1] * U[1][1]) + (fit[i][2] * U[2][1]) + U[3][1];
		z = (fit[i][0] * U[0][2]) + (fit[i][1] * U[1][2]) + (fit[i][2] * U[2][2]) + U[3][2];
	
		fit.append([x, y, z]);

	##Compute the new center of mass.
	cmass_fit[0] = 0;
	cmass_fit[1] = 0;
	cmass_fit[2] = 0;
	for i in range(0, atoms):
		cmass_fit[0] += fit[i][0];
		cmass_fit[1] += fit[i][1];
		cmass_fit[2] += fit[i][2];
	cmass_fit[0] = cmass_fit[0] / atoms;
	cmass_fit[1] = cmass_fit[1] / atoms;
	cmass_fit[2] = cmass_fit[2] / atoms;

	##get the transforms
	tx = cmass_ref[0] - cmass_fit[0];
	ty = cmass_ref[1] - cmass_fit[1];
	tz = cmass_ref[2] - cmass_fit[2];

	#Catch underflow
	if( math.fabs(tx) < 0.000001 ):
		tx = 0.0;
	if( math.fabs(ty) < 0.000001 ):
		ty = 0.0;
	if( math.fabs(tz) < 0.000001 ):
		tz = 0.0;

	##translate the fit.
	for i in range(0, atoms):
		fit[i][0] += tx;
		fit[i][1] += ty;
		fit[i][2] += tz;

	# Compound this onto the translation matrix, and set the scale to 1.
	transform[12] += tx;
	transform[13] += ty;
	transform[14] += tz;
	transform[15] = 1.0;

	##Now compute the RMSD of the fit and the reference.
	#Fill the distances Array with distance data
	result = 0.0;  ##result is the RMSD.
	distances = [];
	for i in range(0, atoms):
		 distances.append( distance (ref[i], fit[i]) );
		 result = result + (distance (ref[i], fit[i])) * (distance (ref[i], fit[i]));
	result /= atoms;
	result = math.sqrt(result);

	# /////////////////////////////RMSD CLAMP/////////////////////////////////////////////////
	# ////This hack fixes retarded RMSD issues assuming that geometry that can
	# ////be very very closely aligned is actually identical.
	# ////This is to catch the situation with RMSD = 1.234e-15 gives bad
	# ////Transformations far from identity.
	# /////////////////////////////RMSD CLAMP/////////////////////////////////////////////////
#	if(result < .000001):
#	  transform[0] = 1;	transform[4] = 0;		transform[8] = 0;		transform[12] = 0;
#	  transform[1] = 0;	transform[5] = 1;		transform[9] = 0;		transform[13] = 0;
#	  transform[2] = 0;	transform[6] = 0;		transform[10] = 1;		transform[14] = 0;
#	  transform[3] = 0;	transform[7] = 0;		transform[11] = 0;		transform[15] = 1;
	# /////////////////////////////RMSD CLAMP/////////////////////////////////////////////////
	# ////The problem is that while an excellent alignment can be calculated
	# ////(i.e. with RMSD = epsilon) the transformation is not accurately 
	# ////calculated and thus is not well defined numerically.  This hopefully
	# ////resolves this issue.
	# /////////////////////////////RMSD CLAMP/////////////////////////////////////////////////


	print("generating results ");
	##Return final results
	results = [];
	##first the RMSD.
	results.append(result);
	##second the aligned points.
	results.append(fit);
	##third the transform.
	results.append(transform);
	##fourth the distances
	results.append(distances);

	print("RMSD=%0.8d" % results[0]);
	print("generating results ");

	return results;




def main():

	print("TEST");

	test_ref = [
		[85.375,2.707,83.119],
		[86.666,6.168,82.313],
		[86.801,7.401,78.710],
		[88.456,10.409,77.113]];
		
	test_tar = [[54.020,22.949,94.945],
		[52.446,21.686,91.695],
		[51.965, 23.777, 88.541],
		[50.020,23.198,85.345]];
	
	print("TEST");

	output = min_rmsd(test_ref, test_tar);

	print("TEST");

	print(output);

	print("TEST");


if __name__ == "__main__":
	main()

TEST
TEST
generating results 
RMSD=00000003
generating results 
TEST
[3.8922252070634356, [[88.73175, 6.717749999999999, 85.12725], [87.15775, 5.454749999999997, 81.87725], [86.67675, 7.545749999999998, 78.72325000000001], [84.73175, 6.9667499999999976, 75.52725000000001], [-45.72596565144282, -93.15373508341354, 41.12040529046622], [-44.426069635019566, -89.69885123573629, 40.11581933319529], [-44.266857143328515, -88.37643830563243, 36.54129371085271], [-42.60508369105185, -85.34769531932072, 34.99277605962108]], [-0.9958335631011788, -0.061266290367632426, 0.0675422554083498, 0, -0.018132621632342805, -0.5928506017224396, -0.8051082983488916, 0, 0.08936846554450995, -0.8029785835886647, 0.5892696086377996, 0, 34.711749999999995, -16.231250000000003, -9.81774999999999, 1.0], [5.602406106977612, 0.9697534158228057, 0.19122287389326825, 5.313541774325292]]
TEST


# Binding site comparison of drug-target from drugbank mapping to PDB (experimental)


In [0]:
from pyspark.sql import Row, SparkSession
from mmtfPyspark.datasets import customReportService, drugBankDataset
from mmtfPyspark.structureViewer import view_binding_site
from mmtfPyspark.utils import traverseStructureHierarchy, ColumnarStructure
from mmtfPyspark import structureViewer
from scipy.spatial.distance import pdist, squareform
from pyspark.sql.functions import concat_ws
from mmtfPyspark.datasets import g2sDataset, pdbjMineDataset, myVariantDataset
from mmtfPyspark.filters import ContainsGroup
from mmtfPyspark.interactions import InteractionFilter, InteractionFingerprinter
from mmtfPyspark.io import mmtfReader
from ipywidgets import interact, IntSlider
import py3Dmol
import numpy as np
import pandas as pd


## 
def get_binding_site_interactions(ligand_drug_id):
    
   
    pdb_input = pdb.filter(ContainsGroup((ligand_drug_id))) # 
    
    interactionFilter = InteractionFilter()
    interactionFilter.set_distance_cutoff(4.5)
    interactionFilter.set_min_interactions(1)
    interactionFilter.set_query_groups(True, [ligand_drug_id])
    
    interactions = InteractionFingerprinter.get_ligand_polymer_interactions(pdb_input, interactionFilter).cache()
    
    return interactions

def get_coords_of_binding_site(group_numbers_from_interactions, pdb_id):
    
    coords_all = []
        
    for group_number in group_numbers_from_interactions:

        pdb = mmtfReader.download_full_mmtf_files([pdb_id], sc)

        structure = pdb.values().first()

        arrays = ColumnarStructure(structure, firstModelOnly=True)

        x = arrays.get_x_coords()
        y = arrays.get_y_coords()
        z = arrays.get_z_coords()

        group_numbers = arrays.get_group_numbers()

        ca_idx = (group_numbers == group_number)

        xc = x[ca_idx]
        yc = y[ca_idx]
        zc = z[ca_idx]

        coords = np.swapaxes(np.array([xc,yc,zc]), 0, 1)

        coords_all.append(coords)

    return coords_all


def all_coords_of_binding_site(pd_ligand_id_interactions, interaction_num):
        from itertools import chain

        interaction_all = get_coords_of_binding_site(pd_ligand_id_interactions.groupNumbers[interaction_num], pd_ligand_id_interactions.structureChainId[interaction_num][0:4])

        list1 = list(chain.from_iterable(interaction_all))
        list2 = [x.tolist() for x  in list1]

        return list2


def get_all_coords_of_one_ligand_id(pd_ligand_id_interactions, ligand_drug_id):
    
    binding_site_ligand = []
    
    print('run...')
#     for j in range(0, pd_ligand_id_interactions.shape[0]): # this is for a full run.
    for j in range(0, 2):   # this is just for example.
        
        print('run j:', j)
        binding_site_ligand.append(all_coords_of_binding_site(pd_ligand_id_interactions, j))
        print('run j DONE.')
        
        
    
    return binding_site_ligand


spark = SparkSession.builder.master('local[4]').appName('drugbank-PDB-binding_site').getOrCreate()

sc = spark.sparkContext

## Get approved drugs from DrugBank
drugs = drugBankDataset.get_drug_links('approved', 'hliu312@gmail.com', 'LHPlhp123456')
drugs = drugs.filter('InChIKey IS NOT NULL').cache()
drugs.toPandas().head(5)

## Get all ligands from PDB
ligands = customReportService.get_dataset(["ligandId","InChIKey","ligandMolecularWeight"])
ligands = ligands.filter("InChIKey IS NOT NULL AND ligandMolecularWeight > 300").cache()
ligands.toPandas().head(10)

# path = "/mmtf_full_sample"
path = './resources/mmtf_full_sample'
pdb = mmtfReader.read_sequence_file(path, sc)

ligand_drug_id = "STI"

ligand_drugs.toPandas().head()

print('Getting binding site interactions:')
pd_ligand_id_interactions = get_binding_site_interactions(ligand_drug_id).toPandas()
print('done')

temp1 = get_all_coords_of_one_ligand_id(pd_ligand_id_interactions, ligand_drug_id)

# min_rmsd is from Brain's function above.
min_rmsd(temp1[0][0:min(len(temp1[0]), len(temp1[1]))], temp1[1][0:min(len(temp1[0]), len(temp1[1]))])

spark.stop()
## generating results 
# RMSD=00000039              ## this is example of the RMSD of two binding sites of STI (Gleevac) in two target proteins.
# generating results