# Extract Atom Data from MMTF Files
This notebook extracts coordinates for a given set of protein chains and residues numbers provided in an input file.

In [None]:
# Disable Numba: temporary workaround for https://github.com/sbl-sdsc/mmtf-pyspark/issues/288
import os
os.environ['NUMBA_DISABLE_JIT'] = "1" 

In [2]:
import pandas as pd
from pyspark.sql import Row, SparkSession
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.utils import MmtfSubstructure

### Load data
Load all data as strings. MMTF represents residue numbers as strings, since the may include insertion codes (e.g., 10A)

In [3]:
df = pd.read_csv("../data/sample.csv", dtype=str)

In [4]:
df.head()

Unnamed: 0.1,Unnamed: 0,pdbId,structureChainId,pdbPosition,pdbResNum,uniprotId
0,0,2gf5,2GF5.A,117,117,Q13158
1,2,3ezq,3EZQ.P,117,117,Q13158
2,4,3ezq,3EZQ.N,117,117,Q13158
3,6,3ezq,3EZQ.L,117,117,Q13158
4,8,3ezq,3EZQ.J,117,117,Q13158


Create a list of unique pdbIds (MMTF requires upper case PDB IDs)

In [5]:
df['pdbId'] = df['structureChainId'].str.split('.', expand=True)[0]
pdbIds = list(df['pdbId'].unique())

In [6]:
len(pdbIds)

5

Group residues by structureChainId

In [7]:
df = df[['structureChainId', 'pdbResNum']]
chain_df = df.groupby('structureChainId').agg(lambda x: list(x))

In [8]:
chain_df.head(100)

Unnamed: 0_level_0,pdbResNum
structureChainId,Unnamed: 1_level_1
1A1W.A,"[63, 64, 8, 75, 73, 46, 14, 32, 66, 40]"
1A1Z.A,"[63, 64, 8, 75, 73, 46, 14, 32, 66, 40]"
2GF5.A,"[117, 113, 63, 131, 64, 158, 8, 75, 73, 138, 4..."
3EZQ.B,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."
3EZQ.D,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."
3EZQ.F,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."
3EZQ.H,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."
3EZQ.J,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."
3EZQ.L,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."
3EZQ.N,"[117, 113, 131, 158, 138, 181, 134, 103, 111, ..."


#### Initialize Spark

In [9]:
spark = SparkSession.builder.appName("ExtractCoords").getOrCreate()

In [10]:
# Enable Arrow-based columnar data transfers between Spark and Pandas dataframes
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

#### Print Spark configurations

Notice, `spark.driver.memory` has been set to 10G.

To set this memory option, add the following line to your .bash_profile file and restart the shell:

`export SPARK_DRIVER_MEMORY=10G`

In [11]:
spark.sql("SET").toPandas()

Unnamed: 0,key,value
0,spark.app.id,local-1593129272634
1,spark.app.name,ExtractCoords
2,spark.driver.host,192.168.1.4
3,spark.driver.memory,10G
4,spark.driver.port,51439
5,spark.executor.id,driver
6,spark.master,local[*]
7,spark.rdd.compress,True
8,spark.serializer.objectStreamReset,100
9,spark.sql.execution.arrow.enabled,true


### Read the MMTF Hadoop Sequence files
The environmental variable MMTF_FULL should be set to directory with the MMTF Hadoop Sequence files.
Make sure the MMTF Hadoop Sequence files are up to data (download instructions [here](http://mmtf.rcsb.org/download.html)).

In [12]:
pdb = mmtfReader.read_full_sequence_file()

Hadoop Sequence file path: MMTF_FULL=/Users/peter/MMTF_Files/full


In [13]:
if len(pdbIds) > 1000:
    # use local MMTF Hadoop Sequence files
    pdb = mmtfReader.read_full_sequence_file(pdbIds)
else:
    # for a small number of structures, it's less overhead to download them from RCSB PDB
    pdb = mmtfReader.download_full_mmtf_files(pdbIds)

### Extract atom data for specific chains and residues

In [14]:
def extract_atom_data(entry, chain_df):
    structure_id = entry[0] # pdbId
    structure = entry[1] # mmmtStructure
    
    rows = list()
    for chain_name in set(structure.chain_name_list):
        structure_chain_id = structure_id + '.' + chain_name
        
        # get list of residues numbers if structureChainId matches the current chain being processed
        group_numbers = chain_df.query(f"structureChainId == '{structure_chain_id}'")['pdbResNum']
        print(group_numbers)
        
        if len(group_numbers) == 1:
            # extract data for specific chain and list of residues(groups)
            s = MmtfSubstructure(structure, chain_name, chain_names=[chain_name], group_numbers=group_numbers)
           
            # add atom data to a row
            for i in range(s.num_atoms):
                rows.append(Row(structure_chain_id, s.group_names[i], s.group_numbers[i], s.atom_names[i], 
                                float(s.x_coord_list[i]), float(s.y_coord_list[i]), float(s.z_coord_list[i])))

    return rows

In [15]:
rows = pdb.flatMap(lambda entry: extract_atom_data(entry, chain_df))

Create a Spark dataframe from the rows of atom data, then convert to a Pandas dataframe

In [16]:
col_names = ['structureChainId', 'groupName', 'group_number', 'atom_name', 'x', 'y', 'z']
spark_df = spark.createDataFrame(rows, col_names) 

In [17]:
results = spark_df.toPandas()
results.to_csv("../data/All_PDB_info_xyz.csv", index=False)
results.head(25)

Unnamed: 0,structureChainId,groupName,group_number,atom_name,x,y,z
0,3EZQ.L,VAL,103,N,-99.361,81.342003,-83.228996
1,3EZQ.L,VAL,103,CA,-100.249001,80.994003,-82.113998
2,3EZQ.L,VAL,103,C,-100.664001,82.253998,-81.364998
3,3EZQ.L,VAL,103,O,-101.841003,82.453003,-81.077003
4,3EZQ.L,VAL,103,CB,-99.574997,79.969002,-81.156998
5,3EZQ.L,VAL,103,CG1,-100.477997,79.644997,-79.966003
6,3EZQ.L,VAL,103,CG2,-99.183998,78.676003,-81.931999
7,3EZQ.L,ASP,111,N,-105.887001,92.769997,-75.222
8,3EZQ.L,ASP,111,CA,-104.885002,93.328003,-74.305
9,3EZQ.L,ASP,111,C,-103.462997,93.223,-74.882004


### Always stop Spark. Multiple Spark jobs may interfere with each other and use up memory.

In [18]:
spark.stop()