In [1]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import avg
from datetime import datetime, timedelta
import random
from typing import List, Tuple

spark = SparkSession.builder.appName('sql').getOrCreate()

25/10/14 12:44:45 WARN Utils: Your hostname, kenans-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.0.0.2 instead (on interface en0)
25/10/14 12:44:45 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/14 12:44:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/10/14 12:44:46 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/10/14 12:44:46 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [2]:
def random_date(start: datetime, end: datetime) -> datetime:
    """Generate a random datetime between `start` and `end`"""
    delta = end - start
    random_second = random.randint(0, int(delta.total_seconds()))
    return start + timedelta(seconds=random_second)


def create_sample_data() -> DataFrame:
    """Create a sample DataFrame with retail sales data"""
    products: List[Tuple[str, float]] = [
        ("Apple", 1.0),
        ("Orange", 2.0),
        ("Banana", 3.0),
        ("Kiwi", 4.0),
        ("Grape", 5.0),
        ("Strawberry", 6.0),
        ("Blueberry", 7.0),
        ("Blackberry", 8.0),
        ("Raspberry", 9.0),
        ("Pineapple", 10.0),
    ]
    data: List[Tuple[int, str, int, float, datetime, int]] = [(i,
                                                               random.choice(products)[0],
                                                               random.randint(1, 10),
                                                               random.choice(products)[1],
                                                               random_date(datetime(2021, 1, 1),
                                                                           datetime(2021, 12, 31)),
                                                               random.randint(1000, 1099))
                                                              for i in range(10)]

    schema = [
        "transaction_id",
        "product_name",
        "quantity",
        "product_price",
        "sales_date",
        "customer_id"]
    return spark.createDataFrame(data, schema=schema)


def calculate_total_sales(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product"""
    return dff.groupBy("product_name").sum("product_price").withColumnRenamed("sum(product_price)", "total_sales")


def calculate_total_sales_by_month(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by month"""
    return dff.groupBy("product_name", "sales_date").sum("product_price").withColumnRenamed("sum(product_price)",
                                                                                            "total_sales")


def calculate_total_sales_by_customer_and_month(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by customer and month"""
    return dff.groupBy("product_name", "customer_id", "sales_date").sum("product_price").withColumnRenamed(
        "sum(product_price)", "total_sales")


def calculate_total_sales_by_customer_and_product(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product by customer and product"""
    return dff.groupBy("product_name", "customer_id").sum("product_price").withColumnRenamed("sum(product_price)",
                                                                                             "total_sales")


def calculate_average_quantity_by_product(dff: DataFrame) -> DataFrame:
    """Calculate average quantity for each product"""
    return dff.groupBy("product_name").avg("quantity").withColumnRenamed("avg(quantity)", "average_quantity")


def calculate_average_quantity(dff: DataFrame) -> DataFrame:
    """Calculate average quantity sold per transaction."""
    return dff.agg(avg("quantity").alias("average_quantity"))


def count_transactions_per_customer(dff: DataFrame) -> DataFrame:
    """Count the number of transactions per customer"""
    return dff.groupBy("customer_id").count().withColumnRenamed("count", "transaction_count")


def calculate_total_sales_by_customer(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each customer"""
    return dff.groupBy("customer_id").sum("product_price").withColumnRenamed("sum(product_price)", "total_sales")

In [3]:
# Creating and Analyzing the Dataset
df = create_sample_data()
df.show()

                                                                                

+--------------+------------+--------+-------------+-------------------+-----------+
|transaction_id|product_name|quantity|product_price|         sales_date|customer_id|
+--------------+------------+--------+-------------+-------------------+-----------+
|             0|  Strawberry|       8|          3.0|2021-04-28 15:03:20|       1021|
|             1|   Pineapple|      10|          6.0|2021-11-13 14:10:53|       1039|
|             2|   Pineapple|      10|          4.0|2021-09-26 10:36:22|       1018|
|             3|  Blackberry|      10|          3.0|2021-11-29 11:32:47|       1046|
|             4|        Kiwi|       6|          7.0|2021-12-27 09:34:32|       1010|
|             5|       Grape|       2|          3.0|2021-10-18 10:40:46|       1084|
|             6|  Blackberry|       7|          2.0|2021-01-02 23:56:54|       1012|
|             7|  Blackberry|       8|          6.0|2021-01-25 01:47:02|       1079|
|             8|   Pineapple|       7|          7.0|2021-12-15 20

In [4]:
calculate_total_sales(df).show()

+------------+-----------+
|product_name|total_sales|
+------------+-----------+
|  Strawberry|        3.0|
|   Pineapple|       17.0|
|  Blackberry|       11.0|
|        Kiwi|        7.0|
|       Grape|        3.0|
|   Raspberry|        5.0|
+------------+-----------+



In [5]:
calculate_total_sales_by_customer(df).show()

+-----------+-----------+
|customer_id|total_sales|
+-----------+-----------+
|       1021|        3.0|
|       1039|        6.0|
|       1018|        4.0|
|       1046|        3.0|
|       1010|        7.0|
|       1084|        3.0|
|       1012|        2.0|
|       1079|        6.0|
|       1058|        7.0|
|       1092|        5.0|
+-----------+-----------+



In [6]:
calculate_average_quantity(df).show()

+----------------+
|average_quantity|
+----------------+
|             6.9|
+----------------+



In [7]:
calculate_average_quantity_by_product(df).show()

+------------+-----------------+
|product_name| average_quantity|
+------------+-----------------+
|  Strawberry|              8.0|
|   Pineapple|              9.0|
|  Blackberry|8.333333333333334|
|        Kiwi|              6.0|
|       Grape|              2.0|
|   Raspberry|              1.0|
+------------+-----------------+



In [8]:
count_transactions_per_customer(df).show()

+-----------+-----------------+
|customer_id|transaction_count|
+-----------+-----------------+
|       1021|                1|
|       1039|                1|
|       1018|                1|
|       1046|                1|
|       1010|                1|
|       1084|                1|
|       1012|                1|
|       1079|                1|
|       1058|                1|
|       1092|                1|
+-----------+-----------------+

