# PySpark on Google Colab 101

In this article, we will see how we can run PySpark in a Google Colaboratory notebook. We will also perform some basic data exploratory tasks common to most data science problems. So, let’s get cracking!

In [1]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null

In [2]:
!wget -q https://archive.apache.org/dist/spark/spark-3.3.1/spark-3.3.1-bin-hadoop3.tgz

In [3]:
!tar xf spark-3.3.1-bin-hadoop3.tgz

In [4]:
!pip install -q findspark

In [5]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.3.1-bin-hadoop3"

In [6]:
import findspark
findspark.init()

In [8]:
from pyspark.sql import SparkSession
# spark = SparkSession.builder.master("local[*]").getOrCreate()
spark = SparkSession.builder\
        .master("local")\
        .appName("Colab")\
        .config('spark.ui.port', '4050')\
        .getOrCreate()

In [9]:
spark

In [10]:
!wget --continue https://raw.githubusercontent.com/weejessada/pyspark-tutorial/main/sample-data/iris.csv -O /tmp/iris.csv

--2023-01-16 02:46:18--  https://raw.githubusercontent.com/weejessada/pyspark-tutorial/main/sample-data/iris.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4550 (4.4K) [text/plain]
Saving to: ‘/tmp/iris.csv’


2023-01-16 02:46:18 (42.4 MB/s) - ‘/tmp/iris.csv’ saved [4550/4550]



In [11]:
 # Read CSV file into dataframe
df = (spark
        .read
        .option("header","false")
        .option("inferSchema", "true")
        .csv("/tmp/iris.csv"))

In [33]:
type(df)

pyspark.sql.dataframe.DataFrame

In [12]:
df.printSchema()

root
 |-- _c0: double (nullable = true)
 |-- _c1: double (nullable = true)
 |-- _c2: double (nullable = true)
 |-- _c3: double (nullable = true)
 |-- _c4: string (nullable = true)



In [13]:
df.show(20,False)

+---+---+---+---+-----------+
|_c0|_c1|_c2|_c3|_c4        |
+---+---+---+---+-----------+
|5.1|3.5|1.4|0.2|Iris-setosa|
|4.9|3.0|1.4|0.2|Iris-setosa|
|4.7|3.2|1.3|0.2|Iris-setosa|
|4.6|3.1|1.5|0.2|Iris-setosa|
|5.0|3.6|1.4|0.2|Iris-setosa|
|5.4|3.9|1.7|0.4|Iris-setosa|
|4.6|3.4|1.4|0.3|Iris-setosa|
|5.0|3.4|1.5|0.2|Iris-setosa|
|4.4|2.9|1.4|0.2|Iris-setosa|
|4.9|3.1|1.5|0.1|Iris-setosa|
|5.4|3.7|1.5|0.2|Iris-setosa|
|4.8|3.4|1.6|0.2|Iris-setosa|
|4.8|3.0|1.4|0.1|Iris-setosa|
|4.3|3.0|1.1|0.1|Iris-setosa|
|5.8|4.0|1.2|0.2|Iris-setosa|
|5.7|4.4|1.5|0.4|Iris-setosa|
|5.4|3.9|1.3|0.4|Iris-setosa|
|5.1|3.5|1.4|0.3|Iris-setosa|
|5.7|3.8|1.7|0.3|Iris-setosa|
|5.1|3.8|1.5|0.3|Iris-setosa|
+---+---+---+---+-----------+
only showing top 20 rows



In [14]:
df.count()

150

In [15]:
#change column name
renamed_df = df.selectExpr("_c0 as sepal_length", "_c1 as sepal_width", 
                           "_c2 as petal_length","_c3 as petal_width","_c4")
renamed_df = renamed_df.withColumnRenamed("_c4","label")

renamed_df.show(5)

print("\n")

renamed_df.printSchema()

+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|      label|
+------------+-----------+------------+-----------+-----------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|
+------------+-----------+------------+-----------+-----------+
only showing top 5 rows



root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- label: string (nullable = true)



In [32]:
#select distinct
renamed_df.select("label").distinct().show()

+---------------+
|          label|
+---------------+
| Iris-virginica|
|    Iris-setosa|
|Iris-versicolor|
+---------------+



In [42]:
from pyspark.sql.functions import col

#filter column
df_filtered = renamed_df.filter(col("label") == "Iris-virginica")
df_filtered.show(5)

print(f"count df_filtered : {df_filtered.count()}")

+------------+-----------+------------+-----------+--------------+
|sepal_length|sepal_width|petal_length|petal_width|         label|
+------------+-----------+------------+-----------+--------------+
|         6.3|        3.3|         6.0|        2.5|Iris-virginica|
|         5.8|        2.7|         5.1|        1.9|Iris-virginica|
|         7.1|        3.0|         5.9|        2.1|Iris-virginica|
|         6.3|        2.9|         5.6|        1.8|Iris-virginica|
|         6.5|        3.0|         5.8|        2.2|Iris-virginica|
+------------+-----------+------------+-----------+--------------+
only showing top 5 rows

count df_filtered : 50


In [25]:
#select column
df_filtered.select("sepal_length", "sepal_width", "label").show(5, False)

+------------+-----------+--------------+
|sepal_length|sepal_width|label         |
+------------+-----------+--------------+
|6.3         |3.3        |Iris-virginica|
|5.8         |2.7        |Iris-virginica|
|7.1         |3.0        |Iris-virginica|
|6.3         |2.9        |Iris-virginica|
|6.5         |3.0        |Iris-virginica|
+------------+-----------+--------------+
only showing top 5 rows



In [43]:
#group by
avg_df = renamed_df.groupBy("label").avg("sepal_length","sepal_width","petal_length","petal_width")
avg_df.show()

+---------------+-----------------+------------------+-----------------+------------------+
|          label|avg(sepal_length)|  avg(sepal_width)|avg(petal_length)|  avg(petal_width)|
+---------------+-----------------+------------------+-----------------+------------------+
| Iris-virginica|6.587999999999998|2.9739999999999998|            5.552|             2.026|
|    Iris-setosa|5.005999999999999|3.4180000000000006|            1.464|0.2439999999999999|
|Iris-versicolor|            5.936|2.7700000000000005|             4.26|1.3259999999999998|
+---------------+-----------------+------------------+-----------------+------------------+



In [44]:
#compute dataframe using sql command via string
renamed_df.createOrReplaceTempView("iris")
all_df = spark.sql("select * from iris")
all_df.show()

+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|      label|
+------------+-----------+------------+-----------+-----------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|
|         5.4|        3.9|         1.7|        0.4|Iris-setosa|
|         4.6|        3.4|         1.4|        0.3|Iris-setosa|
|         5.0|        3.4|         1.5|        0.2|Iris-setosa|
|         4.4|        2.9|         1.4|        0.2|Iris-setosa|
|         4.9|        3.1|         1.5|        0.1|Iris-setosa|
|         5.4|        3.7|         1.5|        0.2|Iris-setosa|
|         4.8|        3.4|         1.6|        0.2|Iris-setosa|
|         4.8|        3.0|         1.4| 