In [1]:
from collections import Counter, defaultdict as dd
from itertools import combinations

import os

In [2]:
# Import PyDrive and associated libraries.
# This only needs to be done once per notebook.
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# Download a file based on its file ID.
#
# A file ID looks like: laggVyWshwcyP6kEI-y_W3P8D26sz
file_id = '1ZlP3yxyDWknKO7wOTDYklTu_redUQVv2'
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('browsing.txt')

In [3]:
SUPPORT_THRESHOLD = 100

In [4]:
def load_dataset(file_path):

    lines = []
    with open(file_path, 'r') as f:
        lines = f.readlines()
    
    transactions = list(map(str.strip, lines))
    transactions = list(map(lambda t: t.split(), transactions))
    
    return transactions

In [5]:
def flatten(outer_list):
    '''
    Convert a 2d list into a 1d list
    '''
    return [item for inner_list in outer_list for item in inner_list]

In [6]:
def remove_counts_below_support(c, support_threshold=SUPPORT_THRESHOLD):
    '''
    Returns a pruned dictionary such that all the values in the dictionary
    are >= the support threshold provided
    '''
    pruned_dict = dict(filter(lambda x: x[1] >= support_threshold, c.items()))
    
    return pruned_dict

In [7]:
def get_candidate_set_counts_2(transactions, r):
    '''
    Given a set of transactions, this function creates
    combinations and counts the occurence of the combination 
    in all the transactions
    '''
    candidate_set_counts = dd(int)

    for transaction in transactions:
        combs = combinations(sorted(list(transaction)), r)

        for itemset in combs:
            candidate_set_counts[itemset] += 1
    
    return candidate_set_counts

In [8]:
def get_confidence(intersect, denominator, frequent_itemsets):
    '''
    This functions calculates the confidence for the 
    specified intersect and denominator value. 
    '''

    numerator = tuple(sorted(list(intersect)))
    
    if numerator not in frequent_itemsets.keys():
        print(f"Error! ({numerator}) does not exist!")
        return 0
    
    if isinstance(denominator, tuple):
        denominator = tuple(sorted(list(denominator)))

    confidence_score = frequent_itemsets[numerator] / frequent_itemsets[denominator]

    return confidence_score

Parse the specified file to obtain a list of transactions. The file is stored in the following structure after being parsed:

```
[['FRO11987', 'ELE17451', 'ELE89019', 'SNA90258', 'GRO99222'],
 ['GRO99222',
  'GRO12298',
  'FRO12685',
  'ELE91550',
  'SNA11465',
  'ELE26917',
  'ELE52966',
  'FRO90334',
  'SNA30755',
  'ELE17451',
  'FRO84225',
  'SNA80192'],
 ['ELE17451', 'GRO73461', 'DAI22896', 'SNA99873', 'FRO86643']]
```

Where each inner list is a 'basket'/transaction and each item inside the the 'basket' is a purchased item

In [9]:
Q2_FILE_PATH = 'browsing.txt'
transactions = load_dataset(Q2_FILE_PATH)
transactions_flat = flatten(transactions)

NUM_TRANSACTIONS = len(transactions)
print(f"There are {NUM_TRANSACTIONS} rows/customer sessions in the file")

There are 31101 rows/customer sessions in the file


The initial candidate set is the obtained and pruned based on the specified support threshold. 

In [10]:
C1 = dict(Counter(transactions_flat))
L1 = remove_counts_below_support(C1)
print(f"There are {len(L1)} unique items/products in the file")

There are 647 unique items/products in the file


Using C1 and L1, the unigram candidate sets, we now obtain the bigram candidate sets and prune that as well

In [11]:
C2 = get_candidate_set_counts_2(transactions, r=2)
L2 = remove_counts_below_support(C2)
print(f"There are {len(L2)} unique items/products in the file")

There are 1334 unique items/products in the file


In [12]:
sorted(list(L2.items()), key=lambda x: (-x[1], x[0]))[:5]

[(('DAI62779', 'ELE17451'), 1592),
 (('FRO40251', 'SNA80324'), 1412),
 (('DAI75645', 'FRO40251'), 1254),
 (('FRO40251', 'GRO85051'), 1213),
 (('DAI62779', 'GRO73461'), 1139)]

In [13]:
C3 = get_candidate_set_counts_2(transactions, r=3)
L3 = remove_counts_below_support(C3)
print(f"There are {len(L3)} unique items/products in the file")

There are 233 unique items/products in the file


In [14]:
sorted(list(L3.items()), key=lambda x: (-x[1], x[0]))[:5]

[(('DAI75645', 'FRO40251', 'SNA80324'), 550),
 (('DAI62779', 'FRO40251', 'SNA80324'), 476),
 (('FRO40251', 'GRO85051', 'SNA80324'), 471),
 (('DAI62779', 'ELE92920', 'SNA18336'), 432),
 (('DAI62779', 'DAI75645', 'SNA80324'), 421)]

In [15]:
frequent_itemsets = {**L1, **L2, **L3}

The confidence for the top 5 items in each candidate set is calculated

In [16]:
tuple_confidence_scores = []
for itemset in L2.keys():
    x, y = itemset
    tuple_confidence_scores.append(((x, y), get_confidence(itemset, x, frequent_itemsets)))
    tuple_confidence_scores.append(((y, x), get_confidence(itemset, y, frequent_itemsets)))

In [17]:
tuple_confidence_scores.sort(key=lambda x: (-x[1], x[0][0], x[0][1]))
tuple_confidence_scores[:5]

[(('DAI93865', 'FRO40251'), 1.0),
 (('GRO85051', 'FRO40251'), 0.999176276771005),
 (('GRO38636', 'FRO40251'), 0.9906542056074766),
 (('ELE12951', 'FRO40251'), 0.9905660377358491),
 (('DAI88079', 'FRO40251'), 0.9867256637168141)]

In [18]:
tuple_3_confidence_scores = []
for itemset in L3.keys():
    x, y, z = itemset
    tuple_3_confidence_scores.append(((x, y, z), get_confidence(itemset, (x, y), frequent_itemsets)))
    tuple_3_confidence_scores.append(((x, z, y), get_confidence(itemset, (x, z), frequent_itemsets)))
    tuple_3_confidence_scores.append(((y, z, x), get_confidence(itemset, (y, z), frequent_itemsets)))

In [19]:
tuple_3_confidence_scores.sort(key=lambda x: (-x[1], x[0]))
tuple_3_confidence_scores[:5]

[(('DAI23334', 'ELE92920', 'DAI62779'), 1.0),
 (('DAI31081', 'GRO85051', 'FRO40251'), 1.0),
 (('DAI55911', 'GRO85051', 'FRO40251'), 1.0),
 (('DAI62779', 'DAI88079', 'FRO40251'), 1.0),
 (('DAI75645', 'GRO85051', 'FRO40251'), 1.0)]