In [None]:
import random
from datetime import datetime, timedelta
from typing import List, Tuple

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import avg

spark = SparkSession.builder.appName("sql").getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/10/14 23:16:40 WARN Utils: Your hostname, kenans-MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 192.168.1.102 instead (on interface en0)
25/10/14 23:16:40 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/14 23:16:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/10/14 23:16:41 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/10/14 23:16:41 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [None]:
def random_date(start: datetime, end: datetime) -> datetime:
    """Generate a random datetime between `start` and `end`"""
    delta = end - start
    random_second = random.randint(0, int(delta.total_seconds()))
    return start + timedelta(seconds=random_second)


def create_sample_data() -> DataFrame:
    """Create a sample DataFrame with retail sales data"""
    products: List[Tuple[str, float]] = [
        ("Apple", 1.0),
        ("Orange", 2.0),
        ("Banana", 3.0),
        ("Kiwi", 4.0),
        ("Grape", 5.0),
        ("Strawberry", 6.0),
        ("Blueberry", 7.0),
        ("Blackberry", 8.0),
        ("Raspberry", 9.0),
        ("Pineapple", 10.0),
    ]
    data: List[Tuple[int, str, int, float, datetime, int]] = [
        (
            i,
            random.choice(products)[0],
            random.randint(1, 10),
            random.choice(products)[1],
            random_date(datetime(2021, 1, 1), datetime(2021, 12, 31)),
            random.randint(1000, 1099),
        )
        for i in range(10)
    ]

    schema = ["transaction_id", "product_name", "quantity", "product_price", "sales_date", "customer_id"]
    return spark.createDataFrame(data, schema=schema)


def calculate_total_sales(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product"""
    return dff.groupBy("product_name").sum("product_price").withColumnRenamed("sum(product_price)", "total_sales")


def calculate_total_sales_by_month(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by month"""
    return dff.groupBy("product_name", "sales_date").sum("product_price").withColumnRenamed("sum(product_price)", "total_sales")


def calculate_total_sales_by_customer_and_month(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by customer and month"""
    return (
        dff.groupBy("product_name", "customer_id", "sales_date")
        .sum("product_price")
        .withColumnRenamed("sum(product_price)", "total_sales")
    )


def calculate_total_sales_by_customer_and_product(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by customer and product"""
    return dff.groupBy("product_name", "customer_id").sum("product_price").withColumnRenamed("sum(product_price)", "total_sales")


def calculate_average_quantity_by_product(dff: DataFrame) -> DataFrame:
    """Calculate average quantity for each product"""
    return dff.groupBy("product_name").avg("quantity").withColumnRenamed("avg(quantity)", "average_quantity")


def calculate_average_quantity(dff: DataFrame) -> DataFrame:
    """Calculate average quantity sold per transaction."""
    return dff.agg(avg("quantity").alias("average_quantity"))


def count_transactions_per_customer(dff: DataFrame) -> DataFrame:
    """Count the number of transactions per customer"""
    return dff.groupBy("customer_id").count().withColumnRenamed("count", "transaction_count")


def calculate_total_sales_by_customer(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each customer"""
    return dff.groupBy("customer_id").sum("product_price").withColumnRenamed("sum(product_price)", "total_sales")

In [3]:
# Creating and Analyzing the Dataset
df = create_sample_data()
df.show()

+--------------+------------+--------+-------------+-------------------+-----------+
|transaction_id|product_name|quantity|product_price|         sales_date|customer_id|
+--------------+------------+--------+-------------+-------------------+-----------+
|             0|  Strawberry|       7|          2.0|2021-04-26 22:12:47|       1096|
|             1|   Raspberry|       1|          4.0|2021-04-14 21:40:41|       1027|
|             2|      Banana|      10|          4.0|2021-07-10 07:40:14|       1031|
|             3|       Grape|       2|          1.0|2021-03-22 21:44:27|       1073|
|             4|  Blackberry|       6|          6.0|2021-12-23 09:11:36|       1092|
|             5|        Kiwi|       9|          5.0|2021-04-30 15:00:14|       1078|
|             6|  Strawberry|      10|          8.0|2021-11-30 17:49:28|       1041|
|             7|   Pineapple|       1|         10.0|2021-10-20 08:48:06|       1017|
|             8|  Blackberry|       2|          4.0|2021-09-23 21

                                                                                

In [4]:
calculate_total_sales(df).show()

+------------+-----------+
|product_name|total_sales|
+------------+-----------+
|  Strawberry|       10.0|
|   Raspberry|        4.0|
|      Banana|        4.0|
|       Grape|        1.0|
|  Blackberry|       10.0|
|        Kiwi|        5.0|
|   Pineapple|       18.0|
+------------+-----------+



In [5]:
calculate_total_sales_by_customer(df).show()

+-----------+-----------+
|customer_id|total_sales|
+-----------+-----------+
|       1096|        2.0|
|       1027|        4.0|
|       1031|        4.0|
|       1073|        1.0|
|       1092|        6.0|
|       1078|        5.0|
|       1041|        8.0|
|       1017|       10.0|
|       1023|        4.0|
|       1089|        8.0|
+-----------+-----------+



In [6]:
calculate_average_quantity(df).show()

+----------------+
|average_quantity|
+----------------+
|             5.7|
+----------------+



In [7]:
calculate_average_quantity_by_product(df).show()

+------------+----------------+
|product_name|average_quantity|
+------------+----------------+
|  Strawberry|             8.5|
|   Raspberry|             1.0|
|      Banana|            10.0|
|       Grape|             2.0|
|  Blackberry|             4.0|
|        Kiwi|             9.0|
|   Pineapple|             5.0|
+------------+----------------+



In [8]:
count_transactions_per_customer(df).show()

+-----------+-----------------+
|customer_id|transaction_count|
+-----------+-----------------+
|       1096|                1|
|       1027|                1|
|       1031|                1|
|       1073|                1|
|       1092|                1|
|       1078|                1|
|       1041|                1|
|       1017|                1|
|       1023|                1|
|       1089|                1|
+-----------+-----------------+

