In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import regexp_replace, col
from pyspark.sql.types import FloatType,IntegerType

In [2]:
spark = SparkSession.builder.appName("DecisionTree").config("spark.driver.memory", "8g").getOrCreate()

<h4>Loading Data from csv file</h4>

In [3]:
data = spark.read.load('googleplaystore.csv', format='csv', sep=',', header='true', inferSchema='true', escape='"')

In [4]:
data.show(10)

+--------------------+--------------+------+-------+----+-----------+----+-----+--------------+--------------------+------------------+------------------+------------+
|                 App|      Category|Rating|Reviews|Size|   Installs|Type|Price|Content Rating|              Genres|      Last Updated|       Current Ver| Android Ver|
+--------------------+--------------+------+-------+----+-----------+----+-----+--------------+--------------------+------------------+------------------+------------+
|Photo Editor & Ca...|ART_AND_DESIGN|   4.1|    159| 19M|    10,000+|Free|    0|      Everyone|        Art & Design|   January 7, 2018|             1.0.0|4.0.3 and up|
| Coloring book moana|ART_AND_DESIGN|   3.9|    967| 14M|   500,000+|Free|    0|      Everyone|Art & Design;Pret...|  January 15, 2018|             2.0.0|4.0.3 and up|
|U Launcher Lite –...|ART_AND_DESIGN|   4.7|  87510|8.7M| 5,000,000+|Free|    0|      Everyone|        Art & Design|    August 1, 2018|             1.2.4|4.0.3 

In [5]:
data.printSchema()

root
 |-- App: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Rating: double (nullable = true)
 |-- Reviews: string (nullable = true)
 |-- Size: string (nullable = true)
 |-- Installs: string (nullable = true)
 |-- Type: string (nullable = true)
 |-- Price: string (nullable = true)
 |-- Content Rating: string (nullable = true)
 |-- Genres: string (nullable = true)
 |-- Last Updated: string (nullable = true)
 |-- Current Ver: string (nullable = true)
 |-- Android Ver: string (nullable = true)



<H4>Data Cleaning and Data PreProcessing</H4>

In [6]:
data=data.withColumn("Reviews",col("Reviews").cast(IntegerType()))\
.withColumn("Installs",regexp_replace(col("Installs"),"[^0-9]",""))\
.withColumn("Installs",col("Installs").cast(IntegerType()))\
.withColumn("Price",regexp_replace(col("price"),"[$]",""))\
.withColumn("Price",col("price").cast(FloatType()))\
.withColumn("Size",regexp_replace(col("Size"), "M", ""))\
.withColumn("Size", regexp_replace(col("Size"), "k", ""))\
.withColumn("Size", col("Size").cast(FloatType()))\
.withColumn('Size', regexp_replace(col('Size'), 'Varies with device', '0.0'))\
.withColumn('Size', col('Size').cast(FloatType()))

In [7]:
data.printSchema()
data = data.fillna(0)
data.columns

root
 |-- App: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Rating: double (nullable = true)
 |-- Reviews: integer (nullable = true)
 |-- Size: float (nullable = true)
 |-- Installs: integer (nullable = true)
 |-- Type: string (nullable = true)
 |-- Price: float (nullable = true)
 |-- Content Rating: string (nullable = true)
 |-- Genres: string (nullable = true)
 |-- Last Updated: string (nullable = true)
 |-- Current Ver: string (nullable = true)
 |-- Android Ver: string (nullable = true)



['App',
 'Category',
 'Rating',
 'Reviews',
 'Size',
 'Installs',
 'Type',
 'Price',
 'Content Rating',
 'Genres',
 'Last Updated',
 'Current Ver',
 'Android Ver']

<h4>Convert the 'Type' column to a numerical format</h4>

In [8]:
indexer = StringIndexer(inputCol='Type', outputCol='label')
data = indexer.fit(data).transform(data)

In [9]:
data.printSchema()
data.select('label').show()

root
 |-- App: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Rating: double (nullable = false)
 |-- Reviews: integer (nullable = true)
 |-- Size: float (nullable = false)
 |-- Installs: integer (nullable = true)
 |-- Type: string (nullable = true)
 |-- Price: float (nullable = false)
 |-- Content Rating: string (nullable = true)
 |-- Genres: string (nullable = true)
 |-- Last Updated: string (nullable = true)
 |-- Current Ver: string (nullable = true)
 |-- Android Ver: string (nullable = true)
 |-- label: double (nullable = false)

+-----+
|label|
+-----+
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
+-----+
only showing top 20 rows



<h4>Select features and label</h4>

In [10]:
feature_cols = ['Rating', 'Reviews','Installs']
assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')
data = assembler.transform(data)

<h4>Split data into training and test sets</h4>

In [11]:
train_data, test_data = data.randomSplit([0.8, 0.2], seed=1234)
test_data.count()

2160

<h4>Create a decision tree model</h4>

In [12]:
dt = DecisionTreeClassifier(seed=1234,featuresCol='features', labelCol='label')

<h4>Train the model</h4>

In [13]:
model = dt.fit(train_data)
predictions = model.transform(test_data)

print("Decision Tree Model:")
print(model.toDebugString)

Decision Tree Model:
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_1fadc9afb983, depth=5, numNodes=33, numClasses=4, numFeatures=3
  If (feature 2 <= 30000.0)
   If (feature 1 <= 764.0)
    If (feature 1 <= 104.5)
     Predict: 0.0
    Else (feature 1 > 104.5)
     If (feature 2 <= 7500.0)
      If (feature 1 <= 169.5)
       Predict: 0.0
      Else (feature 1 > 169.5)
       Predict: 1.0
     Else (feature 2 > 7500.0)
      Predict: 0.0
   Else (feature 1 > 764.0)
    If (feature 1 <= 1255.5)
     If (feature 0 <= 4.45)
      If (feature 0 <= 4.05)
       Predict: 0.0
      Else (feature 0 > 4.05)
       Predict: 1.0
     Else (feature 0 > 4.45)
      If (feature 2 <= 3000.0)
       Predict: 1.0
      Else (feature 2 > 3000.0)
       Predict: 0.0
    Else (feature 1 > 1255.5)
     If (feature 0 <= 3.95)
      Predict: 0.0
     Else (feature 0 > 3.95)
      If (feature 0 <= 4.85)
       Predict: 1.0
      Else (feature 0 > 4.85)
       Predict: 0.0
  Else (feature 2 > 300

<h4>Evaluate the model using MulticlassClassifierEvaluator</h4>

In [14]:
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="label", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f"Accuracy: {accuracy}")

Accuracy: 0.9273148148148148


<h4>Confusion Matrix</h4>

In [15]:
confusion_matrix = predictions.groupBy("label", "prediction").count()
confusion_matrix.show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  2.0|       0.0|    1|
|  1.0|       1.0|   23|
|  0.0|       1.0|   15|
|  1.0|       0.0|  141|
|  0.0|       0.0| 1980|
+-----+----------+-----+



<ul>
    <li>True Positives (TP): 23 (Paid apps correctly predicted as paid)</li>
    <li>False Positives (FP): 15 (Free apps incorrectly predicted as paid)</li>
    <li>True Negatives (TN): 1980 (Free apps correctly predicted as free)</li>
    <li>False Negatives (FN): 141 (Paid apps incorrectly predicted as free)</li>
</ul>

<ul>
    <li>
        <b>For label 0.0 (Free apps):</b>

-> 1980 instances were correctly predicted as free apps (True Negatives).
-> 15 instances were incorrectly predicted as paid apps (False Positives).</li>
    <li>
        <b>For label 1.0 (Paid apps):</b>

-> 23 instances were correctly predicted as paid apps (True Positives).
-> 141 instances were incorrectly predicted as free apps (False Negatives).</li>
    <li>
        <b>For label 2.0 (Type "Varies with device"):</b>

-> 1 instance was incorrectly predicted as a free app.</li></ul>