In [1]:
import random
from datetime import datetime, timedelta
from typing import List, Tuple

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col, collect_list, size, sum, to_date

spark = SparkSession.builder.appName("sql").getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/10/14 23:16:19 WARN Utils: Your hostname, kenans-MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 192.168.1.102 instead (on interface en0)
25/10/14 23:16:19 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/14 23:16:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/10/14 23:16:20 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
class ColumnNames:
    TRANSACTION_ID = "transaction_id"
    PRODUCT_NAME = "product_name"
    QUANTITY = "quantity"
    PRODUCT_PRICE = "product_price"
    SALES_DATE = "sales_date"
    CUSTOMER_ID = "customer_id"
    TOTAL_PRICE = "total_price"
    VAT = "vat"
    SUB_TOTAL = "sub_total"
    SUM_OF_VAT = "sum_of_vat"
    TOTAL_DUE = "total_due"
    TOTAL_SALES = "total_sales"
    PRODUCT_NAMES = "product_names"
    COUNT = "count"


def random_date(start: datetime, end: datetime) -> datetime:
    """Generate a random datetime between 'start' and 'end'"""
    delta = end - start
    random_second = random.randint(0, int(delta.total_seconds()))
    return start + timedelta(seconds=random_second)


def generate_data() -> DataFrame:
    """Generate random retail sales data"""
    products: List[Tuple[str, float]] = [
        ("Apple", 1.0),
        ("Orange", 2.0),
        ("Banana", 3.0),
        ("Kiwi", 4.0),
        ("Grape", 5.0),
        ("Strawberry", 6.0),
        ("Blueberry", 7.0),
        ("Blackberry", 8.0),
        ("Raspberry", 9.0),
        ("Pineapple", 10.0),
    ]
    random_dates: List[datetime] = [random_date(datetime(2020, 1, 1), datetime(2020, 12, 31)) for _ in range(5)]
    data: List[Tuple[int, str, int, float, datetime, int]] = []
    for i in range(10):
        product_name, product_price = random.choice(products)
        quantity = random.randint(1, 10)
        sales_date = random.choice(random_dates)
        customer_id = random.randint(1000, 1099)
        data.append((i, product_name, quantity, product_price, sales_date, customer_id))

    schema = [
        ColumnNames.TRANSACTION_ID,
        ColumnNames.PRODUCT_NAME,
        ColumnNames.QUANTITY,
        ColumnNames.PRODUCT_PRICE,
        ColumnNames.SALES_DATE,
        ColumnNames.CUSTOMER_ID,
    ]

    return spark.createDataFrame(data, schema=schema)


def with_total_price(dff: DataFrame) -> DataFrame:
    """Add a new column to the dataframe with the total price"""
    return dff.withColumn(
        ColumnNames.TOTAL_PRICE,
        col(ColumnNames.QUANTITY) * col(ColumnNames.PRODUCT_PRICE),
    )


def with_vat(dff: DataFrame, rate: int = 0.2) -> DataFrame:
    """Add a new column to the dataframe with the vat"""
    return dff.withColumn(ColumnNames.VAT, col(ColumnNames.TOTAL_PRICE) * rate)


def calculate_sub_total(dff: DataFrame) -> float:
    """Calculate the sub total"""
    return dff.select(sum(col(ColumnNames.TOTAL_PRICE)).alias(ColumnNames.SUB_TOTAL)).first()[ColumnNames.SUB_TOTAL]


def calculate_total_vat(dff: DataFrame) -> float:
    """Calculate the total vat"""
    return dff.select(sum(col(ColumnNames.VAT)).alias(ColumnNames.SUM_OF_VAT)).first()[ColumnNames.SUM_OF_VAT]


def calculate_total_due(dff: DataFrame) -> float:
    return dff.select(sum(col(ColumnNames.TOTAL_PRICE) + col(ColumnNames.VAT)).alias(ColumnNames.TOTAL_DUE)).first()[
        ColumnNames.TOTAL_DUE
    ]


def calculate_total_sales(dff: DataFrame) -> DataFrame:
    """Calculate total sales for each product"""
    return (
        dff.groupBy(ColumnNames.PRODUCT_NAME)
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
    )


def calculate_total_sales_by_sales_date(dff: DataFrame) -> DataFrame:
    return (
        dff.groupBy(to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE))
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
        .sort(ColumnNames.SALES_DATE)
    )


def calculate_total_sales_by_sales_date_and_product_name(dff: DataFrame) -> DataFrame:
    return (
        dff.groupBy(
            to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE),
            ColumnNames.PRODUCT_NAME,
        )
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
    )


def calculate_total_sales_by_sales_date_and_product_name_and_customer_id(
    dff: DataFrame,
) -> DataFrame:
    return (
        dff.groupBy(
            to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE),
            ColumnNames.PRODUCT_NAME,
            ColumnNames.CUSTOMER_ID,
        )
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
    )


def list_product_names_by_sales_date(dff: DataFrame) -> DataFrame:
    """List product names grouped by date"""
    return (
        dff.groupBy(to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE))
        .agg(collect_list(ColumnNames.PRODUCT_NAME).alias(ColumnNames.PRODUCT_NAMES))
        .orderBy(ColumnNames.SALES_DATE)
    )


def list_product_names_and_counts_by_sales_date(dff: DataFrame) -> DataFrame:
    """List product names and counts grouped by date"""
    return (
        dff.groupBy(
            to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE),
            ColumnNames.PRODUCT_NAME,
        )
        .sum(ColumnNames.QUANTITY)
        .withColumnRenamed("sum(quantity)", ColumnNames.TOTAL_SALES)
        .orderBy(ColumnNames.SALES_DATE)
    )


def list_product_names_by_sales_date_v2(dff: DataFrame) -> DataFrame:
    """List product names grouped by date"""
    return (
        dff.groupBy(to_date(ColumnNames.SALES_DATE).alias(ColumnNames.SALES_DATE))
        .agg(
            collect_list(ColumnNames.PRODUCT_NAME).alias(ColumnNames.PRODUCT_NAMES),
            sum(ColumnNames.QUANTITY).alias(ColumnNames.TOTAL_SALES),
        )
        .withColumn(ColumnNames.COUNT, size(ColumnNames.PRODUCT_NAMES))
        .orderBy(ColumnNames.SALES_DATE)
        .select(
            ColumnNames.SALES_DATE,
            ColumnNames.PRODUCT_NAMES,
            ColumnNames.COUNT,
            ColumnNames.TOTAL_SALES,
        )
    )


df = generate_data()

df_extended = df.transform(with_total_price).transform(lambda dff: with_vat(dff, 0.2))

In [3]:
df_extended.sort(ColumnNames.SALES_DATE, ascending=True).show(truncate=False)

[Stage 0:>                                                        (0 + 16) / 16]

+--------------+------------+--------+-------------+-------------------+-----------+-----------+------------------+
|transaction_id|product_name|quantity|product_price|sales_date         |customer_id|total_price|vat               |
+--------------+------------+--------+-------------+-------------------+-----------+-----------+------------------+
|2             |Raspberry   |2       |9.0          |2020-03-17 00:03:27|1005       |18.0       |3.6               |
|3             |Strawberry  |1       |6.0          |2020-03-17 00:03:27|1070       |6.0        |1.2000000000000002|
|0             |Grape       |4       |5.0          |2020-06-24 20:26:27|1061       |20.0       |4.0               |
|5             |Pineapple   |4       |10.0         |2020-06-24 20:26:27|1007       |40.0       |8.0               |
|8             |Grape       |8       |5.0          |2020-07-22 17:11:00|1090       |40.0       |8.0               |
|1             |Pineapple   |3       |10.0         |2020-07-22 17:11:00|

                                                                                

In [4]:
calculate_total_sales(df_extended).show(truncate=False)

+------------+-----------+
|product_name|total_sales|
+------------+-----------+
|Grape       |12         |
|Pineapple   |7          |
|Raspberry   |2          |
|Strawberry  |1          |
|Banana      |3          |
|Orange      |3          |
|Kiwi        |3          |
|Apple       |1          |
+------------+-----------+



In [5]:
calculate_total_sales_by_sales_date(df_extended).show(truncate=False)

+----------+-----------+
|sales_date|total_sales|
+----------+-----------+
|2020-03-17|3          |
|2020-06-24|8          |
|2020-07-22|15         |
|2020-08-12|3          |
|2020-12-17|3          |
+----------+-----------+



In [6]:
calculate_total_sales_by_sales_date_and_product_name(df_extended).show(truncate=False)

+----------+------------+-----------+
|sales_date|product_name|total_sales|
+----------+------------+-----------+
|2020-06-24|Grape       |4          |
|2020-07-22|Pineapple   |3          |
|2020-03-17|Raspberry   |2          |
|2020-03-17|Strawberry  |1          |
|2020-07-22|Banana      |3          |
|2020-06-24|Pineapple   |4          |
|2020-08-12|Orange      |3          |
|2020-12-17|Kiwi        |3          |
|2020-07-22|Grape       |8          |
|2020-07-22|Apple       |1          |
+----------+------------+-----------+



In [7]:
calculate_total_sales_by_sales_date_and_product_name_and_customer_id(df_extended).show()

+----------+------------+-----------+-----------+
|sales_date|product_name|customer_id|total_sales|
+----------+------------+-----------+-----------+
|2020-06-24|       Grape|       1061|          4|
|2020-07-22|   Pineapple|       1029|          3|
|2020-03-17|   Raspberry|       1005|          2|
|2020-03-17|  Strawberry|       1070|          1|
|2020-07-22|      Banana|       1041|          3|
|2020-06-24|   Pineapple|       1007|          4|
|2020-08-12|      Orange|       1001|          3|
|2020-12-17|        Kiwi|       1089|          3|
|2020-07-22|       Grape|       1090|          8|
|2020-07-22|       Apple|       1043|          1|
+----------+------------+-----------+-----------+



In [8]:
sub_total = calculate_sub_total(df_extended)
total_vat = calculate_total_vat(df_extended)
total_due = calculate_total_due(df_extended)

print(f"Sub Total: {sub_total}")
print(f"Total vat: {total_vat}")
print(f"Total due: {total_due}")

Sub Total: 182.0
Total vat: 36.400000000000006
Total due: 218.39999999999998


In [9]:
list_product_names_by_sales_date(df_extended).show(truncate=False)

+----------+---------------------------------+
|sales_date|product_names                    |
+----------+---------------------------------+
|2020-03-17|[Raspberry, Strawberry]          |
|2020-06-24|[Grape, Pineapple]               |
|2020-07-22|[Pineapple, Banana, Grape, Apple]|
|2020-08-12|[Orange]                         |
|2020-12-17|[Kiwi]                           |
+----------+---------------------------------+



In [10]:
list_product_names_and_counts_by_sales_date(df_extended).show(truncate=False)

+----------+------------+-----------+
|sales_date|product_name|total_sales|
+----------+------------+-----------+
|2020-03-17|Raspberry   |2          |
|2020-03-17|Strawberry  |1          |
|2020-06-24|Grape       |4          |
|2020-06-24|Pineapple   |4          |
|2020-07-22|Pineapple   |3          |
|2020-07-22|Banana      |3          |
|2020-07-22|Grape       |8          |
|2020-07-22|Apple       |1          |
|2020-08-12|Orange      |3          |
|2020-12-17|Kiwi        |3          |
+----------+------------+-----------+



In [11]:
list_product_names_by_sales_date_v2(df_extended).show(truncate=False)

+----------+---------------------------------+-----+-----------+
|sales_date|product_names                    |count|total_sales|
+----------+---------------------------------+-----+-----------+
|2020-03-17|[Raspberry, Strawberry]          |2    |3          |
|2020-06-24|[Grape, Pineapple]               |2    |8          |
|2020-07-22|[Pineapple, Banana, Grape, Apple]|4    |15         |
|2020-08-12|[Orange]                         |1    |3          |
|2020-12-17|[Kiwi]                           |1    |3          |
+----------+---------------------------------+-----+-----------+

