In [21]:
from pyspark import SparkConf
from pyspark.sql import SparkSession

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import IndexToString, StringIndexer, VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [3]:
conf = SparkConf()
conf.setMaster("spark://master:7077")
conf.setAppName("samplePySpark")

<pyspark.conf.SparkConf at 0x7f04f15e5630>

In [4]:
spark = SparkSession.builder.config(conf=conf).getOrCreate()

In [5]:
df = spark.read.option('inferSchema', 'true').option('header', 'true').csv('/tmp/data/Pokemon.csv')

In [6]:
df.show()

+---+--------------------+------+------+-----+---+------+-------+-------+-------+-----+----------+---------+
|  #|                Name|Type 1|Type 2|Total| HP|Attack|Defense|Sp. Atk|Sp. Def|Speed|Generation|Legendary|
+---+--------------------+------+------+-----+---+------+-------+-------+-------+-----+----------+---------+
|  1|           Bulbasaur| Grass|Poison|  318| 45|    49|     49|     65|     65|   45|         1|    false|
|  2|             Ivysaur| Grass|Poison|  405| 60|    62|     63|     80|     80|   60|         1|    false|
|  3|            Venusaur| Grass|Poison|  525| 80|    82|     83|    100|    100|   80|         1|    false|
|  3|VenusaurMega Venu...| Grass|Poison|  625| 80|   100|    123|    122|    120|   80|         1|    false|
|  4|          Charmander|  Fire|  null|  309| 39|    52|     43|     60|     50|   65|         1|    false|
|  5|          Charmeleon|  Fire|  null|  405| 58|    64|     58|     80|     65|   80|         1|    false|
|  6|           Cha

In [8]:
df.columns

['#',
 'Name',
 'Type 1',
 'Type 2',
 'Total',
 'HP',
 'Attack',
 'Defense',
 'Sp. Atk',
 'Sp. Def',
 'Speed',
 'Generation',
 'Legendary']

In [11]:
train_df, test_df = df.randomSplit([0.7, 0.3], 42)

#### build pipeline

In [37]:
feature_cols = ['Total', 'HP', 'Attack', 'Defense', 'Speed', 'Generation']
assembler = VectorAssembler() \
                .setInputCols(feature_cols) \
                .setOutputCol('features')

In [38]:
indexer = StringIndexer() \
            .setInputCol('Type 1') \
            .setOutputCol('label') \
            .fit(df)

In [39]:
rfc = RandomForestClassifier() \
        .setFeaturesCol('features') \
        .setMaxDepth(5) \
        .setSeed(42)

In [40]:
label_converter = IndexToString() \
                    .setInputCol('prediction') \
                    .setOutputCol('predictedLabel') \
                    .setLabels(indexer.labels)

In [41]:
pipeline = Pipeline().setStages([assembler, indexer, rfc, label_converter])

#### create model

In [42]:
model = pipeline.fit(train_df)

#### prediction

In [43]:
predict_df = model.transform(test_df)

In [44]:
predict_df.select('Name', 'Type 1', 'Type 2', 'predictedLabel', 'prediction').show()

+--------------------+-------+------+--------------+----------+
|                Name| Type 1|Type 2|predictedLabel|prediction|
+--------------------+-------+------+--------------+----------+
|             Ivysaur|  Grass|Poison|         Water|       0.0|
|            Venusaur|  Grass|Poison|         Water|       0.0|
|VenusaurMega Venu...|  Grass|Poison|         Water|       0.0|
|            Caterpie|    Bug|  null|        Normal|       1.0|
|              Weedle|    Bug|Poison|        Normal|       1.0|
|             Spearow| Normal|Flying|        Normal|       1.0|
|               Ekans| Poison|  null|         Water|       0.0|
|               Arbok| Poison|  null|        Normal|       1.0|
|            Nidoranâ™€| Poison|  null|        Normal|       1.0|
|            Nidorina| Poison|  null|         Water|       0.0|
|           Vileplume|  Grass|Poison|         Water|       0.0|
|               Paras|    Bug| Grass|         Grass|       2.0|
|             Venonat|    Bug|Poison| 

In [45]:
spark.stop()