In [1]:
import sys
import random
import csv

In [2]:
def get_random_choice(lst):
    return random.choice(lst)

In [3]:
states = ["CA", "WA", "TX", "NV", "CO", "OR", "AZ", "WY", "NM", "UT"]
colors = ["Brown", "Blue", "Orange", "Yellow", "Green", "Red"]
fieldnames = ['State', 'Color', 'Count']


entries = 10000
dataset_fn = "mnm_dataset.csv"

with open(dataset_fn, mode='w') as dataset_file:
    dataset_writer = csv.writer(dataset_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    dataset_writer.writerow(fieldnames)
    for i in range(1, entries):
        dataset_writer.writerow([get_random_choice(states), get_random_choice(colors), random.randint(10, 100)])
print("Wrote %d lines in %s file" % (entries, dataset_fn))

Wrote 10000 lines in mnm_dataset.csv file


In [5]:
from pyspark.sql import SparkSession

In [6]:
spark = (SparkSession
    .builder
    .appName("PythonMnMCount")
    .getOrCreate())


# read the file into a Spark DataFrame
mnm_df = (spark.read.format("csv")
    .option("header", "true")
    .option("inferSchema", "true")
    .load(dataset_fn))

mnm_df.show(n=5, truncate=False)

# aggregate count of all colors and groupBy state and color
# orderBy descending order
count_mnm_df = (mnm_df.select("State", "Color", "Count")
                .groupBy(["State", "Color"])
                .sum("Count")
                .orderBy("sum(Count)", ascending=False))

# show all the resulting aggregation for all the dates and colors
count_mnm_df.show(n=60, truncate=False)
print("Total Rows = %d" % (count_mnm_df.count()))

# find the aggregate count for California by filtering
ca_count_mnm_df = (mnm_df.select("*")
                   .where(mnm_df.State == 'CA')
                   .groupBy("State", "Color")
                   .sum("Count")
                   .orderBy("sum(Count)", ascending=False))

# show the resulting aggregation for California
ca_count_mnm_df.show(n=10, truncate=False)

22/06/07 14:09:26 WARN Utils: Your hostname, Pauls-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.4.90 instead (on interface en0)
22/06/07 14:09:26 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/06/07 14:09:26 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


+-----+------+-----+
|State|Color |Count|
+-----+------+-----+
|NV   |Orange|47   |
|CO   |Brown |80   |
|WY   |Blue  |40   |
|WA   |Brown |55   |
|WA   |Orange|44   |
+-----+------+-----+
only showing top 5 rows

+-----+------+----------+
|State|Color |sum(Count)|
+-----+------+----------+
|AZ   |Blue  |10938     |
|WA   |Blue  |10789     |
|CO   |Brown |10567     |
|WY   |Brown |10419     |
|WA   |Yellow|10150     |
|UT   |Red   |10130     |
|NV   |Orange|10079     |
|AZ   |Orange|10060     |
|TX   |Blue  |10023     |
|TX   |Green |9984      |
|NM   |Red   |9970      |
|WY   |Green |9959      |
|NM   |Yellow|9955      |
|NM   |Blue  |9874      |
|OR   |Blue  |9720      |
|NV   |Yellow|9666      |
|TX   |Brown |9597      |
|WY   |Blue  |9572      |
|AZ   |Green |9561      |
|CO   |Yellow|9540      |
|TX   |Yellow|9531      |
|CO   |Blue  |9497      |
|OR   |Red   |9429      |
|CA   |Blue  |9422      |
|WY   |Red   |9369      |
|NV   |Blue  |9331      |
|OR   |Yellow|9202      |
|CA   