# Machine learning - Protein Chain Classification

In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.

[Word2Vec model](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#word2vec)

[Word2Vec example](https://spark.apache.org/docs/latest/ml-features.html#word2vec)

## Imports

In [1]:
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder, SparkMultiClassClassifier, datasetBalancer   
from pyspark.sql.functions import *
from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, MultilayerPerceptronClassifier, RandomForestClassifier

#### Configure Spark

In [2]:
spark = SparkSession.builder.master("local[*]").appName("ProteinChainClassification").getOrCreate()

## Read MMTF File and create a non-redundant set (<=40% seq. identity) of L-protein clains

In [3]:
pdb = mmtfReader.read_sequence_file('../../resources/mmtf_reduced_sample/') \
                .flatMap(StructureToPolymerChains()) #\
#                 .filter(Pisces(sequenceIdentity=40,resolution=3.0))

## Get secondary structure content

In [4]:
data = secondaryStructureExtractor.get_dataset(pdb)

## Define addProteinFoldType function

In [5]:
def add_protein_fold_type(data, minThreshold, maxThreshold):
    '''
    Adds a column "foldType" with three major secondary structure class:
    "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.

    The simplified syntax used in this method relies on two imports:
        from pyspark.sql.functions import when
        from pyspark.sql.functions import col

    Attributes:
        data (Dataset<Row>): input dataset with alpha, beta composition
        minThreshold (float): below this threshold, the secondary structure is ignored
        maxThreshold (float): above this threshold, the secondary structure is ignored
    '''

    return data.withColumn("foldType", \
                           when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
                           when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
                           when((col("alpha") > maxThreshold) & (col("beta") > minThreshold), "alpha+beta"). \
                           otherwise("other")\
                           )

## Classify chains by secondary structure type

In [6]:
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)

## Create a Word2Vec representation of the protein sequences

**n = 2**     # create 2-grams 

**windowSize = 25**    # 25-amino residue window size for Word2Vector

**vectorSize = 50**    # dimension of feature vector

In [7]:
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()

data.toPandas().head(5)

Unnamed: 0,structureChainId,sequence,alpha,beta,coil,dsspQ8Code,dsspQ3Code,foldType,ngram,features
0,1LBU.A,DGCYTWSGTLSEGSSGEAVRQLQIRVAGYPGTGAQLAIDGQFGPAT...,0.361502,0.107981,0.530516,CCSCCCCSCBCTTCBSHHHHHHHHHTTTCSCTTCCCCCSSBCCHHH...,CCCCCCCCCECCCCECHHHHHHHHHCCCCCCCCCCCCCCCECCHHH...,alpha+beta,"[DG, GC, CY, YT, TW, WS, SG, GT, TL, LS, SE, E...","[0.004772469902165093, -0.26769595273401375, 0..."
1,1LC0.A,MITNSGKFGVVVVGVGRAGSVRLRDLKDPRSAAFLNLIGFVSRREL...,0.410345,0.275862,0.313793,CCCCCCSEEEEEECCSHHHHHHHHHHTSHHHHTTEEEEEEECSSCC...,CCCCCCCEEEEEECCCHHHHHHHHHHCCHHHHCCEEEEEEECCCCC...,alpha+beta,"[MI, IT, TN, NS, SG, GK, KF, FG, GV, VV, VV, V...","[-0.07511192831010542, -0.3225324001646388, -0..."
2,1LC5.A,MALFNTAHGGNIREPATVLGISPDQLLDFSANINPLGMPVSVKRAL...,0.428169,0.157746,0.414084,XXCCCCSSSCCCHHHHHHHTSCGGGSEECSSCCCTTCCCHHHHHHH...,XXCCCCCCCCCCHHHHHHHCCCHHHCEECCCCCCCCCCCHHHHHHH...,alpha+beta,"[MA, AL, LF, FN, NT, TA, AH, HG, GG, GN, NI, I...","[-0.15778549242323425, -0.22345690293745563, -..."
3,1LFP.A,MAGHSHWAQIKHKKAKVDAQRGKLFSKLIREIIVATRLGGPNPEFN...,0.427984,0.234568,0.337449,XXXXCCSCCSSSSSSCTTTSHHHHHHHHHHHHHHHHHHHCSCGGGC...,XXXXCCCCCCCCCCCCCCCCHHHHHHHHHHHHHHHHHHHCCCHHHC...,alpha+beta,"[MA, AG, GH, HS, SH, HW, WA, AQ, QI, IK, KH, H...","[-0.1296106138221559, -0.26989139161343056, 0...."
4,1LFW.A,MDLNFKELAEAKKDAILKDLEELIAIDSSEDLENATEEYPVGKGPV...,0.32265,0.273504,0.403846,CCCCHHHHHHTTHHHHHHHHHHHHTSCCBCCGGGCCSSSTTCHHHH...,CCCCHHHHHHCCHHHHHHHHHHHHCCCCECCHHHCCCCCCCCHHHH...,alpha+beta,"[MD, DL, LN, NF, FK, KE, EL, LA, AE, EA, AK, K...","[-0.20264662569885186, -0.24491193602834618, 0..."


## Keep only a subset of relevant fields for further processing

In [8]:
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])

## Select only alpha and beta foldType to parquet file

In [9]:
data = data.where((data.foldType == 'alpha') | (data.foldType == 'beta')) #| (data.foldType == 'other'))

print(f"Total number of data: {data.count()}")
data.toPandas().head()

Total number of data: 2398


Unnamed: 0,structureChainId,alpha,beta,coil,foldType,features
0,1LGH.G,0.857143,0.0,0.142857,alpha,"[-0.1801832006736235, -0.1530499759045514, 0.0..."
1,1LGH.H,0.860465,0.0,0.139535,alpha,"[-0.13151330687105656, -0.35479396708648314, -..."
2,1LKI.A,0.674419,0.023256,0.302326,alpha,"[-0.06014170982567958, -0.06312079258346358, 0..."
3,1LMI.A,0.022901,0.51145,0.465649,beta,"[-0.016958576053954087, -0.2168502055681669, 0..."
4,1M0K.A,0.765766,0.045045,0.189189,alpha,"[-0.15340718354358984, -0.3088978866750367, 0...."


## Basic dataset information and setting

In [10]:
label = 'foldType'
testFraction = 0.1
seed = 123

vector = data.first()["features"]
featureCount = len(vector)
print(f"Feature count    : {featureCount}")
    
classCount = int(data.select(label).distinct().count())
print(f"Class count    : {classCount}")

print(f"Dataset size (unbalanced)    : {data.count()}")
    
data.groupby(label).count().show(classCount)
data = datasetBalancer.downsample(data, label, 1)
print(f"Dataset size (balanced)  : {data.count()}")
    
data.groupby(label).count().show(classCount)

Feature count    : 50
Class count    : 2
Dataset size (unbalanced)    : 2398
+--------+-----+
|foldType|count|
+--------+-----+
|    beta|  693|
|   alpha| 1705|
+--------+-----+

Dataset size (balanced)  : 1397
+--------+-----+
|foldType|count|
+--------+-----+
|    beta|  693|
|   alpha|  704|
+--------+-----+



## Decision Tree Classifier

In [11]:
dtc = DecisionTreeClassifier()
mcc = SparkMultiClassClassifier(dtc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	638	66
beta	620	73

Sample predictions: DecisionTreeClassifier
+----------------+-----------+-----------+-----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|       coil|foldType|            features|indexedLabel|rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+-----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|          4GXB.B|        0.0| 0.53333336| 0.46666667|    beta|[-0.0819993153214...|         1.0|  [3.0,183.0]|[0.01612903225806...|       1.0|          beta|
|          3WG3.B|0.025806451|  0.5225806|  0.4516129|    beta|[-0.0410713091661...|         1.0|  [3.0,183.0]|[0.01612903225806...|       1.0|          beta|
|          3D33.A|0.031914894|  0.5851064| 0.38297874|    beta|[-0.1077686789680...|         1.0|    [7.0,6.0]|[0.538

## Random Forest Classifier

In [12]:
rfc = RandomForestClassifier()
mcc = SparkMultiClassClassifier(rfc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	638	66
beta	620	73

Sample predictions: RandomForestClassifier
+----------------+-----------+-----------+-----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|       coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+-----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          4GXB.B|        0.0| 0.53333336| 0.46666667|    beta|[-0.0819993153214...|         1.0|[3.70021402930451...|[0.18501070146522...|       1.0|          beta|
|          3WG3.B|0.025806451|  0.5225806|  0.4516129|    beta|[-0.0410713091661...|         1.0|[2.4603671877253,...|[0.12301835938626...|       1.0|          beta|
|          3D33.A|0.031914894|  0.5851064| 0.38297874|    beta|[-0.1077686789680..

## Logistic Regression Classifier

In [13]:
lr = LogisticRegression()
mcc = SparkMultiClassClassifier(lr, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	638	66
beta	620	73

Sample predictions: LogisticRegression
+----------------+-----------+-----------+-----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|       coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+-----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          4GXB.B|        0.0| 0.53333336| 0.46666667|    beta|[-0.0819993153214...|         1.0|[-3.3034225218434...|[0.03545396309729...|       1.0|          beta|
|          3WG3.B|0.025806451|  0.5225806|  0.4516129|    beta|[-0.0410713091661...|         1.0|[-0.9278849415037...|[0.28335401096387...|       1.0|          beta|
|          3D33.A|0.031914894|  0.5851064| 0.38297874|    beta|[-0.1077686789680...|  

## Simple Multilayer Perception Classifier

In [14]:
layers = [featureCount, 64, 64, classCount]
mpc = MultilayerPerceptronClassifier().setLayers(layers) \
                                          .setBlockSize(128) \
                                          .setSeed(1234) \
                                          .setMaxIter(100)
mcc = SparkMultiClassClassifier(mpc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	638	66
beta	620	73

Sample predictions: MultilayerPerceptronClassifier
+----------------+-----------+-----------+-----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|       coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+-----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          4GXB.B|        0.0| 0.53333336| 0.46666667|    beta|[-0.0819993153214...|         1.0|[-1.6416983478329...|[0.14315186039957...|       1.0|          beta|
|          3WG3.B|0.025806451|  0.5225806|  0.4516129|    beta|[-0.0410713091661...|         1.0|[-2.7802598864981...|[0.01528569041399...|       1.0|          beta|
|          3D33.A|0.031914894|  0.5851064| 0.38297874|    beta|[-0.1077686

## Terminate Spark

In [17]:
spark.stop()