# This is spark implementation of OneVsAll - Ex3 of ML course by Andrew Ng at coursera.

## Import sparkml libraries

In [1]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression, OneVsRest

## Load data 
Point to note here is that data needs to be formatted as labeled dataset - [label, feature].  The datafile has 401 columns, first column is the digit labet [1-10] and rest 400 represents 20x20 image of a digit.
In addition, for some reason my accuracy dropped by 5-6% if I did not use float conversion on label.
Data file is stored in HDFS.

In [2]:
data_rdd = sc.textFile('ex3/ex3data1.csv')
X_df = data_rdd.map(lambda x: x.split(",")).map(lambda x: (float(x[0]),Vectors.dense(x[1:400]))
                                               ).toDF(['label','features'])

## Train your model

In [3]:
lr = LogisticRegression(regParam=0.01,tol=1E-6, fitIntercept=True)
onevsrest = OneVsRest(classifier=lr)
model = onevsrest.fit(X_df)

## Check accuracy of model

In [4]:
predictions = model.transform(X_df)

In [5]:
print "Prediction %age on training set : " , (predictions.filter(predictions.label == predictions.prediction).count()*100.0)/predictions.count()

Prediction %age on training set :  93.88


In [6]:
predictions.filter(predictions.label != predictions.prediction)['label','prediction'].collect()

[Row(label=10.0, prediction=5.0),
 Row(label=10.0, prediction=8.0),
 Row(label=10.0, prediction=4.0),
 Row(label=10.0, prediction=3.0),
 Row(label=10.0, prediction=6.0),
 Row(label=1.0, prediction=5.0),
 Row(label=1.0, prediction=5.0),
 Row(label=1.0, prediction=5.0),
 Row(label=1.0, prediction=8.0),
 Row(label=1.0, prediction=2.0),
 Row(label=1.0, prediction=5.0),
 Row(label=2.0, prediction=7.0),
 Row(label=2.0, prediction=8.0),
 Row(label=2.0, prediction=1.0),
 Row(label=2.0, prediction=6.0),
 Row(label=2.0, prediction=5.0),
 Row(label=2.0, prediction=7.0),
 Row(label=2.0, prediction=1.0),
 Row(label=2.0, prediction=5.0),
 Row(label=2.0, prediction=1.0),
 Row(label=2.0, prediction=1.0),
 Row(label=2.0, prediction=4.0),
 Row(label=2.0, prediction=9.0),
 Row(label=2.0, prediction=3.0),
 Row(label=2.0, prediction=8.0),
 Row(label=2.0, prediction=7.0),
 Row(label=2.0, prediction=8.0),
 Row(label=2.0, prediction=8.0),
 Row(label=2.0, prediction=4.0),
 Row(label=2.0, prediction=1.0),
 Row(