# Machine learning - Protein Chain Classification

In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.

[Word2Vec model](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#word2vec)

[Word2Vec example](https://spark.apache.org/docs/latest/ml-features.html#word2vec)

## Imports

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder, SparkMultiClassClassifier, datasetBalancer   
from pyspark.sql.functions import *
from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, MultilayerPerceptronClassifier, RandomForestClassifier

#### Configure Spark 

In [2]:
spark = SparkSession.builder.appName("ProteinChainClassification").getOrCreate()

## Read MMTF File and create a non-redundant set (<=40% seq. identity) of L-protein clains

In [3]:
pdb = mmtfReader.read_sequence_file('../../resources/mmtf_reduced_sample/') \
                .flatMap(StructureToPolymerChains()) \
                .filter(Pisces(sequenceIdentity=40,resolution=3.0))

## Get secondary structure content

In [4]:
data = secondaryStructureExtractor.get_dataset(pdb)

## Define addProteinFoldType function

In [5]:
def add_protein_fold_type(data, minThreshold, maxThreshold):
    '''
    Adds a column "foldType" with three major secondary structure class:
    "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.

    The simplified syntax used in this method relies on two imports:
        from pyspark.sql.functions import when
        from pyspark.sql.functions import col

    Attributes:
        data (Dataset<Row>): input dataset with alpha, beta composition
        minThreshold (float): below this threshold, the secondary structure is ignored
        maxThreshold (float): above this threshold, the secondary structure is ignored
    '''

    return data.withColumn("foldType", \
                           when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
                           when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
                           when((col("alpha") > maxThreshold) & (col("beta") > maxThreshold), "alpha+beta"). \
                           otherwise("other")\
                           )

## Classify chains by secondary structure type

In [6]:
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)

## Create a Word2Vec representation of the protein sequences

**n = 2**     # create 2-grams 

**windowSize = 25**    # 25-amino residue window size for Word2Vector

**vectorSize = 50**    # dimension of feature vector

In [7]:
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()

data.toPandas().head(5)

Unnamed: 0,structureChainId,sequence,alpha,beta,coil,dsspQ8Code,dsspQ3Code,foldType,ngram,features
0,4WMY.B,TDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD...,0.170819,0.263345,0.565836,XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCSSHHHHHHHCTTCCS...,XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCCCHHHHHHHCCCCCC...,alpha+beta,"[TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T...","[-0.2564170308716473, -0.1645781719721243, -0...."
1,4WN5.A,GSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI...,0.296296,0.37963,0.324074,XXCCCCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB...,XXCCCCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE...,alpha+beta,"[GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R...","[-0.06936275625699445, -0.047815551295092226, ..."
2,4WND.B,GPLGSDLPPKVVPSKQLLHSDHMEMEPETMETKSVTDYFSKLHMGS...,0.115385,0.0,0.884615,XXXXXXXXXXXXXXXCCCCCCCCCCCCCCCCCCCGGGTTCCXXXXX...,XXXXXXXXXXXXXXXCCCCCCCCCCCCCCCCCCCHHHCCCCXXXXX...,other,"[GP, PL, LG, GS, SD, DL, LP, PP, PK, KV, VV, V...","[-0.20551030337810516, -0.009358869435695503, ..."
3,4WP6.A,GSHHHHHHSQDPMQAAQETKQKLTSCLARRYNAEQKLLDLSALGTD...,0.456954,0.119205,0.423841,XXXXXXXXXXXXXXXXXXCHHHHHHHHHHHEETTTTEEECTTGGGC...,XXXXXXXXXXXXXXXXXXCHHHHHHHHHHHEECCCCEEECCCHHHC...,other,"[GS, SH, HH, HH, HH, HH, HH, HS, SQ, QD, DP, P...","[-0.09747267771348522, -0.11870341383657614, -..."
4,4WP9.A,FQGAMGSRVVILFTDIEESTALNERIGDRAWVKLISSHDKLVSDLV...,0.393939,0.315152,0.290909,XXCCSSEEEEEEEEEETTHHHHHHHHCHHHHHHHHHHHHHHHHHHH...,XXCCCCEEEEEEEEEECCHHHHHHHHCHHHHHHHHHHHHHHHHHHH...,alpha+beta,"[FQ, QG, GA, AM, MG, GS, SR, RV, VV, VI, IL, L...","[0.0453769032293084, -0.09423059372189031, -0...."


## Keep only a subset of relevant fields for further processing

In [8]:
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])

## Select only alpha and beta foldType to parquet file

In [9]:
data = data.where((data.foldType == 'alpha') | (data.foldType == 'beta')) #| (data.foldType == 'other'))

print(f"Total number of data: {data.count()}")
data.toPandas().head()

Total number of data: 2341


Unnamed: 0,structureChainId,alpha,beta,coil,foldType,features
0,1GWM.A,0.039216,0.503268,0.457516,beta,"[-0.0784562407429085, -0.23499373910262394, -0..."
1,1GXR.A,0.0,0.543284,0.456716,beta,"[-0.19994124448082098, -0.12346233704149545, -..."
2,1H2K.S,0.333333,0.0,0.666667,alpha,"[-0.09425241947174073, -0.1447215843014419, -0..."
3,1H32.B,0.318519,0.037037,0.644444,alpha,"[-0.3225205926510104, -0.07207524626223492, -0..."
4,1H6G.A,0.827451,0.0,0.172549,alpha,"[-0.10379326214977339, -0.13861494888277615, -..."


## Basic dataset information and setting

In [10]:
label = 'foldType'
testFraction = 0.1
seed = 123

vector = data.first()["features"]
featureCount = len(vector)
print(f"Feature count    : {featureCount}")
    
classCount = int(data.select(label).distinct().count())
print(f"Class count    : {classCount}")

print(f"Dataset size (unbalanced)    : {data.count()}")
    
data.groupby(label).count().show(classCount)
data = datasetBalancer.downsample(data, label, 1)
print(f"Dataset size (balanced)  : {data.count()}")
    
data.groupby(label).count().show(classCount)

Feature count    : 50
Class count    : 2
Dataset size (unbalanced)    : 2341
+--------+-----+
|foldType|count|
+--------+-----+
|    beta|  613|
|   alpha| 1728|
+--------+-----+

Dataset size (balanced)  : 1230
+--------+-----+
|foldType|count|
+--------+-----+
|    beta|  613|
|   alpha|  617|
+--------+-----+



## Decision Tree Classifier

In [11]:
dtc = DecisionTreeClassifier()
mcc = SparkMultiClassClassifier(dtc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	555	62
beta	551	62

Sample predictions: DecisionTreeClassifier
+----------------+-----------+----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|structureChainId|      alpha|      beta|      coil|foldType|            features|indexedLabel|rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|          2CO3.A| 0.02962963|0.54814816|0.42222223|    beta|[-0.3705799010691...|         1.0| [11.0,288.0]|[0.03678929765886...|       1.0|          beta|
|          4CE8.C|        0.0| 0.7105263|0.28947368|    beta|[-0.3512168330824...|         1.0| [11.0,288.0]|[0.03678929765886...|       1.0|          beta|
|          4KU0.D|        0.0|0.39583334| 0.6041667|    beta|[-0.3439378420381...|         1.0| [11.0,288.0]|[0.03678929765886.

## Random Forest Classifier

In [12]:
rfc = RandomForestClassifier()
mcc = SparkMultiClassClassifier(rfc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	555	62
beta	551	62

Sample predictions: RandomForestClassifier
+----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|      beta|      coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          2CO3.A| 0.02962963|0.54814816|0.42222223|    beta|[-0.3705799010691...|         1.0|[0.89778332956146...|[0.04488916647807...|       1.0|          beta|
|          4CE8.C|        0.0| 0.7105263|0.28947368|    beta|[-0.3512168330824...|         1.0|[1.01261707581008...|[0.05063085379050...|       1.0|          beta|
|          4KU0.D|        0.0|0.39583334| 0.6041667|    beta|[-0.3439378420381...|         1

## Logistic Regression Classifier

In [13]:
lr = LogisticRegression()
mcc = SparkMultiClassClassifier(lr, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	555	62
beta	551	62

Sample predictions: LogisticRegression
+----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|      beta|      coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          2CO3.A| 0.02962963|0.54814816|0.42222223|    beta|[-0.3705799010691...|         1.0|[-6.3102939898716...|[0.00181420152525...|       1.0|          beta|
|          4CE8.C|        0.0| 0.7105263|0.28947368|    beta|[-0.3512168330824...|         1.0|[-6.9051730622340...|[0.00100158138251...|       1.0|          beta|
|          4KU0.D|        0.0|0.39583334| 0.6041667|    beta|[-0.3439378420381...|         1.0|[

## Simple Multilayer Perception Classifier

In [14]:
layers = [featureCount, 64, 64, classCount]
mpc = MultilayerPerceptronClassifier().setLayers(layers) \
                                          .setBlockSize(128) \
                                          .setSeed(1234) \
                                          .setMaxIter(100)
mcc = SparkMultiClassClassifier(mpc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	555	62
beta	551	62

Sample predictions: MultilayerPerceptronClassifier
+----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|      beta|      coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          2CO3.A| 0.02962963|0.54814816|0.42222223|    beta|[-0.3705799010691...|         1.0|[-2.6646716073380...|[0.00946286442072...|       1.0|          beta|
|          4CE8.C|        0.0| 0.7105263|0.28947368|    beta|[-0.3512168330824...|         1.0|[-2.7568692500591...|[0.00783295954254...|       1.0|          beta|
|          4KU0.D|        0.0|0.39583334| 0.6041667|    beta|[-0.3439378420381...|  

## Terminate Spark

In [15]:
spark.stop()