In [1]:
import findspark
findspark.init()

import pyspark

In [2]:
from pyspark import SparkContext
from pyspark.sql import SparkSession

from pyspark.sql.functions import expr, col
from pyspark.ml.feature import RFormula
from pyspark.ml.classification import LogisticRegression

In [3]:
spark= SparkSession.builder.appName('dataframe app').getOrCreate()

In [7]:
# load the ZOO dataset:
zoo_data=spark.read.csv("zoo.csv",inferSchema=True,header=True)
#zoo_data.show()
zoo_data = zoo_data.withColumn("IsMammal", expr("CASE WHEN Type = 1 THEN 1 ELSE 0 END"))
#zoo_data.show()


+----------+----+--------+----+----+--------+-------+--------+-------+--------+--------+--------+----+----+----+--------+-------+----+--------+
|AnimalName|Hair|Feathers|Eggs|Milk|Airborne|Aquatic|Predator|Toothed|Backbone|Breathes|Venomous|Fins|Legs|Tail|Domestic|Catsize|Type|IsMammal|
+----------+----+--------+----+----+--------+-------+--------+-------+--------+--------+--------+----+----+----+--------+-------+----+--------+
|  aardvark|   1|       0|   0|   1|       0|      0|       1|      1|       1|       1|       0|   0|   4|   0|       0|      1|   1|       1|
|  antelope|   1|       0|   0|   1|       0|      0|       0|      1|       1|       1|       0|   0|   4|   1|       0|      1|   1|       1|
|      bass|   0|       0|   1|   0|       0|      1|       1|      1|       1|       0|       0|   1|   0|   1|       0|      0|   4|       0|
|      bear|   1|       0|   0|   1|       0|      0|       1|      1|       1|       1|       0|   0|   4|   0|       0|      1|   1|  

In [9]:
# preprocess dataset using RFormula

preprocessed_data = RFormula(formula= "IsMammal ~ Hair + Feathers + Eggs + Milk + Airborne + Aquatic +" + 
                             " Predator + Toothed + Backbone + Venomous + Fins + Legs+" +
                            "Tail + Domestic + Catsize")

preprocessed_data = preprocessed_data.fit(zoo_data)
preprocessed_data = preprocessed_data.transform(zoo_data)

In [11]:
# split dataset into training and test data
train, test = preprocessed_data.randomSplit([0.7, 0.3])

In [13]:
# configure classifier
lr = LogisticRegression(labelCol="label", featuresCol="features")

In [18]:
# train classifier
fittedLR = lr.fit(train)

In [19]:
# classify test data set
result = fittedLR.transform(test)
result.show()

+----------+----+--------+----+----+--------+-------+--------+-------+--------+--------+--------+----+----+----+--------+-------+----+--------+--------------------+-----+--------------------+--------------------+----------+
|AnimalName|Hair|Feathers|Eggs|Milk|Airborne|Aquatic|Predator|Toothed|Backbone|Breathes|Venomous|Fins|Legs|Tail|Domestic|Catsize|Type|IsMammal|            features|label|       rawPrediction|         probability|prediction|
+----------+----+--------+----+----+--------+-------+--------+-------+--------+--------+--------+----+----+----+--------+-------+----+--------+--------------------+-----+--------------------+--------------------+----------+
|      bass|   0|       0|   1|   0|       0|      1|       1|      1|       1|       0|       0|   1|   0|   1|       0|      0|   4|       0|(15,[2,5,6,7,8,10...|  0.0|[25.5394423674425...|[0.99999999999190...|       0.0|
|      bear|   1|       0|   0|   1|       0|      0|       1|      1|       1|       1|       0|   0|  

In [22]:
result_extracted = result.select("AnimalName", "IsMammal", "prediction")
result_extracted.show(105)

+----------+--------+----------+
|AnimalName|IsMammal|prediction|
+----------+--------+----------+
|      bass|       0|       0.0|
|      bear|       1|       1.0|
|   buffalo|       1|       1.0|
|      calf|       1|       1.0|
|   catfish|       0|       0.0|
|   dogfish|       0|       0.0|
|      dove|       0|       0.0|
|      duck|       0|       0.0|
|  elephant|       1|       1.0|
|      frog|       0|       0.0|
|      goat|       1|       1.0|
|   haddock|       0|       0.0|
|   herring|       0|       0.0|
|  honeybee|       0|       0.0|
|  housefly|       0|       0.0|
|      lion|       1|       1.0|
|      mink|       1|       1.0|
|      mole|       1|       1.0|
|  mongoose|       1|       1.0|
|      newt|       0|       0.0|
|   ostrich|       0|       0.0|
|  porpoise|       1|       1.0|
|      puma|       1|       1.0|
|   sealion|       1|       1.0|
|  seasnake|       0|       0.0|
|   skimmer|       0|       0.0|
|  slowworm|       0|       0.0|
|      slu

In [None]:
spark.stop()