### Customer Transactions

You are given a DataFrame containing customer transactions. The columns are customer_id, transaction_date, and amount.

Write a PySpark code to calculate the following:
- The total transaction amount for each customer.
- The average transaction amount for each customer.
- The number of transactions made by each customer.
- Filter out customers who have made more than 5 transactions.

In [0]:
# Sample data
data = [
    (1, "2024-01-01", 200),
    (1, "2024-01-02", 150),
    (2, "2024-01-01", 300),
    (3, "2024-01-01", 100),
    (1, "2024-01-03", 250),
    (3, "2024-01-02", 200),
    (2, "2024-01-02", 100),
    (2, "2024-01-03", 200),
    (1, "2024-01-04", 300),
    (1, "2024-01-05", 100),
]

# Create DataFrame
df = spark.createDataFrame(data, ["customer_id", "transaction_date", "amount"])
df.show()

+-----------+----------------+------+
|customer_id|transaction_date|amount|
+-----------+----------------+------+
|          1|      2024-01-01|   200|
|          1|      2024-01-02|   150|
|          2|      2024-01-01|   300|
|          3|      2024-01-01|   100|
|          1|      2024-01-03|   250|
|          3|      2024-01-02|   200|
|          2|      2024-01-02|   100|
|          2|      2024-01-03|   200|
|          1|      2024-01-04|   300|
|          1|      2024-01-05|   100|
+-----------+----------------+------+



In [0]:
from pyspark.sql.functions import sum, avg, col, count

In [0]:
aggregated_df = df.groupBy("customer_id").agg(
    sum(col("amount")).alias("total_transaction_amount"),
    avg(col("amount")).alias("average_transaction_amount"),
    count(col("amount")).alias("transaction_count")
)

aggregated_df.show()

+-----------+------------------------+--------------------------+-----------------+
|customer_id|total_transaction_amount|average_transaction_amount|transaction_count|
+-----------+------------------------+--------------------------+-----------------+
|          1|                    1000|                     200.0|                5|
|          2|                     600|                     200.0|                3|
|          3|                     300|                     150.0|                2|
+-----------+------------------------+--------------------------+-----------------+



In [0]:
# Filter out customers who have made more than 5 transactions.
filtered_df = aggregated_df.filter(col("transaction_count") >= 5)
filtered_df.show()

+-----------+------------------------+--------------------------+-----------------+
|customer_id|total_transaction_amount|average_transaction_amount|transaction_count|
+-----------+------------------------+--------------------------+-----------------+
|          1|                    1000|                     200.0|                5|
+-----------+------------------------+--------------------------+-----------------+

