In [None]:
# imports
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType
from pyspark.sql import DataFrame
from pyspark.sql.functions import round
from pyspark.sql.functions import col
from pyspark.sql.utils import AnalysisException
from azure.storage.filedatalake import DataLakeServiceClient
import os 
import time


In [None]:
# SECRETS
connection_string = "**********"
jdbc_username = "**********"
jdbc_password = "**********"

In [None]:
# define schema of source
schema = StructType() \
    .add("Reference", IntegerType(), True) \
    .add("Account Number", StringType(), True) \
    .add("Description", StringType(), True) \
    .add("Start Balance", FloatType(), True) \
    .add("Mutation", FloatType(), True) \
    .add("End Balance", FloatType(), True)

try:
    # load source into dataframe
    df = spark.read.load(
        'abfss://codingassignmentfilesystem@codingassignmentstorage.dfs.core.windows.net/arriving_data/*', 
        format='csv', 
        header=True,
        delimiter=",",
        schema=schema
    )
except AnalysisException: 
    print("No data found in path")
    mssparkutils.session.stop()

In [None]:
def convert_cols_snakecase(df: DataFrame) -> DataFrame:
    """
    Converts all column names in the provided dataframe to snakecase

    Parameters
    ----------
    df: DataFrame
        The DataFrame for which the columns should be converted to snakecase
    
    Returns
    -------
    df: DataFrame
    """
    new_column_names = [col.replace(" ", "_").lower() for col in df.columns]
    return df.toDF(*new_column_names)
    
    # col_mapping ={}
    # for col_name in df.columns:
    #     new_col_name = col_name.replace(" ", "_").lower()
    #     col_mapping[col_name] = new_col_name
    # for column_name in df.columns:
    #     df= df.withColumnRenamed(column_name, col_mapping[column_name])
    # return df

# convert column names to snake case 
df = convert_cols_snakecase(df)

In [None]:
def find_duplicate_column_values(df: DataFrame, column: str) -> DataFrame:
    """
    Finds values for provided column in provided dataframe for which duplicate values exist

    Parameters
    ----------
    df: DataFrame
        The DataFrame that should be searched for duplicate values on provided column
    column: str 
        The column that should be searched for duplicate values  
    
    Returns
    -------
    duplicate_references: DataFrame
    """
    return df.groupBy(column).count().where('count > 1').drop('count')

# get dataframe in which duplicate values for reference are stored
duplicate_references = find_duplicate_column_values(df, "reference")
duplicate_references.show()

In [None]:
def deduplicate_rows_on_column(df: DataFrame, duplicate_reference_df: DataFrame) -> (DataFrame, DataFrame):
    """
    Removes duplicate records from provided dataframe: df, using the values in provvided dataframe: duplicate_reference_df
    Returns a dataframe from which the duplicates are removed and returns a dataframe containing the duplicate records

    Parameters
    ----------
    df: DataFrame
        The DataFrame from which the duplicate values should be removed
    duplicate_reference_df: DataFrame
        The DataFrame from which the values will be used to determine duplicate records
        This DataFrame should contain only one column in which the duplicate values are stored 
    
    Returns
    -------
    (not_duplicate_records: DataFrame, duplicate_records: DataFrame) 
    """
    not_duplicate_records = df.filter(~df['reference'].isin([ int(row['reference']) for row in duplicate_reference_df.collect()]))
    duplicate_records = df.filter(df['reference'].isin([ int(row['reference']) for row in duplicate_reference_df.collect()]))
    return not_duplicate_records, duplicate_records

# get df (dataframe without any duplicate rows) and filtered_df (dataframe containing the duplicate rows)
df, filtered_df = deduplicate_rows_on_column(df, duplicate_references)

In [None]:
def round_numerical_columns(df: DataFrame, columns: list, decimals: int) -> DataFrame:
    """
    Rounds all the columns provided in the column (list) parameter in the provided dataframe df to the amount of decimals in the provided decimals int

    Parameters
    ---------
    df: DataFrame
        The DataFrame for which the column values should be rounded 
    columns: list 
        A list of columns that should be rounded 
    decimals: int 
        The amount of decimals the numerical values should be rounded to 
    
    Returns
    -------
    df: DataFrame 
    """
    for column in columns: 
        df = df.withColumn(column, round(df[column], 2))
    return df 

# round numerical columns start_balance, mutation, and end_balance to two decimals
df = round_numerical_columns(df, ["start_balance", "mutation", "end_balance"], 2)

In [None]:
def check_balance_after_mutation(df: DataFrame, col_start_balance: str, col_mutation: str, col_end_balance: str) -> (DataFrame, DataFrame):
    """
    Checks if balances are correct after mutation by checking whether start_balance + mutation == end_balance
    Correct and incorrect transactions are split into separate DataFrames which are returned by the function 

    Parameters
    ----------
    df: DataFrame 
        The DataFrame for which the balances should be checked 
    col_start_balance: str 
        The column that contains the start_balance 
    col_mutation: str 
        The column that contains the mutation 
    col_end_balance: str 
        The column that contains the end_balance

    Returns
    -------
    df_correct_balance: DataFrame 
        The DataFrame containing the correct end_balance 
    df_incorrect_balance: DataFrame 
        The DataFrame containing the incorrect end_balance
    """
    df_correct_balance = df.filter(round((df[col_start_balance] + df[col_mutation]),2) == df[col_end_balance])
    df_incorrect_balance = df.filter(round((df[col_start_balance] + df[col_mutation]),2) != df[col_end_balance])
    return df_correct_balance, df_incorrect_balance

# get df with correct balances and get df_incorrect_balance with incorrect balances
df, df_incorrect_balance = check_balance_after_mutation(df, "start_balance", "mutation", "end_balance")

# append the incorrect_balance dataframe to the filtered_df dataframe
filtered_df = filtered_df.union(df_incorrect_balance)

In [None]:
def generate_report_content(filtered_df: DataFrame) -> str: 
    """
    Generates report_content for transactions that did not pass validations. Will output reference and description for each failed transaction. 

    Parameters
    ----------
    filtered_df: DataFrame 
        The DataFrame containing rows that did not pass validation. 
    
    Returns
    -------
    report_content: str
    """

    filtered_transaction_array = filtered_df.toPandas().to_dict("records")
    report_header = "Transactions that did not pass validation:\n------------------------------------------\n"
    report_line_array = [f"reference: {transaction['reference']}; description: {transaction['description']}" for transaction in filtered_transaction_array]
    report_lines = "\n".join(report_line_array)
    report_content = report_header + report_lines
    return report_content

# generate report_content from filtered_df DataFrame
report_content = generate_report_content(filtered_df)

In [None]:
def upload_report_to_adls(storage_account_name: str, file_system_name: str, destination_dir: str, destination_path: str, connection_string: str, report_content: str) -> None:
    """
    Uploads the pdf report to adls

    Parameters
    ----------
    storage_account_name: str 
        The storage account on which the adls filesystem resides 
    file_system_name: str 
        The name of the adls filesystem
    destination_dir: str
        The directory in the filesystem to store the pdf to 
    destination_path: str 
        The filepath to store the pdf to
    connection_string: str 
        Storage account key used for authentication 
    report_content: 
        The content that should be written to adls 

    Returns
    -------
    None 
    """
    service_client = DataLakeServiceClient.from_connection_string(conn_str=connection_string)
    file_system_client = service_client.get_file_system_client(file_system=file_system_name)
    directory_client = file_system_client.get_directory_client(os.path.dirname(destination_dir))
    file_client = directory_client.get_file_client(os.path.basename(destination_path))

    file_client.upload_data(report_content, overwrite=True)

# set required variables for uploading the pdf to adls
storage_account_name = "codingassignmentstorage"
file_system_name = "codingassignmentfilesystem"
destination_dir = f"/reports/"
destination_path = f"/reports/failed_transactions_{str(time.time())}"

# upload the pdf to adls 
upload_report_to_adls(storage_account_name, file_system_name, destination_dir, destination_path, connection_string, report_content)

In [None]:
def write_df_to_db(jdbc_hostname: str, jdbc_database: str, jdbc_username, jdbc_password, df: DataFrame, table_name: str, mode: str="ignore", jdbc_port: int=1433) -> None:
    """
    Writes a dataframe to a database table 

    Parameters
    ----------
    jdbc_hostname: str 
        The hostname of the sql server 
    jdbc_database: str 
        The name of the sql database 
    jdbc_username: str 
        The username for connecting to the database 
    jdbc_password: str 
        The password for connecting to the database 
    df: DataFrame 
        The DataFrame to write to the database 

    Returns
    -------
    """
    jdbc_connection_string = f"jdbc:sqlserver://{jdbc_hostname}:{jdbc_port};database={jdbc_database}"
    connection_properties = {
        "user": jdbc_username,
        "password": jdbc_password,
        "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver"
    } 
    df.write.jdbc(url=jdbc_connection_string, table=table_name, mode=mode, properties=connection_properties)


# set the variables required to write the dataframes to the azure sql database 
jdbc_hostname = "codingassignmentsqlserver.database.windows.net"
jdbc_database = "codingassignmentdb"

# write transactions to transactions table 
write_df_to_db(jdbc_hostname, jdbc_database, jdbc_username, jdbc_password, df, "transactions")

# write transactions that did not pass validation to the failed_transactions table 
write_df_to_db(jdbc_hostname, jdbc_database, jdbc_username, jdbc_password, filtered_df, "failed_transactions")