In [None]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import col, sum, to_date, collect_list, size
from datetime import datetime, timedelta
import random
from typing import List, Tuple

spark = SparkSession.builder.appName("sql").getOrCreate()

In [119]:
class ColumnNames:
    TRANSACTION_ID = "transaction_id"
    PRODUCT_NAME = "product_name"
    QUANTITY = "quantity"
    PRODUCT_PRICE = "product_price"
    SALES_DATE = "sales_date"
    CUSTOMER_ID = "customer_id"
    TOTAL_PRICE = "total_price"
    VAT = "vat"
    SUB_TOTAL = "sub_total"
    SUM_OF_VAT = "sum_of_vat"
    TOTAL_DUE = "total_due"
    TOTAL_SALES = "total_sales"
    PRODUCT_NAMES = "product_names"
    COUNT = "count"


def random_date(start: datetime, end: datetime) -> datetime:
    """Generate a random datetime between 'start' and 'end'"""
    delta = end - start
    random_second = random.randint(0, int(delta.total_seconds()))
    return start + timedelta(seconds=random_second)


def generate_data() -> DataFrame:
    """Generate random retail sales data"""
    products: List[Tuple[str, float]] = [
        ("Apple", 1.0),
        ("Orange", 2.0),
        ("Banana", 3.0),
        ("Kiwi", 4.0),
        ("Grape", 5.0),
        ("Strawberry", 6.0),
        ("Blueberry", 7.0),
        ("Blackberry", 8.0),
        ("Raspberry", 9.0),
        ("Pineapple", 10.0),
    ]
    random_dates: List[datetime] = [
        random_date(datetime(2020, 1, 1), datetime(2020, 12, 31)) for _ in range(5)
    ]
    data: List[Tuple[int, str, int, float, datetime, int]] = []
    for i in range(10):
        product_name, product_price = random.choice(products)
        quantity = random.randint(1, 10)
        sales_date = random.choice(random_dates)
        customer_id = random.randint(1000, 1099)
        data.append((i, product_name, quantity, product_price, sales_date, customer_id))

    schema = [
        ColumnNames.TRANSACTION_ID,
        ColumnNames.PRODUCT_NAME,
        ColumnNames.QUANTITY,
        ColumnNames.PRODUCT_PRICE,
        ColumnNames.SALES_DATE,
        ColumnNames.CUSTOMER_ID,
    ]

    return spark.createDataFrame(data, schema=schema)


def with_total_price(dff: DataFrame) -> DataFrame:
    """Add a new column to the dataframe with the total price"""
    return dff.withColumn(
        ColumnNames.TOTAL_PRICE,
        col(ColumnNames.QUANTITY) * col(ColumnNames.PRODUCT_PRICE),
    )


def with_vat(dff: DataFrame, rate: int = 0.2) -> DataFrame:
    """Add a new column to the dataframe with the vat"""
    return dff.withColumn(ColumnNames.VAT, col(ColumnNames.TOTAL_PRICE) * rate)


def calculate_sub_total(dff: DataFrame) -> float:
    """Calculate the sub total"""
    return dff.select(
        sum(col(ColumnNames.TOTAL_PRICE)).alias(ColumnNames.SUB_TOTAL)
    ).first()[ColumnNames.SUB_TOTAL]


def calculate_total_vat(dff: DataFrame) -> float:
    """Calculate the total vat"""
    return dff.select(sum(col(ColumnNames.VAT)).alias(ColumnNames.SUM_OF_VAT)).first()[
        ColumnNames.SUM_OF_VAT
    ]


def calculate_total_due(dff: DataFrame) -> float:
    return dff.select(
        sum(col(ColumnNames.TOTAL_PRICE) + col(ColumnNames.VAT)).alias(
            ColumnNames.TOTAL_DUE
        )
    ).first()[ColumnNames.TOTAL_DUE]


def calculate_total_sales(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product"""
    return (
        dff.groupBy(ColumnNames.PRODUCT_NAME)
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
    )


def calculate_total_sales_by_sales_date(dff: DataFrame) -> DataFrame:
    return (
        dff.groupBy(to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE))
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
        .sort(ColumnNames.SALES_DATE)
    )


def calculate_total_sales_by_sales_date_and_product_name(dff: DataFrame) -> DataFrame:
    return (
        dff.groupBy(
            to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE),
            ColumnNames.PRODUCT_NAME,
        )
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
    )


def calculate_total_sales_by_sales_date_and_product_name_and_customer_id(
    dff: DataFrame,
) -> DataFrame:
    return (
        dff.groupBy(
            to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE),
            ColumnNames.PRODUCT_NAME,
            ColumnNames.CUSTOMER_ID,
        )
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
    )


def list_product_names_by_sales_date(dff: DataFrame) -> DataFrame:
    """List product names grouped by date"""
    return (
        dff.groupBy(to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE))
        .agg(collect_list(ColumnNames.PRODUCT_NAME).alias(ColumnNames.PRODUCT_NAMES))
        .orderBy(ColumnNames.SALES_DATE)
    )


def list_product_names_and_counts_by_sales_date(dff: DataFrame) -> DataFrame:
    """List product names and counts grouped by date"""
    return (
        dff.groupBy(
            to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE),
            ColumnNames.PRODUCT_NAME,
        )
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
        .orderBy(ColumnNames.SALES_DATE)
    )


def list_product_names_by_sales_date_v2(dff: DataFrame) -> DataFrame:
    """List product names grouped by date"""
    return (
        dff.groupBy(to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE))
        .agg(
            collect_list(ColumnNames.PRODUCT_NAME).alias(ColumnNames.PRODUCT_NAMES),
            sum(ColumnNames.QUANTITY).alias(ColumnNames.TOTAL_SALES),
        )
        .withColumn(ColumnNames.COUNT, size(ColumnNames.PRODUCT_NAMES))
        .orderBy(ColumnNames.SALES_DATE)
        .select(
            ColumnNames.SALES_DATE,
            ColumnNames.PRODUCT_NAMES,
            ColumnNames.COUNT,
            ColumnNames.TOTAL_SALES,
        )
    )


df = generate_data()

df_extended = df.transform(with_total_price).transform(lambda dff: with_vat(dff, 0.2))

In [53]:
df_extended.sort(ColumnNames.SALES_DATE, ascending=True).show(truncate=False)

+--------------+------------+--------+-------------+-------------------+-----------+-----------+------------------+
|transaction_id|product_name|quantity|product_price|sales_date         |customer_id|total_price|vat               |
+--------------+------------+--------+-------------+-------------------+-----------+-----------+------------------+
|3             |Orange      |5       |2.0          |2020-02-22 09:46:37|1039       |10.0       |2.0               |
|9             |Blackberry  |9       |8.0          |2020-02-22 09:46:37|1083       |72.0       |14.4              |
|4             |Orange      |9       |2.0          |2020-04-10 01:51:23|1011       |18.0       |3.6               |
|6             |Kiwi        |2       |4.0          |2020-07-16 07:43:04|1029       |8.0        |1.6               |
|1             |Raspberry   |8       |9.0          |2020-07-16 07:43:04|1010       |72.0       |14.4              |
|7             |Blueberry   |6       |7.0          |2020-07-16 07:43:04|

In [54]:
calculate_total_sales(df_extended).show(truncate=False)

                                                                                

+------------+-----------+
|product_name|total_sales|
+------------+-----------+
|Grape       |6          |
|Orange      |14         |
|Raspberry   |8          |
|Strawberry  |2          |
|Kiwi        |8          |
|Blackberry  |9          |
|Blueberry   |6          |
+------------+-----------+


In [57]:
calculate_total_sales_by_sales_date(df_extended).show(truncate=False)

                                                                                

+----------+-----------+
|sales_date|total_sales|
+----------+-----------+
|2020-02-06|11         |
|2020-03-18|5          |
|2020-04-19|12         |
|2020-05-07|10         |
|2020-12-02|15         |
+----------+-----------+


In [58]:
calculate_total_sales_by_sales_date_and_product_name(df_extended).show(truncate=False)

+----------+------------+-----------+
|sales_date|product_name|total_sales|
+----------+------------+-----------+
|2020-05-07|Pineapple   |8          |
|2020-12-02|Blackberry  |7          |
|2020-04-19|Apple       |4          |
|2020-04-19|Banana      |8          |
|2020-03-18|Grape       |5          |
|2020-12-02|Orange      |8          |
|2020-02-06|Blueberry   |11         |
|2020-05-07|Orange      |2          |
+----------+------------+-----------+


In [59]:
calculate_total_sales_by_sales_date_and_product_name_and_customer_id(df_extended).show()

+----------+------------+-----------+-----------+
|sales_date|product_name|customer_id|total_sales|
+----------+------------+-----------+-----------+
|2020-12-02|      Orange|       1094|          2|
|2020-05-07|      Orange|       1063|          2|
|2020-05-07|   Pineapple|       1076|          8|
|2020-04-19|       Apple|       1098|          4|
|2020-12-02|  Blackberry|       1034|          7|
|2020-03-18|       Grape|       1076|          5|
|2020-04-19|      Banana|       1081|          8|
|2020-02-06|   Blueberry|       1069|          9|
|2020-12-02|      Orange|       1029|          6|
|2020-02-06|   Blueberry|       1067|          2|
+----------+------------+-----------+-----------+


In [60]:
sub_total = calculate_sub_total(df_extended)
total_vat = calculate_total_vat(df_extended)
total_due = calculate_total_due(df_extended)

print(f"Sub Total: {sub_total}")
print(f"Total vat: {total_vat}")
print(f"Total due: {total_due}")

Sub Total: 286.0
Total vat: 57.2
Total due: 343.20000000000005


In [73]:
list_product_names_by_sales_date(df_extended).show(truncate=False)

[Stage 402:>                                                      (0 + 10) / 10]

+----------+-----------------------------------------+
|sales_date|product_names                            |
+----------+-----------------------------------------+
|2020-05-08|[Strawberry, Raspberry, Kiwi, Strawberry]|
|2020-05-26|[Orange, Banana, Blackberry, Blueberry]  |
|2020-07-22|[Banana]                                 |
|2020-08-21|[Grape]                                  |
+----------+-----------------------------------------+


                                                                                

In [75]:
list_product_names_and_counts_by_sales_date(df_extended).show(truncate=False)



+----------+------------+-----------+
|sales_date|product_name|total_sales|
+----------+------------+-----------+
|2020-05-08|Strawberry  |13         |
|2020-05-08|Kiwi        |9          |
|2020-05-08|Raspberry   |3          |
|2020-05-26|Blackberry  |7          |
|2020-05-26|Banana      |10         |
|2020-05-26|Orange      |1          |
|2020-05-26|Blueberry   |8          |
|2020-07-22|Banana      |3          |
|2020-08-21|Grape       |4          |
+----------+------------+-----------+


                                                                                

In [120]:
list_product_names_by_sales_date_v2(df_extended).show(truncate=False)

+----------+-------------------------------------+-----+-----------+
|sales_date|product_names                        |count|total_sales|
+----------+-------------------------------------+-----+-----------+
|2020-02-09|[Blueberry, Banana]                  |2    |6          |
|2020-06-15|[Raspberry, Blackberry]              |2    |9          |
|2020-06-28|[Pineapple]                          |1    |5          |
|2020-07-19|[Grape, Strawberry, Apple, Raspberry]|4    |13         |
|2020-08-13|[Blackberry]                         |1    |10         |
+----------+-------------------------------------+-----+-----------+
