In [1]:
import pyspark
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("test2").\
config("spark.driver.bindAddress","localhost").\
config("spark.ui.port","4050").\
getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/28 10:16:39 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
pyspark.__version__

'3.5.4'

In [3]:
import os
print(os.getcwd())

/Users/pvasud669@apac.comcast.com/repos/learnings/spark


In [4]:

mnm_file = os.path.join(os.getcwd(), "datasets", "mnm_dataset.csv")
mnm_dataset = spark.read.format("csv")\
                    .option("header","true")\
                    .option("inferSchema", "true")\
                    .load(mnm_file)

In [5]:
mnm_dataset.show(10, truncate=False)

+-----+------+-----+
|State|Color |Count|
+-----+------+-----+
|TX   |Red   |20   |
|NV   |Blue  |66   |
|CO   |Blue  |79   |
|OR   |Blue  |71   |
|WA   |Yellow|93   |
|WY   |Blue  |16   |
|CA   |Yellow|53   |
|WA   |Green |60   |
|OR   |Green |71   |
|TX   |Green |68   |
+-----+------+-----+
only showing top 10 rows



In [6]:
mnm_dataset.printSchema()

root
 |-- State: string (nullable = true)
 |-- Color: string (nullable = true)
 |-- Count: integer (nullable = true)



In [7]:
# group data based on state and color
from pyspark.sql.functions import count, sum

agg_data = mnm_dataset.select("State","Color","Count")\
                    .groupBy("State","Color")\
                    .agg(sum("Count").alias("Total_count"))\
                    .orderBy("State", "Color")

In [8]:
agg_data.show()

+-----+------+-----------+
|State| Color|Total_count|
+-----+------+-----------+
|   AZ|  Blue|      89971|
|   AZ| Brown|      92287|
|   AZ| Green|      91882|
|   AZ|Orange|      91684|
|   AZ|   Red|      90042|
|   AZ|Yellow|      90946|
|   CA|  Blue|      89123|
|   CA| Brown|      95762|
|   CA| Green|      93505|
|   CA|Orange|      90311|
|   CA|   Red|      91527|
|   CA|Yellow|     100956|
|   CO|  Blue|      93412|
|   CO| Brown|      93692|
|   CO| Green|      93724|
|   CO|Orange|      90971|
|   CO|   Red|      89465|
|   CO|Yellow|      95038|
|   NM|  Blue|      90150|
|   NM| Brown|      93447|
+-----+------+-----------+
only showing top 20 rows



In [9]:
# show details only for TX
tx_agg_data = agg_data.where(agg_data.State == "TX").orderBy("Total_count", ascending=False)

In [10]:
tx_agg_data.show()

+-----+------+-----------+
|State| Color|Total_count|
+-----+------+-----------+
|   TX| Green|      95753|
|   TX|   Red|      95404|
|   TX|Yellow|      93819|
|   TX|Orange|      92315|
|   TX| Brown|      90736|
|   TX|  Blue|      88466|
+-----+------+-----------+



In [11]:
tx_agg_data.show(1)

+-----+-----+-----------+
|State|Color|Total_count|
+-----+-----+-----------+
|   TX|Green|      95753|
+-----+-----+-----------+
only showing top 1 row

