# Machine learning - Features extraction

Demo to create a feature vector for protein fold classification. 
In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.

[Word2Vec model](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#word2vec)

[Word2Vec example](https://spark.apache.org/docs/latest/ml-features.html#word2vec)

## Imports

In [35]:
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webFilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder

## Configure Spark Context

In [36]:
conf = SparkConf() \
            .setMaster("local[*]") \
            .setAppName("MachineLearningFeaturesExtractionDemo")

sc = SparkContext(conf = conf)

## Read MMTF File and create a non-redundant set (<=40% seq. identity) of L-protein clains

In [37]:
pdb = mmtfReader.read_sequence_file('../../resources/mmtf_reduced_sample/', sc) \
                .flatMap(StructureToPolymerChains()) \
                .filter(Pisces(sequenceIdentity=40,resolution=3.0))

## Get secondary structure content

In [38]:
data = secondaryStructureExtractor.get_dataset(pdb)

## Define addProteinFoldType function

In [39]:
def add_protein_fold_type(data, minThreshold, maxThreshold):
    '''
    Adds a column "foldType" with three major secondary structure class:
    "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.

    The simplified syntax used in this method relies on two imports:
        from pyspark.sql.functions import when
        from pyspark.sql.functions import col

    Attributes:
        data (Dataset<Row>): input dataset with alpha, beta composition
        minThreshold (float): below this threshold, the secondary structure is ignored
        maxThreshold (float): above this threshold, the secondary structure is ignored
    '''

    return data.withColumn("foldType", \
                           when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
                           when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
                           when((col("alpha") > maxThreshold) & (col("beta") > minThreshold), "alpha+beta"). \
                           otherwise("other")\
                           )

## Classify chains by secondary structure type

In [40]:
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)

## Create a Word2Vec representation of the protein sequences

**n = 2**     # create 2-grams 

**windowSize = 25**    # 25-amino residue window size for Word2Vector

**vectorSize = 50**    # dimension of feature vector

In [41]:
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()

data.toPandas().head(5)

Unnamed: 0,structureChainId,sequence,alpha,beta,coil,dsspQ8Code,dsspQ3Code,foldType,ngram,features
0,1RCQ.A,MRPARALIDLQALRHNYRLAREATGARALAVIKADAYGHGAVRCAE...,0.316527,0.240896,0.442577,CCCCEEEEEHHHHHHHHHHHHHHHCSEEEEECHHHHHTTCHHHHHH...,CCCCEEEEEHHHHHHHHHHHHHHHCCEEEEECHHHHHCCCHHHHHH...,alpha+beta,"[MR, RP, PA, AR, RA, AL, LI, ID, DL, LQ, QA, A...","[-0.4455247169773858, -0.05284532651127306, -0..."
1,1REG.Y,MIEITLKKPEDFLKVKETLTRMGIANNKDKVLYQSCHILQKKGLYY...,0.308333,0.291667,0.4,CEEEECSSGGHHHHHHHHHTTEEEEETTTTEEEECEEEEEETTEEE...,CEEEECCCHHHHHHHHHHHCCEEEEECCCCEEEECEEEEEECCEEE...,alpha+beta,"[MI, IE, EI, IT, TL, LK, KK, KP, PE, ED, DF, F...","[-0.2694746898564179, 0.3966265844165786, 0.05..."
2,1REQ.B,SSTDQGTNPADTDDLTPTTLSLAGDFPKATEEQWEREVEKVLNRGR...,0.470113,0.121163,0.408724,XXXXXXXXXXXXXXXXXXCCCSGGGSCCCCHHHHHHHHHHHHHTTC...,XXXXXXXXXXXXXXXXXXCCCCHHHCCCCCHHHHHHHHHHHHHCCC...,alpha+beta,"[SS, ST, TD, DQ, QG, GT, TN, NP, PA, AD, DT, T...","[-0.4618564432899837, 0.19905433491159397, -0...."
3,1RFE.A,GTKQRADIVMSEAEIADFVNSSRTGTLATIGPDGQPHLTAMWYAVI...,0.3125,0.35625,0.33125,XCCCCTTTCCCHHHHHHHHHHCCCEEEEEECTTSCEEEEEECCEEE...,XCCCCCCCCCCHHHHHHHHHHCCCEEEEEECCCCCEEEEEECCEEE...,alpha+beta,"[GT, TK, KQ, QR, RA, AD, DI, IV, VM, MS, SE, E...","[-0.3139253464695182, 0.2763726514046838, -0.0..."
4,1RG8.B,HHHHHHFNLPPGNYKKPKLLYCSNGGHFLRILPDGTVDGTRDRSDQ...,0.06383,0.375887,0.560284,XXCCSCCCCCSCCSSSCEEEEETTTTEEEEECTTSCEEEESCTTCT...,XXCCCCCCCCCCCCCCCEEEEECCCCEEEEECCCCCEEEECCCCCC...,other,"[HH, HH, HH, HH, HH, HF, FN, NL, LP, PP, PG, G...","[-0.4487889504381295, 0.1226459019656839, -0.0..."


## Keep only a subset of relevant fields for further processing

In [42]:
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])

## Write to parquet file

In [43]:
data.write.format('parquet').save('./features')

## Terminate Spark

In [45]:
sc.stop()