In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import regexp_replace, col
from pyspark.sql.types import FloatType,IntegerType

In [2]:
spark = SparkSession.builder.appName("DecisionTree").config("spark.driver.memory", "8g").getOrCreate()

<h4>Loading Data from csv file</h4>

In [3]:
data = spark.read.load('googleplaystore.csv', format='csv', sep=',', header='true', inferSchema='true', escape='"')

In [4]:
data.show(10)

+--------------------+--------------+------+-------+----+-----------+----+-----+--------------+--------------------+------------------+------------------+------------+
|                 App|      Category|Rating|Reviews|Size|   Installs|Type|Price|Content Rating|              Genres|      Last Updated|       Current Ver| Android Ver|
+--------------------+--------------+------+-------+----+-----------+----+-----+--------------+--------------------+------------------+------------------+------------+
|Photo Editor & Ca...|ART_AND_DESIGN|   4.1|    159| 19M|    10,000+|Free|    0|      Everyone|        Art & Design|   January 7, 2018|             1.0.0|4.0.3 and up|
| Coloring book moana|ART_AND_DESIGN|   3.9|    967| 14M|   500,000+|Free|    0|      Everyone|Art & Design;Pret...|  January 15, 2018|             2.0.0|4.0.3 and up|
|U Launcher Lite –...|ART_AND_DESIGN|   4.7|  87510|8.7M| 5,000,000+|Free|    0|      Everyone|        Art & Design|    August 1, 2018|             1.2.4|4.0.3 

In [5]:
data.printSchema()

root
 |-- App: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Rating: double (nullable = true)
 |-- Reviews: string (nullable = true)
 |-- Size: string (nullable = true)
 |-- Installs: string (nullable = true)
 |-- Type: string (nullable = true)
 |-- Price: string (nullable = true)
 |-- Content Rating: string (nullable = true)
 |-- Genres: string (nullable = true)
 |-- Last Updated: string (nullable = true)
 |-- Current Ver: string (nullable = true)
 |-- Android Ver: string (nullable = true)



<H4>Data Cleaning and Data PreProcessing</H4>

In [6]:
data=data.withColumn("Reviews",col("Reviews").cast(IntegerType()))\
.withColumn("Installs",regexp_replace(col("Installs"),"[^0-9]",""))\
.withColumn("Installs",col("Installs").cast(IntegerType()))\
.withColumn("Price",regexp_replace(col("price"),"[$]",""))\
.withColumn("Price",col("price").cast(FloatType()))\
.withColumn("Size",regexp_replace(col("Size"), "M", ""))\
.withColumn("Size", regexp_replace(col("Size"), "k", ""))\
.withColumn("Size", col("Size").cast(FloatType()))\
.withColumn('Size', regexp_replace(col('Size'), 'Varies with device', '0.0'))\
.withColumn('Size', col('Size').cast(FloatType()))

In [7]:
data.printSchema()
data = data.fillna(0)
data.columns

root
 |-- App: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Rating: double (nullable = true)
 |-- Reviews: integer (nullable = true)
 |-- Size: float (nullable = true)
 |-- Installs: integer (nullable = true)
 |-- Type: string (nullable = true)
 |-- Price: float (nullable = true)
 |-- Content Rating: string (nullable = true)
 |-- Genres: string (nullable = true)
 |-- Last Updated: string (nullable = true)
 |-- Current Ver: string (nullable = true)
 |-- Android Ver: string (nullable = true)



['App',
 'Category',
 'Rating',
 'Reviews',
 'Size',
 'Installs',
 'Type',
 'Price',
 'Content Rating',
 'Genres',
 'Last Updated',
 'Current Ver',
 'Android Ver']

<h4>Convert the 'Type' column to a numerical format</h4>

In [8]:
indexer = StringIndexer(inputCol='Type', outputCol='label')
data = indexer.fit(data).transform(data)

In [9]:
data.printSchema()
data.select('label').show()

root
 |-- App: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Rating: double (nullable = false)
 |-- Reviews: integer (nullable = true)
 |-- Size: float (nullable = false)
 |-- Installs: integer (nullable = true)
 |-- Type: string (nullable = true)
 |-- Price: float (nullable = false)
 |-- Content Rating: string (nullable = true)
 |-- Genres: string (nullable = true)
 |-- Last Updated: string (nullable = true)
 |-- Current Ver: string (nullable = true)
 |-- Android Ver: string (nullable = true)
 |-- label: double (nullable = false)

+-----+
|label|
+-----+
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
|  0.0|
+-----+
only showing top 20 rows



<h4>Select features and label</h4>

In [10]:
feature_cols = ['Size','Rating', 'Reviews','Installs']
assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')
data = assembler.transform(data)

<h4>Split data into training and test sets</h4>

In [11]:
train_data, test_data = data.randomSplit([0.8, 0.2], seed=1234)

<h4>Create a decision tree model</h4>

In [12]:
dt = DecisionTreeClassifier(seed=1234,featuresCol='features', labelCol='label')

<h4>Train the model</h4>

In [13]:
model = dt.fit(train_data)
predictions = model.transform(test_data)

print("Decision Tree Model:")
print(model.toDebugString)

Decision Tree Model:
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_af9b854d3126, depth=5, numNodes=37, numClasses=4, numFeatures=4
  If (feature 3 <= 30000.0)
   If (feature 2 <= 764.0)
    If (feature 0 <= 43.5)
     If (feature 2 <= 273.5)
      Predict: 0.0
     Else (feature 2 > 273.5)
      If (feature 3 <= 7500.0)
       Predict: 1.0
      Else (feature 3 > 7500.0)
       Predict: 0.0
    Else (feature 0 > 43.5)
     If (feature 3 <= 7500.0)
      If (feature 2 <= 104.5)
       Predict: 0.0
      Else (feature 2 > 104.5)
       Predict: 1.0
     Else (feature 3 > 7500.0)
      If (feature 2 <= 469.5)
       Predict: 0.0
      Else (feature 2 > 469.5)
       Predict: 1.0
   Else (feature 2 > 764.0)
    If (feature 2 <= 1255.5)
     If (feature 0 <= 2.549999952316284)
      Predict: 1.0
     Else (feature 0 > 2.549999952316284)
      If (feature 0 <= 50.5)
       Predict: 0.0
      Else (feature 0 > 50.5)
       Predict: 1.0
    Else (feature 2 > 1255.5)
     If (feat

<h4>Evaluate the model using MulticlassClassifierEvaluator</h4>

In [14]:
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="label", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f"Accuracy: {accuracy}")

Accuracy: 0.9310185185185185


<h4>Confusion Matrix</h4>

In [15]:
confusion_matrix = predictions.groupBy("label", "prediction").count()
confusion_matrix.show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  2.0|       0.0|    1|
|  1.0|       1.0|   27|
|  0.0|       1.0|   11|
|  1.0|       0.0|  137|
|  0.0|       0.0| 1984|
+-----+----------+-----+



<ul>
    <li>True Positives (TP): 27 (Paid apps correctly predicted as paid)</li>
    <li>False Positives (FP): 11 (Free apps incorrectly predicted as paid)</li>
    <li>True Negatives (TN): 1984 (Free apps correctly predicted as free)</li>
    <li>False Negatives (FN): 137 (Paid apps incorrectly predicted as free)</li>
</ul>

<ul>
    <li>
        <b>For label 0.0 (Free apps):</b>

-> 1984 instances were correctly predicted as free apps (True Negatives).
-> 11 instances were incorrectly predicted as paid apps (False Positives).</li>
    <li>
        <b>For label 1.0 (Paid apps):</b>

-> 27 instances were correctly predicted as paid apps (True Positives).
-> 137 instances were incorrectly predicted as free apps (False Negatives).</li>
    <li>
        <b>For label 2.0 (Type "Varies with device"):</b>

-> 1 instance was incorrectly predicted as a free app.</li></ul>

In [16]:
from pyspark.mllib.evaluation import MulticlassMetrics
predictionAndLabels = predictions.select("prediction", "label").rdd
metrics = MulticlassMetrics(predictionAndLabels)
confusion_matrix = metrics.confusionMatrix().toArray()
print("Confusion Matrix:")
for row in confusion_matrix:
    print([int(elem) for elem in row])



Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.runJob.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 26.0 failed 1 times, most recent failure: Lost task 0.0 in stage 26.0 (TID 24) (LAPTOP-KDMQ8FBJ executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:192)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:166)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:705)
	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:749)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:673)
	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:639)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:615)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:572)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:530)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:179)
	... 15 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2785)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2721)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2720)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2720)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1206)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2984)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2923)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2912)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:971)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2263)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2284)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2303)
	at org.apache.spark.api.python.PythonRDD$.runJob(PythonRDD.scala:179)
	at org.apache.spark.api.python.PythonRDD.runJob(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:76)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:52)
	at java.base/java.lang.reflect.Method.invoke(Method.java:577)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:192)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:166)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	... 1 more
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:705)
	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:749)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:673)
	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:639)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:615)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:572)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:530)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:179)
	... 15 more
