From 84fde9737f977298139f3ed658dc0704874625a1 Mon Sep 17 00:00:00 2001 From: William Laney Date: Thu, 14 Jun 2018 21:44:09 -0400 Subject: [PATCH] Apriori use sets (#393) --- docs/sources/CHANGELOG.md | 2 +- mlxtend/frequent_patterns/apriori.py | 6 +++--- mlxtend/frequent_patterns/tests/test_apriori.py | 10 ++++++++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/sources/CHANGELOG.md b/docs/sources/CHANGELOG.md index 8407a4350..c52e175f3 100755 --- a/docs/sources/CHANGELOG.md +++ b/docs/sources/CHANGELOG.md @@ -26,7 +26,7 @@ The CHANGELOG for the current development version is available at ##### Changes -- - +- Itemsets generated with `apriori` are now sets ([#344](https://github.com/rasbt/mlxtend/issues/344) by [William Laney](https://github.com/WLaney)) ##### Bug Fixes diff --git a/mlxtend/frequent_patterns/apriori.py b/mlxtend/frequent_patterns/apriori.py index 8aa087a47..18e3c1fdd 100644 --- a/mlxtend/frequent_patterns/apriori.py +++ b/mlxtend/frequent_patterns/apriori.py @@ -130,7 +130,7 @@ def apriori(df, min_support=0.5, use_colnames=False, max_len=None): all_res = [] for k in sorted(itemset_dict): support = pd.Series(support_dict[k]) - itemsets = pd.Series([i for i in itemset_dict[k]]) + itemsets = pd.Series([set(i) for i in itemset_dict[k]]) res = pd.concat((support, itemsets), axis=1) all_res.append(res) @@ -139,8 +139,8 @@ def apriori(df, min_support=0.5, use_colnames=False, max_len=None): res_df.columns = ['support', 'itemsets'] if use_colnames: mapping = {idx: item for idx, item in enumerate(df.columns)} - res_df['itemsets'] = res_df['itemsets'].apply(lambda x: [mapping[i] - for i in x]) + res_df['itemsets'] = res_df['itemsets'].apply(lambda x: set([mapping[i] + for i in x])) res_df = res_df.reset_index(drop=True) return res_df diff --git a/mlxtend/frequent_patterns/tests/test_apriori.py b/mlxtend/frequent_patterns/tests/test_apriori.py index 9e24f3ce6..967345398 100644 --- a/mlxtend/frequent_patterns/tests/test_apriori.py +++ b/mlxtend/frequent_patterns/tests/test_apriori.py @@ -52,3 +52,13 @@ def test_max_len(): res_df2 = apriori(df, max_len=2) assert len(res_df2.iloc[-1, -1]) == 2 + + +def test_itemsets_type(): + res_colindice = apriori(df, use_colnames=False) # This is default behavior + for i in res_colindice['itemsets']: + assert isinstance(i, set) is True + + res_colnames = apriori(df, use_colnames=True) + for i in res_colnames['itemsets']: + assert isinstance(i, set) is True