Skip to content

Commit

Permalink
Apriori use sets (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
WLaney authored and rasbt committed Jun 15, 2018
1 parent bcf3f44 commit 84fde97
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The CHANGELOG for the current development version is available at

##### Changes

- -
- Itemsets generated with `apriori` are now sets ([#344](https://github.com/rasbt/mlxtend/issues/344) by [William Laney](https://github.com/WLaney))


##### Bug Fixes
Expand Down
6 changes: 3 additions & 3 deletions mlxtend/frequent_patterns/apriori.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def apriori(df, min_support=0.5, use_colnames=False, max_len=None):
all_res = []
for k in sorted(itemset_dict):
support = pd.Series(support_dict[k])
itemsets = pd.Series([i for i in itemset_dict[k]])
itemsets = pd.Series([set(i) for i in itemset_dict[k]])

res = pd.concat((support, itemsets), axis=1)
all_res.append(res)
Expand All @@ -139,8 +139,8 @@ def apriori(df, min_support=0.5, use_colnames=False, max_len=None):
res_df.columns = ['support', 'itemsets']
if use_colnames:
mapping = {idx: item for idx, item in enumerate(df.columns)}
res_df['itemsets'] = res_df['itemsets'].apply(lambda x: [mapping[i]
for i in x])
res_df['itemsets'] = res_df['itemsets'].apply(lambda x: set([mapping[i]
for i in x]))
res_df = res_df.reset_index(drop=True)

return res_df
10 changes: 10 additions & 0 deletions mlxtend/frequent_patterns/tests/test_apriori.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,13 @@ def test_max_len():

res_df2 = apriori(df, max_len=2)
assert len(res_df2.iloc[-1, -1]) == 2


def test_itemsets_type():
res_colindice = apriori(df, use_colnames=False) # This is default behavior
for i in res_colindice['itemsets']:
assert isinstance(i, set) is True

res_colnames = apriori(df, use_colnames=True)
for i in res_colnames['itemsets']:
assert isinstance(i, set) is True

0 comments on commit 84fde97

Please sign in to comment.