In [1]:
# This is important to run at the begenning
import findspark
findspark.init()
## Start SparkSession
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('data_example').getOrCreate()

In [2]:
data = spark.createDataFrame(
[(1, 289, 'MD'),
 (2, 889, 'MD'),
 (3, 633, 'VA'),
 (4, 455, 'WA'),
 (5, 550, 'WA'),
 (6, 400, 'WA'),
], ['product_id', 'monthly_sale', 'state'])

In [3]:
data.show()

+----------+------------+-----+
|product_id|monthly_sale|state|
+----------+------------+-----+
|         1|         289|   MD|
|         2|         889|   MD|
|         3|         633|   VA|
|         4|         455|   WA|
|         5|         550|   WA|
|         6|         400|   WA|
+----------+------------+-----+



In [4]:
data.printSchema()

root
 |-- product_id: long (nullable = true)
 |-- monthly_sale: long (nullable = true)
 |-- state: string (nullable = true)



In [5]:
# StringIndexer
from pyspark.ml.feature import (StringIndexer, OneHotEncoder, VectorAssembler)
indexer = StringIndexer(inputCol='state', outputCol='stateNum')
indexed_data = indexer.fit(data).transform(data)
indexed_data.show(10, truncate=False)

+----------+------------+-----+--------+
|product_id|monthly_sale|state|stateNum|
+----------+------------+-----+--------+
|1         |289         |MD   |1.0     |
|2         |889         |MD   |1.0     |
|3         |633         |VA   |2.0     |
|4         |455         |WA   |0.0     |
|5         |550         |WA   |0.0     |
|6         |400         |WA   |0.0     |
+----------+------------+-----+--------+



In [6]:
# OneHotEncoder
encoder = OneHotEncoder(inputCol='stateNum', outputCol='stateVec')
onehotdata = encoder.fit(indexed_data).transform(indexed_data)
onehotdata.show(10, truncate=False)

+----------+------------+-----+--------+-------------+
|product_id|monthly_sale|state|stateNum|stateVec     |
+----------+------------+-----+--------+-------------+
|1         |289         |MD   |1.0     |(2,[1],[1.0])|
|2         |889         |MD   |1.0     |(2,[1],[1.0])|
|3         |633         |VA   |2.0     |(2,[],[])    |
|4         |455         |WA   |0.0     |(2,[0],[1.0])|
|5         |550         |WA   |0.0     |(2,[0],[1.0])|
|6         |400         |WA   |0.0     |(2,[0],[1.0])|
+----------+------------+-----+--------+-------------+



In [7]:
# VectorAssembler
assembler1 = VectorAssembler(
inputCols=["product_id", "monthly_sale", "stateVec"],
outputCol="features")
outdata1 = assembler1.transform(onehotdata)
outdata1.show(10, truncate=False)

+----------+------------+-----+--------+-------------+-------------------+
|product_id|monthly_sale|state|stateNum|stateVec     |features           |
+----------+------------+-----+--------+-------------+-------------------+
|1         |289         |MD   |1.0     |(2,[1],[1.0])|[1.0,289.0,0.0,1.0]|
|2         |889         |MD   |1.0     |(2,[1],[1.0])|[2.0,889.0,0.0,1.0]|
|3         |633         |VA   |2.0     |(2,[],[])    |[3.0,633.0,0.0,0.0]|
|4         |455         |WA   |0.0     |(2,[0],[1.0])|[4.0,455.0,1.0,0.0]|
|5         |550         |WA   |0.0     |(2,[0],[1.0])|[5.0,550.0,1.0,0.0]|
|6         |400         |WA   |0.0     |(2,[0],[1.0])|[6.0,400.0,1.0,0.0]|
+----------+------------+-----+--------+-------------+-------------------+



### END END END