In [49]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.ml.feature import Imputer

spark = (
    SparkSession.builder.appName("MAST30034 Project 2")
    .config("spark.driver.memory", '4g')
    .config("spark.executor.memory", '8g')
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.sql.parquet.enableVectorizedReader","false")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .getOrCreate()
)

### Read in data

In [41]:
weights_sdf = spark.read.parquet(
    '../data/curated/demographic_weights.parquet'
)
consumers_sdf = spark.read.parquet(
    '../data/curated/cleaned_consumers.parquet'
)

user_details_sdf = spark.read.parquet(
    '../data/tables/consumer_user_details.parquet'
)

postcode_poa_sdf = spark.read.parquet(
    '../data/curated/census/postcode_poa.parquet'
)
# Reading in all transaction data and joining them
"""transactions_sdf = spark.read.parquet(
    '../data/tables/transactions_20210228_20210827_snapshot'
).union(
    spark.read.parquet(
        '../data/tables/transactions_20210828_20220227_snapshot'
    )
).union(
    spark.read.parquet(
        '../data/tables/transactions_20220228_20220828_snapshot'
    )
)"""
transactions_sdf = spark.read.parquet('../data/raw/samples/transaction_sample.parquet')

### Remove transactions outside valid bnpl range

In [42]:
# Round to 2 decimal places, and define a transaction range
min_value = 5
max_value = 10000

transactions = transactions_sdf.where(
    (F.col('dollar_value') >= min_value)
    & (F.col('dollar_value') <= max_value)
)

### Join transaction data with consumer data and weights

In [43]:
transactions_sdf = transactions_sdf.join(
    user_details_sdf,
    on = 'user_id',
    how = 'left'
).join(
    consumers_sdf.select(
        'consumer_id', 'postcode', 'gender'
    ),
    on = 'consumer_id',
    how = 'left'
).join(
    postcode_poa_sdf,
    on = 'postcode',
    how = 'left'
).join(
    weights_sdf,
    on = ['poa' ,'gender'],
    how = 'left'
)

Displays number of null values for each column (resulting from above joins)

In [48]:
transactions_sdf.select([F.count(F.when(F.col(x).isNull(), x)) for x in transactions_sdf.columns])

count(CASE WHEN (poa IS NULL) THEN poa END),count(CASE WHEN (gender IS NULL) THEN gender END),count(CASE WHEN (postcode IS NULL) THEN postcode END),count(CASE WHEN (consumer_id IS NULL) THEN consumer_id END),count(CASE WHEN (user_id IS NULL) THEN user_id END),count(CASE WHEN (merchant_abn IS NULL) THEN merchant_abn END),count(CASE WHEN (dollar_value IS NULL) THEN dollar_value END),count(CASE WHEN (order_id IS NULL) THEN order_id END),count(CASE WHEN (order_datetime IS NULL) THEN order_datetime END),count(CASE WHEN (weight IS NULL) THEN weight END)
449,345,345,0,0,0,0,0,0,0


Impute null weights

In [47]:

imputer = Imputer(inputCol = 'weight', outputCol='weight', strategy = 'mean')
transactions_sdf = imputer.fit(transactions_sdf).transform(transactions_sdf)



Apply weights

In [None]:
transaction_sdf = transactions_sdf.withColumn(
    'weighted_dollar_value',
    F.col('weight')*F.col('dollar_value')
)

In [None]:
transactions_sdf = transactions_sdf.withColumn('dollar_value', F.round(F.col('dollar_value'), 2))

### Check ABN validity

In [15]:
# Make sure ABN is valid, takes in long

def validateABN(merchant_abn):

    str_abn = str(merchant_abn)

    if len(str_abn) == 11:
        return True
    else:
        return False

In [16]:
# Create a list of all row values, used for validating ABN

sdf_list = sdf_consumer_transaction.select("merchant_abn").collect()

                                                                                

In [17]:
# Find any merchants without a valid ABN

i = 0
invalidABN = []

while i < len(sdf_list):
    abn = str(sdf_list[i].__getitem__('merchant_abn'))
    if validateABN(abn) == False:
        invalidABN.append(abn)
    i += 1

In [18]:
invalidABN

[]

ez no invalid abn

In [19]:
# Checking date range
start_date = '2021-02-28'
end_date = '2022-08-28'
sdf_consumer_transaction = sdf_consumer_transaction.where(
    (F.col('order_datetime') >= start_date) & (F.col('order_datetime') <= end_date)
)

In [20]:
sdf_consumer_transaction.count()

                                                                                

11965964

In [21]:
# Export cleaned data
sdf_consumer_transaction.write.mode('overwrite').parquet('../data/curated/cleaned_transactions.parquet')

                                                                                