In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession \
        .builder \
        .appName('Predict grape quality from wine properties') \
        .getOrCreate()

In [3]:
spark

In [5]:
# If a hive db related exception raises - delete the db.lck file from C:\tools\spark2\bin\code\metastore_db
rawData = spark.read \
            .format('csv') \
            .option('header', 'false') \
            .load('dataset/sales.csv')

In [6]:
rawData

DataFrame[_c0: string, _c1: string, _c2: string, _c3: string, _c4: string, _c5: string, _c6: string, _c7: string]

In [7]:
rawData.show(5)

+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+
|      _c0|      _c1|                 _c2|     _c3|                _c4|      _c5|       _c6|           _c7|
+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|        InvoiceDate|UnitPrice|CustomerID|       Country|
|   536365|   85123A|WHITE HANGING HEA...|       6|2010-12-01 08:26:00|     2.55|   17850.0|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|2010-12-01 08:26:00|     3.39|   17850.0|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|2010-12-01 08:26:00|     2.75|   17850.0|United Kingdom|
|   536365|   84029G|KNITTED UNION FLA...|       6|2010-12-01 08:26:00|     3.39|   17850.0|United Kingdom|
+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+
only showing top 5 rows



In [8]:
dataset = rawData.toDF('Label',
                'Alcohol',
                'MalicAcid',
                'Ash',
                'AshAlkalinity',
                'Magnesium',
                'TotalPhenols',
                'Flavanoids',
                'NonflavanoidPhenols',
                'Proanthocyanins',
                'ColorIntensity',
                'Hue',
                'OD',
                'Proline'
                )

IllegalArgumentException: "requirement failed: The number of columns doesn't match.\nOld column names (8): _c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7\nNew column names (14): Label, Alcohol, MalicAcid, Ash, AshAlkalinity, Magnesium, TotalPhenols, Flavanoids, NonflavanoidPhenols, Proanthocyanins, ColorIntensity, Hue, OD, Proline"

In [9]:
dataset.show(5)

NameError: name 'dataset' is not defined

In [16]:
dataset.take(5)

[Row(Label='1', Alcohol='14.23', MalicAcid='1.71', Ash='2.43', AshAlkalinity='15.6', Magnesium='127', TotalPhenols='2.8', Flavanoids='3.06', NonflavanoidPhenols='.28', Proanthocyanins='2.29', ColorIntensity='5.64', Hue='1.04', OD='3.92', Proline='1065'),
 Row(Label='1', Alcohol='13.2', MalicAcid='1.78', Ash='2.14', AshAlkalinity='11.2', Magnesium='100', TotalPhenols='2.65', Flavanoids='2.76', NonflavanoidPhenols='.26', Proanthocyanins='1.28', ColorIntensity='4.38', Hue='1.05', OD='3.4', Proline='1050'),
 Row(Label='1', Alcohol='13.16', MalicAcid='2.36', Ash='2.67', AshAlkalinity='18.6', Magnesium='101', TotalPhenols='2.8', Flavanoids='3.24', NonflavanoidPhenols='.3', Proanthocyanins='2.81', ColorIntensity='5.68', Hue='1.03', OD='3.17', Proline='1185'),
 Row(Label='1', Alcohol='14.37', MalicAcid='1.95', Ash='2.5', AshAlkalinity='16.8', Magnesium='113', TotalPhenols='3.85', Flavanoids='3.49', NonflavanoidPhenols='.24', Proanthocyanins='2.18', ColorIntensity='7.8', Hue='.86', OD='3.45', P

In [11]:
from pyspark.ml.linalg import Vectors

def vectorize(data):
    return data.rdd.map(lambda r: [r[0], Vectors.dense(r[1:])]).toDF(['label','features'])

In [12]:
vectorizedData = vectorize(dataset)

In [13]:
vectorizedData.show(5)

+-----+--------------------+
|label|            features|
+-----+--------------------+
|    1|[14.23,1.71,2.43,...|
|    1|[13.2,1.78,2.14,1...|
|    1|[13.16,2.36,2.67,...|
|    1|[14.37,1.95,2.5,1...|
|    1|[13.24,2.59,2.87,...|
+-----+--------------------+
only showing top 5 rows



In [15]:
vectorizedData.take(5)

[Row(label='1', features=DenseVector([14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0])),
 Row(label='1', features=DenseVector([13.2, 1.78, 2.14, 11.2, 100.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.4, 1050.0])),
 Row(label='1', features=DenseVector([13.16, 2.36, 2.67, 18.6, 101.0, 2.8, 3.24, 0.3, 2.81, 5.68, 1.03, 3.17, 1185.0])),
 Row(label='1', features=DenseVector([14.37, 1.95, 2.5, 16.8, 113.0, 3.85, 3.49, 0.24, 2.18, 7.8, 0.86, 3.45, 1480.0])),
 Row(label='1', features=DenseVector([13.24, 2.59, 2.87, 21.0, 118.0, 2.8, 2.69, 0.39, 1.82, 4.32, 1.04, 2.93, 735.0]))]

In [18]:
from pyspark.ml.feature import StringIndexer

labelIndexer = StringIndexer(inputCol='label',
                             outputCol='indexedLabel')

In [19]:
indexedData = labelIndexer.fit(vectorizedData).transform(vectorizedData)
indexedData.take(2)

[Row(label='1', features=DenseVector([14.23, 1.71, 2.43, 15.6, 127.0, 2.8, 3.06, 0.28, 2.29, 5.64, 1.04, 3.92, 1065.0]), indexedLabel=1.0),
 Row(label='1', features=DenseVector([13.2, 1.78, 2.14, 11.2, 100.0, 2.65, 2.76, 0.26, 1.28, 4.38, 1.05, 3.4, 1050.0]), indexedLabel=1.0)]

In [20]:
indexedData

DataFrame[label: string, features: vector, indexedLabel: double]

In [21]:
indexedData.select('label').distinct().show()

+-----+
|label|
+-----+
|    3|
|    1|
|    2|
+-----+



In [22]:
indexedData.select('indexedLabel').distinct().show()

+------------+
|indexedLabel|
+------------+
|         0.0|
|         1.0|
|         2.0|
+------------+



In [23]:
(trainingData, testData) = indexedData.randomSplit([0.8, 0.2])

In [24]:
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [25]:
# specify layers for the neural network:
# input layer of size 4 (features), two intermediate of size 5 and 4
# and output of size 3 (classes)
layers = [4, 5, 4, 3]

In [27]:
# create the trainer and set its parameters
trainer = MultilayerPerceptronClassifier(maxIter=10, layers=layers, blockSize=100, seed=1234)

In [28]:
# train the model
model = trainer.fit(trainingData)

IllegalArgumentException: 'requirement failed: Column label must be of type NumericType but was actually of type StringType.'