# Machine learning - Features extraction

Demo to create a feature vector for protein fold classification. 
In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.

[Word2Vec model](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#word2vec)

[Word2Vec example](https://spark.apache.org/docs/latest/ml-features.html#word2vec)

## Imports

In [12]:
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webFilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder

## Configure Spark Context

In [13]:
conf = SparkConf() \
            .setMaster("local[*]") \
            .setAppName("MachineLearningFeaturesExtractionDemo")

sc = SparkContext(conf = conf)

## Read MMTF File and create a non-redundant set (<=40% seq. identity) of L-protein clains

In [14]:
pdb = mmtfReader.read_sequence_file('../resouces/mmtf_reduced_sample/', sc) \
                .flatMap(StructureToPolymerChains()) \
                .filter(Pisces(sequenceIdentity=40,resolution=3.0))

## Get secondary structure content

In [15]:
data = secondaryStructureExtractor.get_dataset(pdb)

## Define addProteinFoldType function

In [16]:
def add_protein_fold_type(data, minThreshold, maxThreshold):
    '''
    Adds a column "foldType" with three major secondary structure class:
    "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.

    The simplified syntax used in this method relies on two imports:
        from pyspark.sql.functions import when
        from pyspark.sql.functions import col

    Attributes:
        data (Dataset<Row>): input dataset with alpha, beta composition
        minThreshold (float): below this threshold, the secondary structure is ignored
        maxThreshold (float): above this threshold, the secondary structure is ignored
    '''

    return data.withColumn("foldType", \
                           when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
                           when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
                           when((col("alpha") > maxThreshold) & (col("beta") > minThreshold), "alpha+beta"). \
                           otherwise("other")\
                           )

## Classify chains by secondary structure type

In [17]:
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)

## Create a Word2Vec representation of the protein sequences

**n = 2**     # create 2-grams 

**windowSize = 25**    # 25-amino residue window size for Word2Vector

**vectorSize = 50**    # dimension of feature vector

In [18]:
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()

data.toPandas().head(5)

Unnamed: 0,structureChainId,sequence,alpha,beta,coil,dsspQ8Code,dsspQ3Code,foldType,ngram,features
0,3F5B.A,SNAMMIKASTNEFRFCFKQMNKSQHELVLGWIHQPHINEWLHGDGL...,0.372093,0.319767,0.30814,XCTTCXXXXXXCCCEEEEECCGGGHHHHHHHTTSHHHHTTSCHHHH...,XCCCCXXXXXXCCCEEEEECCHHHHHHHHHHCCCHHHHCCCCHHHH...,alpha+beta,"[SN, NA, AM, MM, MI, IK, KA, AS, ST, TN, NE, E...","[0.29378673649336423, -0.5251977472715286, 0.3..."
1,3F5O.F,MTSMTQSLREVIKAMTKARNFERVLGKITLVSAAPGKVICEMKVEE...,0.282609,0.413043,0.304348,XCCHHHHHHHHHHHHHHCSSGGGGGTTCEEEEEETTEEEEEEECCG...,XCCHHHHHHHHHHHHHHCCCHHHHHCCCEEEEEECCEEEEEEECCH...,alpha+beta,"[MT, TS, SM, MT, TQ, QS, SL, LR, RE, EV, VI, I...","[-0.08282158001750504, -0.5237155609643783, 0...."
2,3F5R.A,MGSSHHHHHHSSGRENLYFQGMSTDFDRIYLNQSKFSGRFRIADSG...,0.19469,0.424779,0.380531,XXXXXXXXXXXXXXXXXXXCCSSEEEEEEEETTCSSCEEEEEETTE...,XXXXXXXXXXXXXXXXXXXCCCCEEEEEEEECCCCCCEEEEEECCE...,alpha+beta,"[MG, GS, SS, SH, HH, HH, HH, HH, HH, HS, SS, S...","[-0.11817694425270402, -0.3900981570054826, 0...."
3,3F62.A,GAMVETKCPNLDIVTSSGEFHCSGCVEHMPEFSYMYWLAKDMKSDE...,0.027778,0.527778,0.444444,CCCSCCCSCCCEEEEETTEEEEEEEECSSTTSCEEEEEEEETTSCC...,CCCCCCCCCCCEEEEECCEEEEEEEECCCCCCCEEEEEEEECCCCC...,beta,"[GA, AM, MV, VE, ET, TK, KC, CP, PN, NL, LD, D...","[0.1076028709196382, -0.6029256812331301, 0.46..."
4,3F67.A,SNAIIAGETSIPSQGENMPAYHARPKNADGPLPIVIVVQEIFGVHE...,0.358333,0.25,0.391667,XCCEEEEEEEEEETTEEEEEEEEEETTCCSCEEEEEEECCTTCSCH...,XCCEEEEEEEEEECCEEEEEEEEEECCCCCCEEEEEEECCCCCCCH...,alpha+beta,"[SN, NA, AI, II, IA, AG, GE, ET, TS, SI, IP, P...","[0.1345422140727654, -0.5889589260177066, 0.73..."


## Keep only a subset of relevant fields for further processing

In [19]:
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])

## Write to parquet file

In [22]:
data.write.mode('overwrite').format('parquet').save('./features')

## Terminate Spark

In [11]:
sc.stop()