In [9]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import avg
from datetime import datetime, timedelta
import random
from typing import List, Tuple

spark = SparkSession.builder.appName('sql').getOrCreate()

23/11/15 09:41:48 WARN Utils: Your hostname, kenans-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 172.20.10.2 instead (on interface en0)
23/11/15 09:41:48 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/11/15 09:41:49 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/15 09:41:50 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [10]:
def random_date(start: datetime, end: datetime) -> datetime:
    """Generate a random datetime between `start` and `end`"""
    delta = end - start
    random_second = random.randint(0, int(delta.total_seconds()))
    return start + timedelta(seconds=random_second)


def create_sample_data() -> DataFrame:
    """Create a sample DataFrame with retail sales data"""
    products: List[Tuple[str, float]] = [
        ("Apple", 1.0),
        ("Orange", 2.0),
        ("Banana", 3.0),
        ("Kiwi", 4.0),
        ("Grape", 5.0),
        ("Strawberry", 6.0),
        ("Blueberry", 7.0),
        ("Blackberry", 8.0),
        ("Raspberry", 9.0),
        ("Pineapple", 10.0),
    ]
    data: List[Tuple[int, str, int, float, datetime, int]] = [(i,
                                                               random.choice(products)[0],
                                                               random.randint(1, 10),
                                                               random.choice(products)[1],
                                                               random_date(datetime(2021, 1, 1),
                                                                           datetime(2021, 12, 31)),
                                                               random.randint(1000, 1099))
                                                              for i in range(10)]

    schema = [
        "transaction_id",
        "product_name",
        "quantity",
        "product_price",
        "sales_date",
        "customer_id"]
    return spark.createDataFrame(data, schema=schema)


def calculate_total_sales(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product"""
    return dff.groupBy("product_name").sum("product_price").withColumnRenamed("sum(product_price)", "total_sales")


def calculate_total_sales_by_month(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by month"""
    return dff.groupBy("product_name", "sales_date").sum("product_price").withColumnRenamed("sum(product_price)",
                                                                                            "total_sales")


def calculate_total_sales_by_customer_and_month(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by customer and month"""
    return dff.groupBy("product_name", "customer_id", "sales_date").sum("product_price").withColumnRenamed(
        "sum(product_price)", "total_sales")


def calculate_total_sales_by_customer_and_product(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by customer and product"""
    return dff.groupBy("product_name", "customer_id").sum("product_price").withColumnRenamed("sum(product_price)",
                                                                                             "total_sales")


def calculate_average_quantity_by_product(dff: DataFrame) -> DataFrame:
    """Calculate average quantity for each product"""
    return dff.groupBy("product_name").avg("quantity").withColumnRenamed("avg(quantity)", "average_quantity")


def calculate_average_quantity(dff: DataFrame) -> DataFrame:
    """Calculate average quantity sold per transaction."""
    return dff.agg(avg("quantity").alias("average_quantity"))


def count_transactions_per_customer(dff: DataFrame) -> DataFrame:
    """Count the number of transactions per customer"""
    return dff.groupBy("customer_id").count().withColumnRenamed("count", "transaction_count")


def calculate_total_sales_by_customer(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each customer"""
    return dff.groupBy("customer_id").sum("product_price").withColumnRenamed("sum(product_price)", "total_sales")

In [18]:
# Creating and Analyzing the Dataset
df = create_sample_data()
df.show()

+--------------+------------+--------+-------------+-------------------+-----------+
|transaction_id|product_name|quantity|product_price|         sales_date|customer_id|
+--------------+------------+--------+-------------+-------------------+-----------+
|             0|  Blackberry|       5|          3.0|2021-07-23 14:47:55|       1049|
|             1|  Strawberry|       8|          3.0|2021-02-01 03:33:24|       1060|
|             2|       Apple|       8|          2.0|2021-09-11 15:01:31|       1033|
|             3|       Apple|       3|          9.0|2021-04-19 08:54:12|       1099|
|             4|        Kiwi|       6|          8.0|2021-08-20 10:19:57|       1024|
|             5|      Banana|       9|          4.0|2021-03-16 11:30:43|       1020|
|             6|        Kiwi|       9|          4.0|2021-07-11 01:52:49|       1073|
|             7|       Apple|       6|          8.0|2021-07-07 06:32:21|       1096|
|             8|        Kiwi|       7|          8.0|2021-07-15 07

In [19]:
calculate_total_sales(df).show()

+------------+-----------+
|product_name|total_sales|
+------------+-----------+
|      Banana|        4.0|
|  Strawberry|        3.0|
|        Kiwi|       30.0|
|  Blackberry|        3.0|
|       Apple|       19.0|
+------------+-----------+


In [20]:
calculate_total_sales_by_customer(df).show()

+-----------+-----------+
|customer_id|total_sales|
+-----------+-----------+
|       1073|        4.0|
|       1099|        9.0|
|       1060|        3.0|
|       1024|        8.0|
|       1000|        8.0|
|       1096|        8.0|
|       1020|        4.0|
|       1049|        3.0|
|       1017|       10.0|
|       1033|        2.0|
+-----------+-----------+


In [21]:
calculate_average_quantity(df).show()

+----------------+
|average_quantity|
+----------------+
|             6.6|
+----------------+


In [22]:
calculate_average_quantity_by_product(df).show()

+------------+-----------------+
|product_name| average_quantity|
+------------+-----------------+
|      Banana|              9.0|
|  Strawberry|              8.0|
|        Kiwi|             6.75|
|  Blackberry|              5.0|
|       Apple|5.666666666666667|
+------------+-----------------+


In [23]:
count_transactions_per_customer(df).show()

+-----------+-----------------+
|customer_id|transaction_count|
+-----------+-----------------+
|       1073|                1|
|       1099|                1|
|       1060|                1|
|       1024|                1|
|       1000|                1|
|       1096|                1|
|       1020|                1|
|       1049|                1|
|       1017|                1|
|       1033|                1|
+-----------+-----------------+
