<a href="https://colab.research.google.com/github/rohandawar/pyspark/blob/main/ANN_On_Pyspark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this Notebook, I am trying to implenet ANN in pyspark

In [84]:
! pip install pyspark



In [85]:
# Import Libs

# Pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Pandas
import pandas as pd

# Google
from google.colab import drive

In [86]:
# Start the spark Session
spark = SparkSession.builder.appName('ann').getOrCreate()

In [87]:
#Read the data
df = pd.read_csv(r'https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv')
df.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa


In [88]:
data = spark.createDataFrame(df)
data.show()

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| setosa|
|         4.9|        3.0|         1.4|        0.2| setosa|
|         4.7|        3.2|         1.3|        0.2| setosa|
|         4.6|        3.1|         1.5|        0.2| setosa|
|         5.0|        3.6|         1.4|        0.2| setosa|
|         5.4|        3.9|         1.7|        0.4| setosa|
|         4.6|        3.4|         1.4|        0.3| setosa|
|         5.0|        3.4|         1.5|        0.2| setosa|
|         4.4|        2.9|         1.4|        0.2| setosa|
|         4.9|        3.1|         1.5|        0.1| setosa|
|         5.4|        3.7|         1.5|        0.2| setosa|
|         4.8|        3.4|         1.6|        0.2| setosa|
|         4.8|        3.0|         1.4|        0.1| setosa|
|         4.3|        3.0|         1.1| 

In [89]:
# check class distribution
data.groupBy('species').count().show()

+----------+-----+
|   species|count|
+----------+-----+
|versicolor|   50|
|    setosa|   50|
| virginica|   50|
+----------+-----+



In [90]:
data.printSchema()

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- species: string (nullable = true)



In [91]:
# Columns list
col_list = data.columns
print('List for columns :', col_list)
col_list.remove('species')
print('List for columns :', col_list)

List for columns : ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
List for columns : ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']


In [92]:
# Vector Assembler
vec_assembler = VectorAssembler(inputCols=col_list, outputCol='features')
datadf = vec_assembler.transform(data)
datadf.show(5)


+------------+-----------+------------+-----------+-------+-----------------+
|sepal_length|sepal_width|petal_length|petal_width|species|         features|
+------------+-----------+------------+-----------+-------+-----------------+
|         5.1|        3.5|         1.4|        0.2| setosa|[5.1,3.5,1.4,0.2]|
|         4.9|        3.0|         1.4|        0.2| setosa|[4.9,3.0,1.4,0.2]|
|         4.7|        3.2|         1.3|        0.2| setosa|[4.7,3.2,1.3,0.2]|
|         4.6|        3.1|         1.5|        0.2| setosa|[4.6,3.1,1.5,0.2]|
|         5.0|        3.6|         1.4|        0.2| setosa|[5.0,3.6,1.4,0.2]|
+------------+-----------+------------+-----------+-------+-----------------+
only showing top 5 rows



In [93]:
indexer = StringIndexer(inputCol = 'species', outputCol='label')
finaldf = indexer.fit(datadf).transform(datadf)
finaldf.show(5)

+------------+-----------+------------+-----------+-------+-----------------+-----+
|sepal_length|sepal_width|petal_length|petal_width|species|         features|label|
+------------+-----------+------------+-----------+-------+-----------------+-----+
|         5.1|        3.5|         1.4|        0.2| setosa|[5.1,3.5,1.4,0.2]|  0.0|
|         4.9|        3.0|         1.4|        0.2| setosa|[4.9,3.0,1.4,0.2]|  0.0|
|         4.7|        3.2|         1.3|        0.2| setosa|[4.7,3.2,1.3,0.2]|  0.0|
|         4.6|        3.1|         1.5|        0.2| setosa|[4.6,3.1,1.5,0.2]|  0.0|
|         5.0|        3.6|         1.4|        0.2| setosa|[5.0,3.6,1.4,0.2]|  0.0|
+------------+-----------+------------+-----------+-------+-----------------+-----+
only showing top 5 rows



In [94]:
# check class distribution
finaldf.select('species', 'label').groupBy('species', 'label').count().show()

+----------+-----+-----+
|   species|label|count|
+----------+-----+-----+
|    setosa|  0.0|   50|
|versicolor|  1.0|   50|
| virginica|  2.0|   50|
+----------+-----+-----+



In [95]:
traindf, testdf = finaldf.randomSplit([0.6,0.4], seed=1)

In [96]:
# Instiate the model
layers = [4,5,5,3]
mlp = MultilayerPerceptronClassifier(layers=layers, seed=1)

# Let the training begin
mlp_model = mlp.fit(traindf)

# Make the predictions
pred_df = mlp_model.transform(testdf)

In [97]:
pred_df.show(5)

+------------+-----------+------------+-----------+-------+-----------------+-----+--------------------+--------------------+----------+
|sepal_length|sepal_width|petal_length|petal_width|species|         features|label|       rawPrediction|         probability|prediction|
+------------+-----------+------------+-----------+-------+-----------------+-----+--------------------+--------------------+----------+
|         4.3|        3.0|         1.1|        0.1| setosa|[4.3,3.0,1.1,0.1]|  0.0|[11.7431310226109...|[0.99993211489659...|       0.0|
|         4.5|        2.3|         1.3|        0.3| setosa|[4.5,2.3,1.3,0.3]|  0.0|[10.7529016939915...|[0.99979306260253...|       0.0|
|         4.6|        3.1|         1.5|        0.2| setosa|[4.6,3.1,1.5,0.2]|  0.0|[11.5083568802337...|[0.99991375634132...|       0.0|
|         4.7|        3.2|         1.3|        0.2| setosa|[4.7,3.2,1.3,0.2]|  0.0|[11.7535046487494...|[0.99993417597894...|       0.0|
|         4.8|        3.1|         1.6|  

In [98]:
#Evaluator
evaluator = MulticlassClassificationEvaluator(labelCol=  'label', predictionCol='prediction', metricName='accuracy')
mlp_accuracy = evaluator.evaluate(pred_df)
print(f'Accuracy:{mlp_accuracy}')

Accuracy:0.9827586206896551


Reference: https://medium.com/swlh/pysparks-multi-layer-perceptron-classifier-on-iris-dataset-dcf70d553cd8