# Extract Atom Data from MMTF Files
This notebook extracts coordinates for a given set of protein chains and residues numbers provided in an input file.

In [None]:
# Disable Numba: temporary workaround for https://github.com/sbl-sdsc/mmtf-pyspark/issues/288
import os
os.environ['NUMBA_DISABLE_JIT'] = "1" 

In [2]:
import pandas as pd
from pyspark.sql import Row, SparkSession
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.utils import MmtfSubstructure

### Load data
Load all data as strings. MMTF represents residue numbers as strings, since the may include insertion codes (e.g., 10A)

In [3]:
df = pd.read_csv("../data/sample.csv", dtype=str)

In [4]:
df.head()

Unnamed: 0.1,Unnamed: 0,pdbId,structureChainId,pdbPosition,pdbResNum,uniprotId
0,0,2gf5,2GF5.A,117,117,Q13158
1,2,3ezq,3EZQ.P,117,117,Q13158
2,4,3ezq,3EZQ.N,117,117,Q13158
3,6,3ezq,3EZQ.L,117,117,Q13158
4,8,3ezq,3EZQ.J,117,117,Q13158


Create a list of unique pdbIds (MMTF requires upper case PDB IDs)

In [5]:
df['pdbId'] = df['structureChainId'].str.split('.', expand=True)[0]
pdbIds = list(df['pdbId'].unique())

Group residues by structureChainId

In [6]:
df = df[['structureChainId', 'pdbResNum']]
chain_df = df.groupby('structureChainId').agg(lambda x: list(x))

In [7]:
chain_df.head()

Unnamed: 0_level_0,pdbResNum
structureChainId,Unnamed: 1_level_1
1A1W.A,"[63, 64, 8, 75, 73, 46, 14, 32, 66, 40]"
1A1Z.A,"[63, 64, 8, 75, 73, 46, 14, 32, 66, 40]"
2GF5.A,"[117, 113, 63, 131, 64, 158, 8, 75, 73, 138, 4..."
3EZQ.B,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."
3EZQ.D,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."


#### Initialize Spark

In [8]:
spark = SparkSession.builder.appName("ExtractCoords").getOrCreate()

In [9]:
# Enable Arrow-based columnar data transfers between Spark and Pandas dataframes
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

#### Print Spark configurations

Notice, `spark.driver.memory` has been set to 10G.

To set this memory option, add the following line to your .bash_profile file and restart the shell:

`export SPARK_DRIVER_MEMORY=10G`

In [10]:
spark.sql("SET").toPandas()

Unnamed: 0,key,value
0,spark.app.id,local-1592191519446
1,spark.app.name,ExtractCoords
2,spark.driver.host,192.168.1.4
3,spark.driver.memory,10G
4,spark.driver.port,56504
5,spark.executor.id,driver
6,spark.master,local[*]
7,spark.rdd.compress,True
8,spark.serializer.objectStreamReset,100
9,spark.sql.execution.arrow.enabled,true


### Read the MMTF Hadoop Sequence files
The environmental variable MMTF_FULL should be set to directory with the MMTF Hadoop Sequence files.
Make sure the MMTF Hadoop Sequence files are up to data (download instructions [here](http://mmtf.rcsb.org/download.html)).

In [11]:
if len(pdbIds) > 1000:
    # use local MMTF Hadoop Sequence files
    pdb = mmtfReader.read_full_sequence_file(pdbIds)
else:
    # for a small number of structures, it's less overhead to download them from RCSB PDB
    pdb = mmtfReader.download_full_mmtf_files(pdbIds)

### Extract atom data for specific chains and residues

In [12]:
def extract_atom_data(entry, chain_df):
    structure_id = entry[0] # pdbId
    structure = entry[1] # mmmtStructure
    
    rows = list()
    for chain_name in set(structure.chain_name_list):
        structure_chain_id = structure_id + '.' + chain_name
        
        # get list of residues numbers if structureChainId matches the current chain being processed
        group_numbers = chain_df.query(f"structureChainId == '{structure_chain_id}'")['pdbResNum']
        print(group_numbers)
        
        if len(group_numbers) == 1:
            # extract data for specific chain and list of residues(groups)
            s = MmtfSubstructure(structure, chain_name, chain_names=[chain_name], group_numbers=group_numbers)
           
            # add atom data to a row
            for i in range(s.num_atoms):
                rows.append(Row(structure_id, s.group_names[i], s.group_numbers[i], s.atom_names[i], 
                                float(s.x_coord_list[i]), float(s.y_coord_list[i]), float(s.z_coord_list[i])))

    return rows

In [13]:
rows = pdb.flatMap(lambda entry: extract_atom_data(entry, chain_df))

Create a Spark dataframe from the rows of atom data, then convert to a Pandas dataframe

In [14]:
col_names = ['structureChainId', 'groupName', 'group_number', 'atom_name', 'x', 'y', 'z']
spark_df = spark.createDataFrame(rows, col_names) 

In [15]:
results = spark_df.toPandas()
results.head(25)

Unnamed: 0,structureChainId,groupName,group_number,atom_name,x,y,z
0,1A1Z,LEU,8,N,-5.578,0.976,4.943
1,1A1Z,LEU,8,CA,-4.151,0.555,5.018
2,1A1Z,LEU,8,C,-3.778,0.229,6.469
3,1A1Z,LEU,8,O,-2.799,-0.441,6.727
4,1A1Z,LEU,8,CB,-3.259,1.689,4.508
5,1A1Z,LEU,8,CG,-3.344,1.762,2.983
6,1A1Z,LEU,8,CD1,-2.839,3.126,2.508
7,1A1Z,LEU,8,CD2,-2.478,0.658,2.371
8,1A1Z,LEU,8,H,-5.802,1.896,4.695
9,1A1Z,LEU,8,HA,-4.003,-0.321,4.404


### Always stop Spark. Multiple Spark jobs may interfere with each other and use up memory.

In [16]:
spark.stop()