<a href="https://colab.research.google.com/github/thowley0824/howley_thomas_capstone/blob/main/association_rules.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install mlxtend

In [None]:
%run /databricks-yoda

In [None]:
import pandas as pd
import numpy as np
from mlxtend.frequent_patterns import apriori, association_rules
from mlxtend.preprocessing import TransactionEncoder

In [None]:
# CONFIGURE SPARK CONNECTOR

# Refresh token
REFRESH_TOKEN = get_refresh_token()

# Register token
snowflake_databricks = DatabricksYoda()
snowflake_databricks.request_token(REFRESH_TOKEN, 'refresh_token')

# Get user email
user_email = spark.sql("SELECT CURRENT_USER").rdd.map(list).first()[0]

# Create connection to the snowflake database
engine = snowflake_databricks.create_snowflake_oauth_connection(user_email)

In [None]:
# A functioned defined to create clean table displays
def format_table(df, title, index_remap =None,float_cols=None):

    if index_remap is not None:
        table = df.rename(index=index_remap)
    else:
        table = df.copy()

    s = table.style.set_table_styles([{'selector': 'th.col_heading',
                         'props': 'text-align: center; font-size: 1em;'},
                         {'selector': 'th.index_name',
                         'props': 'text-align: center; font-size: 1em;'},
                        {'selector': '.output',
                         'props': 'flex-direction: row;'},
                        {'selector': 'th:not(.index_name)',
                        'props': 'background-color: #c9c9c9; padding: 0.4em 0.4em 0.4em 0.4em; column-width: 100px;'},
                        {'selector': 'th.blank.level0',
                        'props': 'background-color: white;'},
                        {'selector': 'td',
                        'props': 'text-align: center; padding: 0.1em 0.1em 0.1em 0.1em; border: 1px solid #444444;'}],
                                     overwrite=False)
    s = s.set_table_attributes("style='display:inline'").set_caption(title)\
            .set_table_styles([{'selector': 'caption',
                        'props': 'caption-side:top; text-align:left; font-size:1.2em; font-weight:bold; padding:1em;'}],
                        overwrite=False)
    if float_cols is not None:
        float_format_dict = {n:'{:0.2}'.format for n in float_cols}

        s = s.format(
                    formatter=float_format_dict)
    return s

In [None]:
# RETREIVE DATA FROM THE SNOWFLAKE SOURCE TABLE

# Define the queries for obtaining data from the snowflake table
q_all_users = f"""select
                    product_1, product_2, product_3, product_4, product_5,
                    product_6, product_7, product_8, product_9, product_10
            from prod__workspace__us.SCRATCH_T_STRATEGICFINANCE.user_transactions_by_product;
         """

q_new_users = f"""select
                    product_1, product_2, product_3, product_4, product_5,
                    product_6, product_7, product_8, product_9, product_10
            from prod__workspace__us.SCRATCH_T_STRATEGICFINANCE.user_transactions_by_product
            where new_user_flag = 1;
         """

q_existing_users = f"""select
                    product_1, product_2, product_3, product_4, product_5,
                    product_6, product_7, product_8, product_9, product_10
            from prod__workspace__us.SCRATCH_T_STRATEGICFINANCE.user_transactions_by_product
            where new_user_flag = 0;
         """

# Create pandas dataframes containing data from Snowflake
all_users_df = pd.read_sql(q_all_users, engine)
new_users_df = pd.read_sql(q_new_users, engine)
existing_users_df = pd.read_sql(q_existing_users, engine)

In [None]:
# CLEAN UP THE DATA

for df in [all_users_df, new_users_df, existing_users_df]:
    # Convert column names to lowercase
    df.columns = df.columns.str.lower()

    # Convert cell values to bools: if cell = 0 -> False, else -> True
    for column in df.columns:
        df[column] = np.where(df[column] == 0, False,True)

In [None]:
# GENERATE THE SINGLE ITEM SUPPORT DATA AND FORMAT FOR DISPLAY

all_single_supports = pd.DataFrame({'All Users': all_users_df.mean(),
                             'Existing Users': existing_users_df.mean(),
                             'New Users': new_users_df.mean()
                            })

descr_stats = all_single_supports.describe()
descr_stats = descr_stats.loc[~descr_stats.index.isin(['count','std'])]

sup_new_index_names = {'product_1': 'Product 1',
                       'product_2': 'Product 2',
                       'product_3': 'Product 3',
                       'product_4': 'Product 4',
                       'product_5': 'Product 5',
                       'product_6': 'Product 6',
                       'product_7': 'Product 7',
                       'product_8': 'Product 8',
                       'product_9': 'Product 9',
                       'product_10': 'Product 10'}
descr_new_index_names = {'mean': 'Mean',
                         'min': 'Minimum',
                         '25%': '25th Perc.',
                         '50%': 'Median',
                         '75%': '75th Perc.',
                         'max': 'Maximum'}

sup_float_col_names = ['All Users','Existing Users','New Users']
descr_float_col_names = ['All Users','Existing Users','New Users']

sup_tbl = format_table(all_single_supports, 'Single Item Supports By User Subset',
                       index_remap=sup_new_index_names,float_cols=sup_float_col_names)
descr_tbl = format_table(descr_stats, '',index_remap=descr_new_index_names,
                         float_cols=descr_float_col_names)

In [None]:
# GENERATE ALL FREQUENT ITEMSETS

all_users_freq_itemsets = apriori(all_users_df, min_support = 0.0021,
                                  use_colnames = True, verbose = 1)
existing_users_freq_itemsets = apriori(existing_users_df, min_support = 0.00165,
                                       use_colnames = True, verbose = 1)
new_users_freq_itemsets = apriori(new_users_df, min_support = 0.00006,
                                  use_colnames = True, verbose = 1)

In [None]:
# A function for obtaining the top rules ranked both by lift and by confidence

def obtain_top_rules_by_lift_and_confidence(freq_itemsets):
    rules = association_rules(freq_itemsets, metric = "confidence", min_threshold = 0.0)

    col_rename_dict = {'antecedents': 'LHS',
             'consequents': 'RHS',
             'antecedent support': 'LHS Support',
             'consequent support': 'RHS Support',
             'support': 'Support',
             'confidence': 'Confidence',
             'lift': 'Lift'}

    rules = rules.rename(columns=col_rename_dict)

    rules['lhs_len'] = rules['LHS'].apply(lambda x: len(x))
    rules['rhs_len'] = rules['RHS'].apply(lambda x: len(x))
    rules = rules[(rules['lhs_len']==1) & (rules['rhs_len']==1)]

    for i in range(len(rules.index)):
        rules.at[i,'LHS'] = list(rules.loc[i,'LHS'])[0]
        rules.at[i,'RHS'] = list(rules.loc[i,'RHS'])[0]

    rules = rules[['LHS','RHS','Lift','Confidence', 'Support', 'LHS Support','RHS Support']]

    rules_by_lift = rules.sort_values(by=['Lift', 'Confidence'], ascending = False)
    rules_by_conf = rules.sort_values(by=['Confidence'], ascending = False)

    rules_by_lift.reset_index(drop=True,inplace=True)
    rules_by_conf.reset_index(drop=True,inplace=True)

    return rules_by_lift.head(5), rules_by_conf.head(5)

In [None]:
# Obtain the top rules for each user group

all_users_lift, all_users_conf = obtain_top_rules_by_lift_and_confidence(
    all_users_freq_itemsets)
existing_users_lift, existing_users_conf = obtain_top_rules_by_lift_and_confidence(
    existing_users_freq_itemsets)
new_users_lift, new_users_conf = obtain_top_rules_by_lift_and_confidence(
    new_users_freq_itemsets)

In [None]:
# Format the top rules for display

rules_new_index_names = {0: 1,
                    1: 2,
                    2: 3,
                    3: 4,
                    4: 5}

rules_float_col_names = ['Lift','Confidence', 'Support', 'LHS Support','RHS Support']

all_lift_tbl = format_table(all_users_lift, 'All Users: Top Rules By Lift',
                            index_remap=rules_new_index_names,float_cols=rules_float_col_names)
all_conf_tbl = format_table(all_users_conf, 'All Users: Top Rules By Confidence',
                            index_remap=rules_new_index_names,
                            float_cols=rules_float_col_names)
existing_lift_tbl = format_table(existing_users_lift, 'Existing Users: Top Rules By Lift',
                                 index_remap=rules_new_index_names,
                                 float_cols=rules_float_col_names)
existing_conf_tbl = format_table(existing_users_conf, 'Existing Users: Top Rules By Confidence',
                                 index_remap=rules_new_index_names,
                                 float_cols=rules_float_col_names)
new_lift_tbl = format_table(new_users_lift, 'New Users: Top Rules By Lift',index_remap=rules_new_index_names,
                            float_cols=rules_float_col_names)
new_conf_tbl = format_table(new_users_conf, 'New Users: Top Rules By Confidence',index_remap=rules_new_index_names,
                            float_cols=rules_float_col_names)

In [None]:
# Generate the tabular views

space = "\xa0" * 20
displayHTML(sup_tbl._repr_html_()+space+descr_tbl._repr_html_())

displayHTML(all_lift_tbl._repr_html_())
displayHTML(all_conf_tbl._repr_html_())
displayHTML(existing_lift_tbl._repr_html_())
displayHTML(existing_conf_tbl._repr_html_())
displayHTML(new_lift_tbl._repr_html_())
displayHTML(new_conf_tbl._repr_html_())