In [1]:
# Intialization
import os
import sys

os.environ["SPARK_HOME"] = "/home/talentum/spark"
os.environ["PYLIB"] = os.environ["SPARK_HOME"] + "/python/lib"
# In below two lines, use /usr/bin/python2.7 if you want to use Python 2
os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.6" 
os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3"
sys.path.insert(0, os.environ["PYLIB"] +"/py4j-0.10.7-src.zip")
sys.path.insert(0, os.environ["PYLIB"] +"/pyspark.zip")

# NOTE: Whichever package you want mention here.
# os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages com.databricks:spark-xml_2.11:0.6.0 pyspark-shell' 
# os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-avro_2.11:2.4.0 pyspark-shell'
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages com.databricks:spark-xml_2.11:0.6.0,org.apache.spark:spark-avro_2.11:2.4.3 pyspark-shell'
# os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages com.databricks:spark-xml_2.11:0.6.0,org.apache.spark:spark-avro_2.11:2.4.0 pyspark-shell'

In [2]:
#import important library 
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, year, month, dayofmonth, dayofweek, when, isnan, when, count, expr, min, max
from pyspark.sql.functions import countDistinct, mean, stddev, first ,col, to_date, regexp_replace, udf
from pyspark.sql.functions import sum as spark_sum, round, rank, split, trim, size, datediff, floor, lit
from pyspark.sql.types import StructType, StructField, StringType, FloatType, DateType, IntegerType
from pyspark.sql.types import StringType, BooleanType, IntegerType
from pyspark.sql.window import Window
import pandas as pd



In [4]:
#Entrypoint 2.x
spark = SparkSession.builder.appName("Bank Customer Segmentation").getOrCreate()

# On yarn:
# spark = SparkSession.builder.appName("Spark SQL basic example").enableHiveSupport().master("yarn").getOrCreate()
# specify .master("yarn")

sc = spark.sparkContext

In [5]:
#declare schema for dataframe
schema = StructType([
    StructField("TransactionID", StringType(), True),
    StructField("CustomerID", StringType(), True),
    StructField("CustomerDOB", StringType(), True),
    StructField("CustGender", StringType(), True),
    StructField("CustLocation", StringType(), True),
    StructField("CustAccountBalance", FloatType(), True),
    StructField("TransactionDate", StringType(), True),
    StructField("TransactionTime", StringType(), True),
    StructField("TransactionAmountINR", FloatType(), True)
])
file_path = "Bank_dataset.csv"

# Print the file_path
print("The file_path is", file_path)

The file_path is Bank_dataset.csv


# Basic ETL operation

In [6]:
#Reading the csv file
df = spark.read.csv(file_path, header=True, schema = StructType(schema))

# Inspect Data
df.show(5)
df.printSchema()

+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+--------------------+
|TransactionID|CustomerID|CustomerDOB|CustGender|CustLocation|CustAccountBalance|TransactionDate|TransactionTime|TransactionAmountINR|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+--------------------+
|           T1|  C5841053| 10-01-1994|         F|  JAMSHEDPUR|          17819.05|     02-08-2016|         143207|                25.0|
|           T2|  C2142763| 04-04-1957|         M|     JHAJJAR|           2270.69|     02-08-2016|         141858|             27999.0|
|           T3|  C4417068| 26-11-1996|         F|      MUMBAI|          17874.44|     02-08-2016|         142712|               459.0|
|           T4|  C5342380| 14-09-1973|         F|      MUMBAI|          866503.2|     02-08-2016|         142714|              2060.0|
|           T5|  C9031234| 24-03-1988|         F| NAVI 

In [7]:
#Counting rows and cloumns 
row_count = df.count()
column_count = len(df.columns)
print("The Number of Row's is:{},\nThe Number of Columns is: {}".format(row_count,column_count))

The Number of Row's is:1048567,
The Number of Columns is: 9


In [8]:
#Convert the column which contain date type data type 
df = df.withColumn("CustomerDOB", to_date(col("CustomerDOB"), "dd-MM-yyyy"))
df = df.withColumn("TransactionDate", to_date(col("TransactionDate"), "dd-MM-yyyy"))

In [9]:
# Summary Statistics
df.describe().show()

+-------+-------------+----------+----------+--------------------+------------------+------------------+--------------------+
|summary|TransactionID|CustomerID|CustGender|        CustLocation|CustAccountBalance|   TransactionTime|TransactionAmountINR|
+-------+-------------+----------+----------+--------------------+------------------+------------------+--------------------+
|  count|      1048567|   1048567|   1047467|             1048416|           1046198|           1048567|             1048567|
|   mean|         null|      null|      null|            400012.0|115403.54003532877|157087.52939297154|  1574.3350034945934|
| stddev|         null|      null|      null|                 0.0|  846485.381321891| 51261.85402232959|   6574.742984257266|
|    min|           T1|  C1010011|         F|(154) BHASKOLA FA...|               0.0|                 0|                 0.0|
|    max|      T999999|  C9099956|         T|           ZUNHEBOTO|      1.15035496E8|             95959|           156

In [10]:
# Checking for Null Values
null_counts = df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns])

# Collect the result into a local DataFrame
null_counts_result = null_counts.collect()[0]

# Display results
for column, count in null_counts_result.asDict().items():
    if count > 0:
        print(f"The column '{column}' contains {count} null values.")

The column 'CustomerDOB' contains 60736 null values.
The column 'CustGender' contains 1100 null values.
The column 'CustLocation' contains 151 null values.
The column 'CustAccountBalance' contains 2369 null values.


In [11]:
# Value Counts for Categorical Features
categorical_columns = [col_name for col_name, dtype in df.dtypes if dtype == 'string']

for column in categorical_columns:
    df.groupBy(column).count().show()

+-------------+-----+
|TransactionID|count|
+-------------+-----+
|         T352|    1|
|         T590|    1|
|         T855|    1|
|         T929|    1|
|         T947|    1|
|        T1118|    1|
|        T1401|    1|
|        T1508|    1|
|        T1767|    1|
|        T1872|    1|
|        T2345|    1|
|        T2463|    1|
|        T2837|    1|
|        T2947|    1|
|        T3091|    1|
|        T3230|    1|
|        T3271|    1|
|        T3337|    1|
|        T3396|    1|
|        T4155|    1|
+-------------+-----+
only showing top 20 rows

+----------+-----+
|CustomerID|count|
+----------+-----+
|  C8732711|    1|
|  C6421261|    2|
|  C2939112|    2|
|  C8741166|    1|
|  C4440931|    1|
|  C8116854|    1|
|  C4029960|    1|
|  C7817677|    4|
|  C4240562|    1|
|  C2221420|    1|
|  C8524541|    1|
|  C5638051|    3|
|  C2331867|    1|
|  C5534211|    1|
|  C4940219|    1|
|  C2230276|    1|
|  C4341556|    1|
|  C7283217|    1|
|  C3825339|    1|
|  C4138928|    1|
+--------

In [12]:
# Correlation Analysis (for numerical features)
numeric_columns = [col_name for col_name, dtype in df.dtypes if dtype in ['int', 'double']]

for column in numeric_columns:
    df.select(column).summary().show()

In [13]:
#confirming the schema
df.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustomerDOB: date (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustLocation: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionTime: string (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)



In [14]:
# Drop rows with null values in specific columns
df = df.na.drop(subset=["CustGender", "CustLocation", "CustAccountBalance","CustomerDOB"])

# Count null values for each column
null_counts_after_drop = df.select([spark_sum(when(col(c).isNull(), 1).otherwise(0)).alias(c) for c in df.columns])

# Show the results
null_counts_after_drop.show()


+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+--------------------+
|TransactionID|CustomerID|CustomerDOB|CustGender|CustLocation|CustAccountBalance|TransactionDate|TransactionTime|TransactionAmountINR|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+--------------------+
|            0|         0|          0|         0|           0|                 0|              0|              0|                   0|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+--------------------+



In [15]:
#Function to check whether the given column contain any value different from datatype
def is_valid_number(value):
    try:
        float(value)
        return True
    except ValueError:
        return False

is_valid_number_udf = udf(is_valid_number, BooleanType())

# Columns to check
columns_to_check = ["CustAccountBalance", "TransactionAmountINR"]

# Apply the UDF to flag rows with invalid data types in the specified columns
for column in columns_to_check:
    df = df.withColumn(f"{column}_is_valid", is_valid_number_udf(col(column)))
    
invalid_rows = df.filter((col("CustAccountBalance_is_valid") == False) | 
                         (col("TransactionAmountINR_is_valid") == False))

invalid_rows.count()


0

In [16]:
#confirming the schema
df.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustomerDOB: date (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustLocation: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionTime: string (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)
 |-- CustAccountBalance_is_valid: boolean (nullable = true)
 |-- TransactionAmountINR_is_valid: boolean (nullable = true)



In [17]:
#counting the row's and column's value
row_count = df.count()
column_count = len(df.columns)
print("The Number of Row's is:{},\nThe Number of Columns is: {}".format(row_count,column_count))

The Number of Row's is:985322,
The Number of Columns is: 11


In [18]:
#importing modules 
from pyspark.sql.functions import col, year, count
df = df.withColumn("Year", year(col("CustomerDOB")))

# Group by 'Year' and count occurrences
year_counts = df.groupBy("Year").agg(count("*").alias("Count"))

# Show the results
year_counts.show()

+----+-----+
|Year|Count|
+----+-----+
|1959| 2133|
|1990|72616|
|1975|12293|
|2025|    1|
|1977|16626|
|2027|   49|
|2003|   69|
|2007|   46|
|2018|    1|
|1974|10952|
|2015|   25|
|2023|    6|
|1955| 1761|
|2006|   21|
|1978|19605|
|2022|    3|
|1961| 2164|
|2013|   27|
|1942|  400|
|1944|  412|
+----+-----+
only showing top 20 rows



In [19]:
#adding window functions
window_spec = Window.orderBy("Year")

# Calculate quantiles (Q1, Q3)
quantiles = df.selectExpr("percentile_approx(Year, 0.25) as Q1", "percentile_approx(Year, 0.75) as Q3").collect()
q1 = quantiles[0]["Q1"]
q3 = quantiles[0]["Q3"]

# Calculate IQR
iqr = q3 - q1

# Define bounds for outliers
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr

# Filter outliers
outliers = df.filter((col("Year") < lower_bound) | (col("Year") > upper_bound))

# Show outliers
outliers.count()

54394

In [20]:
#filtered data with outlier
filtered_df = df.filter((col("Year") >= lower_bound) & (col("Year") <= upper_bound))

# Show the filtered DataFrame
filtered_df.show()

+-------------+----------+-----------+----------+--------------------+------------------+---------------+---------------+--------------------+---------------------------+-----------------------------+----+
|TransactionID|CustomerID|CustomerDOB|CustGender|        CustLocation|CustAccountBalance|TransactionDate|TransactionTime|TransactionAmountINR|CustAccountBalance_is_valid|TransactionAmountINR_is_valid|Year|
+-------------+----------+-----------+----------+--------------------+------------------+---------------+---------------+--------------------+---------------------------+-----------------------------+----+
|           T1|  C5841053| 1994-01-10|         F|          JAMSHEDPUR|          17819.05|     2016-08-02|         143207|                25.0|                       true|                         true|1994|
|           T3|  C4417068| 1996-11-26|         F|              MUMBAI|          17874.44|     2016-08-02|         142712|               459.0|                       true|      

In [21]:
#confirm after filtered with outlier
row_count = filtered_df.count()
column_count = len(filtered_df.columns)
print("The Number of Row's is:{},\nThe Number of Columns is: {}".format(row_count,column_count))

The Number of Row's is:930928,
The Number of Columns is: 12


In [22]:
#printing after filtering data
filtered_df.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustomerDOB: date (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustLocation: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionTime: string (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)
 |-- CustAccountBalance_is_valid: boolean (nullable = true)
 |-- TransactionAmountINR_is_valid: boolean (nullable = true)
 |-- Year: integer (nullable = true)



In [23]:
#maximum value in year column
max_year = filtered_df.select(max(col("Year"))).collect()[0][0]
print("Maximum Year:", max_year)

Maximum Year: 2004


In [24]:
#minimum value in year column
min_year = filtered_df.select(min(col("Year"))).collect()[0][0]
print("Minimum Year:", min_year)

Minimum Year: 1969


In [25]:
#finding the maximum date or last date of transaction
max_date = filtered_df.select(max(col("TransactionDate"))).collect()[0][0]
print("Maximum Transaction Date:", max_date)

Maximum Transaction Date: 2016-10-21


In [26]:
#write a function for adding age column in data 
def add_age_column(df, dob_col, reference_date_str):
    """
    Add a new column 'ageofcustomer' to the DataFrame, calculating age from DOB to a reference date.
    
    :param df: DataFrame with a date of birth column
    :param dob_col: Name of the date of birth column
    :param reference_date_str: Reference date in 'dd-MM-yyyy' format
    :return: DataFrame with the new 'ageofcustomer' column
    """
    # Convert the date of birth column from string to date format
    df = df.withColumn(dob_col, to_date(col(dob_col), "dd-MM-yyyy"))
    
    # Define the reference date
    reference_date = to_date(lit(reference_date_str), "dd-MM-yyyy")
    
    # Calculate the age
    df = df.withColumn("CustomerAge",
                       floor(datediff(reference_date, col(dob_col)) / 365.25).cast(IntegerType()))
    
    return df

# Add age column
reference_date_str = "21-10-2016"
df_with_age = add_age_column(filtered_df, "CustomerDOB", reference_date_str)

# Show the DataFrame with the new 'ageofcustomer' column

In [27]:
#display only age column
df_with_age.select("CustomerAge").show()

+-----------+
|CustomerAge|
+-----------+
|         22|
|         19|
|         43|
|         28|
|         44|
|         24|
|         34|
|         28|
|         32|
|         34|
|         28|
|         38|
|         24|
|         38|
|         27|
|         25|
|         31|
|         23|
|         27|
|         30|
+-----------+
only showing top 20 rows



In [28]:
#confirming after adding age column
df_with_age.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustomerDOB: date (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustLocation: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionTime: string (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)
 |-- CustAccountBalance_is_valid: boolean (nullable = true)
 |-- TransactionAmountINR_is_valid: boolean (nullable = true)
 |-- Year: integer (nullable = true)
 |-- CustomerAge: integer (nullable = true)



In [29]:
#countinh the values which are greater than 18
df_with_age = df_with_age.filter((df_with_age["CustomerAge"] >= 18))
df_with_age.count()

929410

In [30]:
#finding outlier in age column 
quantiles = df_with_age.approxQuantile("CustomerAge", [0.25, 0.75], 0.0)
q1, q3 = quantiles[0], quantiles[1]

# Step 2: Compute IQR
iqr = q3 - q1

# Step 3: Determine the bounds for outliers
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr

# Step 4: Filter outliers
outliers = df_with_age.filter((col("CustomerAge") < lower_bound) | (col("CustomerAge") > upper_bound))

# Show the outliers
outliers.count()


11877

In [31]:
#filtered with outlier for age column
df_filter= df_with_age.filter((col("CustomerAge") >= lower_bound) & (col("CustomerAge") <= upper_bound))

# Show the filtered DataFrame
df_filter.select('CustomerID', 'CustomerDOB', 'CustomerAge').show()

+----------+-----------+-----------+
|CustomerID|CustomerDOB|CustomerAge|
+----------+-----------+-----------+
|  C5841053| 1994-01-10|         22|
|  C4417068| 1996-11-26|         19|
|  C5342380| 1973-09-14|         43|
|  C9031234| 1988-03-24|         28|
|  C1536588| 1972-10-08|         44|
|  C7126560| 1992-01-26|         24|
|  C1220223| 1982-01-27|         34|
|  C8536061| 1988-04-19|         28|
|  C6638934| 1984-06-22|         32|
|  C5430833| 1982-07-22|         34|
|  C6939838| 1988-07-07|         28|
|  C6339347| 1978-06-13|         38|
|  C8327851| 1992-01-05|         24|
|  C7917151| 1978-03-24|         38|
|  C8967349| 1989-07-16|         27|
|  C3732016| 1991-01-11|         25|
|  C8999019| 1985-06-24|         31|
|  C6121429| 1993-04-20|         23|
|  C4511244| 1989-08-31|         27|
|  C5830215| 1986-10-01|         30|
+----------+-----------+-----------+
only showing top 20 rows



In [32]:
#confirming after filtered
row_count = df_filter.count()
column_count = len(df_filter.columns)
print("The Number of Row's is:{},\nThe Number of Columns is: {}".format(row_count,column_count))

The Number of Row's is:917533,
The Number of Columns is: 13


In [33]:
#printing all colmns name
df_filter.columns

['TransactionID',
 'CustomerID',
 'CustomerDOB',
 'CustGender',
 'CustLocation',
 'CustAccountBalance',
 'TransactionDate',
 'TransactionTime',
 'TransactionAmountINR',
 'CustAccountBalance_is_valid',
 'TransactionAmountINR_is_valid',
 'Year',
 'CustomerAge']

In [34]:
#counting if any duplicate row is present or not
duplicate_counts = df_filter.groupBy(df_filter.columns).count()

# Filter rows where count is greater than 1 (indicating duplicates)
duplicates = duplicate_counts.filter(col("count") > 1)

# Show duplicate rows
duplicates.show()

+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+--------------------+---------------------------+-----------------------------+----+-----------+-----+
|TransactionID|CustomerID|CustomerDOB|CustGender|CustLocation|CustAccountBalance|TransactionDate|TransactionTime|TransactionAmountINR|CustAccountBalance_is_valid|TransactionAmountINR_is_valid|Year|CustomerAge|count|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+--------------------+---------------------------+-----------------------------+----+-----------+-----+
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+--------------------+---------------------------+-----------------------------+----+-----------+-----+



In [35]:
# Dataframe with drop unnecessary columns
df_final = df_filter.drop("CustomerDOB","TransactionTime","Year","CustAccountBalance_is_valid","TransactionAmountINR_is_valid")

In [36]:
#printing dataframe
df_final.show()

+-------------+----------+----------+--------------------+------------------+---------------+--------------------+-----------+
|TransactionID|CustomerID|CustGender|        CustLocation|CustAccountBalance|TransactionDate|TransactionAmountINR|CustomerAge|
+-------------+----------+----------+--------------------+------------------+---------------+--------------------+-----------+
|           T1|  C5841053|         F|          JAMSHEDPUR|          17819.05|     2016-08-02|                25.0|         22|
|           T3|  C4417068|         F|              MUMBAI|          17874.44|     2016-08-02|               459.0|         19|
|           T4|  C5342380|         F|              MUMBAI|          866503.2|     2016-08-02|              2060.0|         43|
|           T5|  C9031234|         F|         NAVI MUMBAI|           6714.43|     2016-08-02|              1762.5|         28|
|           T6|  C1536588|         F|            ITANAGAR|           53609.2|     2016-08-02|               676

In [37]:
#confirming after droping the unnecessary columns
row_count = df_final.count()
column_count = len(df_final.columns)
print("The Number of Row's is:{},\nThe Number of Columns is: {}".format(row_count,column_count))
df_final.printSchema()

The Number of Row's is:917533,
The Number of Columns is: 8
root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustLocation: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)
 |-- CustomerAge: integer (nullable = true)



In [38]:
# Value Counts for Categorical Features
categorical_columns = [col_name for col_name, dtype in df_final.dtypes if dtype == 'string']

for column in categorical_columns:
    df_final.groupBy(column).count().show()

+-------------+-----+
|TransactionID|count|
+-------------+-----+
|         T352|    1|
|         T590|    1|
|         T929|    1|
|         T947|    1|
|        T1118|    1|
|        T1401|    1|
|        T1508|    1|
|        T1767|    1|
|        T1872|    1|
|        T2345|    1|
|        T2463|    1|
|        T2837|    1|
|        T2947|    1|
|        T3091|    1|
|        T3230|    1|
|        T3271|    1|
|        T3337|    1|
|        T3396|    1|
|        T4381|    1|
|        T4418|    1|
+-------------+-----+
only showing top 20 rows

+----------+-----+
|CustomerID|count|
+----------+-----+
|  C8732711|    1|
|  C6421261|    2|
|  C2939112|    2|
|  C8741166|    1|
|  C4440931|    1|
|  C4029960|    1|
|  C7817677|    4|
|  C4240562|    1|
|  C2221420|    1|
|  C8524541|    1|
|  C5638051|    2|
|  C2331867|    1|
|  C5534211|    1|
|  C4940219|    1|
|  C2230276|    1|
|  C4341556|    1|
|  C7283217|    1|
|  C3825339|    1|
|  C4138928|    1|
|  C5997044|    1|
+--------

In [39]:
# Group by CustLocation and calculate the total transaction amount rounded to 2 decimal places
Totalamount_location_wise = df_final.groupBy("CustLocation") \
    .agg(round(spark_sum("TransactionAmountINR"), 2).alias("TotalTransactionAmount"))

# Convert to Pandas DataFrame for plotting
Totalamount_location_wise_pd = Totalamount_location_wise.toPandas()
Totalamount_location_wise_pd = Totalamount_location_wise_pd.sort_values(by="TotalTransactionAmount", ascending=False).head(20)

In [40]:
#showing group by operation on columns
Totalamount_location_wise_pd.head()

Unnamed: 0,CustLocation,TotalTransactionAmount
7694,MUMBAI,137799300.0
443,NEW DELHI,106731500.0
2203,BANGALORE,91282400.0
6148,GURGAON,87780790.0
2202,DELHI,79916310.0


In [41]:
# unique city values
unique_city = df_final.select("CustLocation").distinct()
unique_city.count()

7733

In [42]:
# Step 1: Group by CustLocation and CustomerID, then calculate the total transaction amount
location_customer_transaction = df_final.groupBy("CustLocation", "CustomerID") \
    .agg(spark_sum("TransactionAmountINR").alias("TotalTransactionAmount"))

# Step 2: Use a window function to rank the customers by transaction amount within each location
window_spec = Window.partitionBy("CustLocation").orderBy(col("TotalTransactionAmount").desc())
location_customer_transaction = location_customer_transaction.withColumn("rank", rank().over(window_spec))

# Step 3: Filter to get the top 3 customers in each location
top_3_customers_per_location = location_customer_transaction.filter(col("rank") <= 3)

# Step 4: Convert the resulting DataFrame to Pandas for plotting
top_3_customers_per_location_pd = top_3_customers_per_location.toPandas()

In [43]:
#find top 3 customer based 
top_3_customers_per_location_pd.head()

Unnamed: 0,CustLocation,CustomerID,TotalTransactionAmount,rank
0,ACADEMY SAMASTIPUR,C5314885,569.0,1
1,ACADEMY SAMASTIPUR,C8814860,200.0,2
2,ACADEMY SAMASTIPUR,C8014864,120.0,3
3,ALIPURDUAR,C2819634,6376.5,1
4,ALIPURDUAR,C2219681,66.0,2


In [44]:
# Normalize spaces by replacing any whitespace characters with a single space
df_final = df_final.withColumn('CustLocation', regexp_replace(trim(col('CustLocation')), r'\s+', ' '))

# Extract the last word from the CustLocation column and create a new column `LastPart`
df_final = df_final.withColumn('LastPart', split(col('CustLocation'), ' ').getItem(size(split(col('CustLocation'), ' ')) - 1))

# Show the result (including the new Custlocation column)
df_final.select('*').show(50, truncate=False)


+-------------+----------+----------+-----------------------+------------------+---------------+--------------------+-----------+-------------+
|TransactionID|CustomerID|CustGender|CustLocation           |CustAccountBalance|TransactionDate|TransactionAmountINR|CustomerAge|LastPart     |
+-------------+----------+----------+-----------------------+------------------+---------------+--------------------+-----------+-------------+
|T1           |C5841053  |F         |JAMSHEDPUR             |17819.05          |2016-08-02     |25.0                |22         |JAMSHEDPUR   |
|T3           |C4417068  |F         |MUMBAI                 |17874.44          |2016-08-02     |459.0               |19         |MUMBAI       |
|T4           |C5342380  |F         |MUMBAI                 |866503.2          |2016-08-02     |2060.0              |43         |MUMBAI       |
|T5           |C9031234  |F         |NAVI MUMBAI            |6714.43           |2016-08-02     |1762.5              |28         |MUMBAI 

In [45]:
#confirm after adding the column with only single city name
df_final.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustLocation: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)
 |-- CustomerAge: integer (nullable = true)
 |-- LastPart: string (nullable = true)



In [46]:
#counting the unique value previous column
unique = df_final.select("CustLocation").distinct()
unique.count()

7703

In [58]:
#writing the schema
df_final.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustLocation: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)
 |-- CustomerAge: integer (nullable = true)
 |-- LastPart: string (nullable = true)



In [59]:
#counting unique vaue after new column 
unique1 = df_final.select("LastPart").distinct()
unique1.count()

3444

In [47]:
#confirming the values
df_final.show()

+-------------+----------+----------+--------------------+------------------+---------------+--------------------+-----------+------------+
|TransactionID|CustomerID|CustGender|        CustLocation|CustAccountBalance|TransactionDate|TransactionAmountINR|CustomerAge|    LastPart|
+-------------+----------+----------+--------------------+------------------+---------------+--------------------+-----------+------------+
|           T1|  C5841053|         F|          JAMSHEDPUR|          17819.05|     2016-08-02|                25.0|         22|  JAMSHEDPUR|
|           T3|  C4417068|         F|              MUMBAI|          17874.44|     2016-08-02|               459.0|         19|      MUMBAI|
|           T4|  C5342380|         F|              MUMBAI|          866503.2|     2016-08-02|              2060.0|         43|      MUMBAI|
|           T5|  C9031234|         F|         NAVI MUMBAI|           6714.43|     2016-08-02|              1762.5|         28|      MUMBAI|
|           T6|  C15

In [48]:
#drop the customer location column which have multiple value in single row
df_final = df_final.drop("CustLocation")

In [49]:
#after adding and droping the columns
df_final.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)
 |-- CustomerAge: integer (nullable = true)
 |-- LastPart: string (nullable = true)



In [50]:
#change column nme to customer location
df_final = df_final.withColumnRenamed("LastPart","CustLocation")

In [51]:
#final dataframe
df_final.printSchema()

root
 |-- TransactionID: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- CustGender: string (nullable = true)
 |-- CustAccountBalance: float (nullable = true)
 |-- TransactionDate: date (nullable = true)
 |-- TransactionAmountINR: float (nullable = true)
 |-- CustomerAge: integer (nullable = true)
 |-- CustLocation: string (nullable = true)



# Basic aggregation

In [52]:

# Load the dataset
data = df_final

# Convert 'TransactionDate' to DateType
data = data.withColumn("TransactionDate", col("TransactionDate").cast("date"))

# Extract datetime features
data = data.withColumn("TransactionYear", year(col("TransactionDate")))
data = data.withColumn("TransactionMonth", month(col("TransactionDate")))
data = data.withColumn("TransactionDay", dayofmonth(col("TransactionDate")))
data = data.withColumn("TransactionDayOfWeek", dayofweek(col("TransactionDate")))

# Create Age Groups
data = data.withColumn("AgeGroup", 
                       when((col("CustomerAge") >= 18) & (col("CustomerAge") < 25), "18-25")
                       .when((col("CustomerAge") >= 25) & (col("CustomerAge") < 35), "26-35")
                       .when((col("CustomerAge") >= 35) & (col("CustomerAge") < 45), "36-45")
                       .when((col("CustomerAge") >= 45) & (col("CustomerAge") < 55), "46-55")
                       .when((col("CustomerAge") >= 55) & (col("CustomerAge") < 65), "56-65")
                       .when((col("CustomerAge") >= 65) & (col("CustomerAge") < 75), "66-75")
                       .otherwise("76+"))

# Now perform the aggregation
customer_features = data.groupBy("CustomerID").agg(
    spark_sum("TransactionAmountINR").alias("TotalTransactionAmount"),
    mean("TransactionAmountINR").alias("AvgTransactionAmount"),
    stddev("TransactionAmountINR").alias("StdTransactionAmount"),
    count("TransactionAmountINR").alias("NumTransactions"),
    mean("CustAccountBalance").alias("AvgAccountBalance"),
    countDistinct("TransactionYear").alias("NumTransactionYears"),
    countDistinct("TransactionMonth").alias("NumTransactionMonths"),
    first("CustLocation").alias("CustLocation"),
    first("CustGender").alias("CustGender"),
    first("AgeGroup").alias("AgeGroup")
)

# Show the resulting dataframe
customer_features.show()


+----------+----------------------+--------------------+--------------------+---------------+-----------------+-------------------+--------------------+------------+----------+--------+
|CustomerID|TotalTransactionAmount|AvgTransactionAmount|StdTransactionAmount|NumTransactions|AvgAccountBalance|NumTransactionYears|NumTransactionMonths|CustLocation|CustGender|AgeGroup|
+----------+----------------------+--------------------+--------------------+---------------+-----------------+-------------------+--------------------+------------+----------+--------+
|  C1010248|                2000.0|              2000.0|                 NaN|              1|     195876.65625|                  1|                   1|    VADODARA|         M|   36-45|
|  C1013555|                 110.0|               110.0|                 NaN|              1|     334906.65625|                  1|                   1|       DELHI|         M|   18-25|
|  C1013666|                5000.0|              5000.0|              

# Saving data into CSV format

In [53]:
#saving data in csv format
output_dir = "file:///home/talentum/dataset"
df_final.coalesce(1).write.option("header","true").csv(output_dir, mode='overwrite')