Skip to content

Commit

Permalink
Merge pull request #12 from ymoch/release/v0.9.1-beta
Browse files Browse the repository at this point in the history
Release/v0.9.1 beta
  • Loading branch information
ymoch committed Mar 13, 2016
2 parents 1747e76 + 9891f76 commit 4faa907
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 79 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ script:
- coverage run --source=apyori setup.py test
# Code quality check.
- pylint apyori.py test/*.py
# Integration test
- apyori-run data/integration_test_input_1.tsv > result.txt
- diff result.txt data/integration_test_output_1.txt
after_success:
- coveralls
notifications:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Features
- Is consisted of only one file and depends on no other libraries,
which enable you to use Apyori portably.
- Can be used as APIs.
- Supports a JSON output format.
- Supports a TSV output format for 2-items relations.


Expand Down
149 changes: 103 additions & 46 deletions apyori.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
#!/usr/bin/env python

"""
Implementation of Apriori algorithm.
a simple implementation of Apriori algorithm by Python.
"""

import sys
import csv
import argparse
import json
import os
from collections import namedtuple
from itertools import combinations
from itertools import chain


__version__ = '0.9.0'
__version__ = '0.9.1'
__author__ = 'ymoch'
__author_email__ = 'ymoch@githib.com'
__author_email__ = 'ymoch@github.com'


# Ignore name errors because these names are namedtuples.
Expand All @@ -28,14 +30,16 @@

class TransactionManager(object):
"""
Transaction manager class.
Transaction managers.
"""

def __init__(self, transactions):
"""
Initialize.
@param transactions A transaction list.
Arguments:
transactions -- A transaction iterable object
(eg. [['A', 'B'], ['B', 'C']]).
"""
self.__num_transaction = 0
self.__items = []
Expand All @@ -48,7 +52,8 @@ def add_transaction(self, transaction):
"""
Add a transaction.
@param transaction A transaction.
Arguments:
transaction -- A transaction as an iterable object (eg. ['A', 'B']).
"""
for item in transaction:
if item not in self.__transaction_index_map:
Expand All @@ -60,6 +65,9 @@ def add_transaction(self, transaction):
def calc_support(self, items):
"""
Returns a support for items.
Arguments:
items -- Items as an iterable object (eg. ['A', 'B']).
"""
# Empty items is supported by all transactions.
if not items:
Expand All @@ -74,18 +82,20 @@ def calc_support(self, items):
return 0.0

if sum_indexes is None:
# Assign the indexes on the first time.
sum_indexes = indexes
else:
# Calculate the intersection on not the first time.
sum_indexes = sum_indexes.intersection(indexes)

# Calculate the support.
# Calculate and return the support.
return float(len(sum_indexes)) / self.__num_transaction

def initial_candidates(self):
"""
Returns the initial candidates.
"""
return [frozenset([item]) for item in self.__items]
return [frozenset([item]) for item in self.items]

@property
def num_transaction(self):
Expand All @@ -99,7 +109,7 @@ def items(self):
"""
Returns the item list that the transaction is consisted of.
"""
return self.__items
return sorted(self.__items)

@staticmethod
def create(transactions):
Expand All @@ -114,13 +124,18 @@ def create(transactions):

def create_next_candidates(prev_candidates, length):
"""
Returns the apriori candidates.
Returns the apriori candidates as a list.
Arguments:
prev_candidates -- Previous candidates as a list.
length -- The lengths of the next candidates.
"""
# Solve the items.
items = set()
item_set = set()
for candidate in prev_candidates:
for item in candidate:
items.add(item)
item_set.add(item)
items = sorted(item_set)

def check_subsets(candidate):
"""
Expand All @@ -143,14 +158,25 @@ def check_subsets(candidate):
return next_candidates


def gen_support_records(
transaction_manager,
min_support,
max_length=None,
_generate_candidates_func=create_next_candidates):
def gen_support_records(transaction_manager, min_support, **kwargs):
"""
Returns a generator of support records with given transactions.
Arguments:
transaction_manager -- Transactions as a TransactionManager instance.
min_support -- A minimum support (float).
Keyword arguments:
max_length -- The maximum length of relations (integer).
"""
# Parse arguments.
max_length = kwargs.get('max_length')

# For testing.
_create_next_candidates = kwargs.get(
'_create_next_candidates', create_next_candidates)

# Process.
candidates = transaction_manager.initial_candidates()
length = 1
while candidates:
Expand All @@ -165,17 +191,20 @@ def gen_support_records(
length += 1
if max_length and length > max_length:
break
candidates = _generate_candidates_func(relations, length)
candidates = _create_next_candidates(relations, length)


def gen_ordered_statistics(transaction_manager, record):
"""
Returns a generator of ordered statistics.
Returns a generator of ordered statistics as OrderedStatistic instances.
Arguments:
transaction_manager -- Transactions as a TransactionManager instance.
record -- A support record as a SupportRecord instance.
"""
items = record.items
combination_sets = [
frozenset(x) for x in combinations(items, len(items) - 1)]
for items_base in combination_sets:
for combination_set in combinations(sorted(items), len(items) - 1):
items_base = frozenset(combination_set)
items_add = frozenset(items.difference(items_base))
confidence = (
record.support / transaction_manager.calc_support(items_base))
Expand All @@ -186,15 +215,22 @@ def gen_ordered_statistics(transaction_manager, record):

def apriori(transactions, **kwargs):
"""
Run Apriori algorithm.
Executes Apriori algorithm and returns a RelationRecord generator.
Arguments:
transactions -- A transaction iterable object
(eg. [['A', 'B'], ['B', 'C']]).
@param transactions A list of transactions.
@param min_support The minimum support of the relation (float).
@param max_length The maximum length of the relation (integer).
Keyword arguments:
min_support -- The minimum support of the relation (float).
max_length -- The maximum length of the relation (integer).
"""
# Parse the arguments.
min_support = kwargs.get('min_support', 0.1)
max_length = kwargs.get('max_length', None)
min_confidence = kwargs.get('min_confidence', 0.0)

# For testing.
_gen_support_records = kwargs.get(
'_gen_support_records', gen_support_records)
_gen_ordered_statistics = kwargs.get(
Expand All @@ -203,7 +239,7 @@ def apriori(transactions, **kwargs):
# Calculate supports.
transaction_manager = TransactionManager.create(transactions)
support_records = _gen_support_records(
transaction_manager, min_support, max_length)
transaction_manager, min_support, max_length=max_length)

# Calculate ordered stats.
for support_record in support_records:
Expand All @@ -218,12 +254,30 @@ def apriori(transactions, **kwargs):
filtered_ordered_statistics)


def load_transactions(input_file, **kwargs):
"""
Load transactions and returns a generator for transactions.
Arguments:
input_file -- An input file.
Keyword arguments:
delimiter -- The delimiter of the transaction.
"""
delimiter = kwargs.get('delimiter', '\t')
for transaction in csv.reader(input_file, delimiter=delimiter):
if not transaction:
continue
yield transaction


def dump_as_json(record, output_file):
"""
Dump an relation record as a json value.
@param record A record.
@param output_file An output file.
Arguments:
record -- A RelationRecord instance to dump.
output_file -- A file to output.
"""
def default_func(value):
"""
Expand All @@ -235,31 +289,37 @@ def default_func(value):

converted_record = record._replace(
ordered_statistics=[x._asdict() for x in record.ordered_statistics])
output_file.write(
json.dumps(converted_record._asdict(), default=default_func))
output_file.write('\n')
json.dump(
converted_record._asdict(), output_file,
default=default_func, ensure_ascii=False)
output_file.write(os.linesep)


def dump_as_two_item_tsv(record, output_file):
"""
Dump a relation record as TSV only for 2 item relations.
@param record A record.
@param output_file An output file.
Arguments:
record -- A RelationRecord instance to dump.
output_file -- A file to output.
"""
for ordered_stats in record.ordered_statistics:
if len(ordered_stats.items_base) != 1:
return
if len(ordered_stats.items_add) != 1:
return
output_file.write('{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}\n'.format(
output_file.write('{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}{5}'.format(
list(ordered_stats.items_base)[0], list(ordered_stats.items_add)[0],
record.support, ordered_stats.confidence, ordered_stats.lift))
record.support, ordered_stats.confidence, ordered_stats.lift,
os.linesep))


def parse_args(argv):
"""
Parse commandline arguments.
Arguments:
argv -- An argument list without the program name.
"""
output_funcs = {
'json': dump_as_json,
Expand Down Expand Up @@ -296,9 +356,10 @@ def parse_args(argv):
help='Delimiter for items of transactions (default: tab).',
type=str, default='\t')
parser.add_argument(
'-f', '--out-format', help='Output format (default or tsv).',
metavar='str', type=str, choices=output_funcs.keys(),
default=default_output_func_key)
'-f', '--out-format', metavar='str',
help='Output format ({0}; default: {1}).'.format(
', '.join(output_funcs.keys()), default_output_func_key),
type=str, choices=output_funcs.keys(), default=default_output_func_key)
args = parser.parse_args(argv)
if args.min_support <= 0:
raise ValueError('min support must be > 0')
Expand All @@ -308,19 +369,15 @@ def parse_args(argv):


def main():
"""
Main.
"""
""" Executes Apriori algorithm and print its result. """
args = parse_args(sys.argv[1:])

transactions = [
line.strip().split(args.delimiter) for line in chain(*args.input)]
transactions = load_transactions(
chain(*args.input), delimiter=args.delimiter)
result = apriori(
transactions,
max_length=args.max_length,
min_support=args.min_support,
min_confidence=args.min_confidence)

for record in result:
args.output_func(record, args.output)

Expand Down
18 changes: 18 additions & 0 deletions data/integration_test_output_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{"items": ["beer"], "support": 0.625, "ordered_statistics": [{"items_base": [], "items_add": ["beer"], "confidence": 0.625, "lift": 1.0}]}
{"items": ["jam"], "support": 0.5, "ordered_statistics": [{"items_base": [], "items_add": ["jam"], "confidence": 0.5, "lift": 1.0}]}
{"items": ["nuts"], "support": 0.625, "ordered_statistics": [{"items_base": [], "items_add": ["nuts"], "confidence": 0.625, "lift": 1.0}]}
{"items": ["beer", "butter"], "support": 0.25, "ordered_statistics": [{"items_base": ["butter"], "items_add": ["beer"], "confidence": 0.6666666666666666, "lift": 1.0666666666666667}]}
{"items": ["beer", "cheese"], "support": 0.25, "ordered_statistics": [{"items_base": ["cheese"], "items_add": ["beer"], "confidence": 0.6666666666666666, "lift": 1.0666666666666667}]}
{"items": ["beer", "jam"], "support": 0.375, "ordered_statistics": [{"items_base": ["beer"], "items_add": ["jam"], "confidence": 0.6, "lift": 1.2}, {"items_base": ["jam"], "items_add": ["beer"], "confidence": 0.75, "lift": 1.2}]}
{"items": ["beer", "nuts"], "support": 0.5, "ordered_statistics": [{"items_base": ["beer"], "items_add": ["nuts"], "confidence": 0.8, "lift": 1.28}, {"items_base": ["nuts"], "items_add": ["beer"], "confidence": 0.8, "lift": 1.28}]}
{"items": ["cheese", "nuts"], "support": 0.375, "ordered_statistics": [{"items_base": ["cheese"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["nuts"], "items_add": ["cheese"], "confidence": 0.6, "lift": 1.5999999999999999}]}
{"items": ["jam", "nuts"], "support": 0.375, "ordered_statistics": [{"items_base": ["jam"], "items_add": ["nuts"], "confidence": 0.75, "lift": 1.2}, {"items_base": ["nuts"], "items_add": ["jam"], "confidence": 0.6, "lift": 1.2}]}
{"items": ["beer", "butter", "jam"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "butter"], "items_add": ["jam"], "confidence": 0.5, "lift": 1.0}, {"items_base": ["butter", "jam"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]}
{"items": ["beer", "butter", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "butter"], "items_add": ["nuts"], "confidence": 0.5, "lift": 0.8}, {"items_base": ["butter", "nuts"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]}
{"items": ["beer", "cheese", "jam"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "cheese"], "items_add": ["jam"], "confidence": 0.5, "lift": 1.0}, {"items_base": ["cheese", "jam"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]}
{"items": ["beer", "cheese", "nuts"], "support": 0.25, "ordered_statistics": [{"items_base": ["beer", "cheese"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["beer", "nuts"], "items_add": ["cheese"], "confidence": 0.5, "lift": 1.3333333333333333}, {"items_base": ["cheese", "nuts"], "items_add": ["beer"], "confidence": 0.6666666666666666, "lift": 1.0666666666666667}]}
{"items": ["beer", "jam", "nuts"], "support": 0.375, "ordered_statistics": [{"items_base": ["beer", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["beer", "nuts"], "items_add": ["jam"], "confidence": 0.75, "lift": 1.5}, {"items_base": ["jam", "nuts"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]}
{"items": ["butter", "jam", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["butter", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["butter", "nuts"], "items_add": ["jam"], "confidence": 1.0, "lift": 2.0}]}
{"items": ["cheese", "jam", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["cheese", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}]}
{"items": ["beer", "butter", "jam", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "butter", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["beer", "butter", "nuts"], "items_add": ["jam"], "confidence": 1.0, "lift": 2.0}, {"items_base": ["butter", "jam", "nuts"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]}
{"items": ["beer", "cheese", "jam", "nuts"], "support": 0.125, "ordered_statistics": [{"items_base": ["beer", "cheese", "jam"], "items_add": ["nuts"], "confidence": 1.0, "lift": 1.6}, {"items_base": ["beer", "cheese", "nuts"], "items_add": ["jam"], "confidence": 0.5, "lift": 1.0}, {"items_base": ["cheese", "jam", "nuts"], "items_add": ["beer"], "confidence": 1.0, "lift": 1.6}]}
6 changes: 3 additions & 3 deletions test/test_apriori.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_empty():
Test for empty data.
"""
transaction_manager = Mock(spec=TransactionManager)
def gen_support_records(*_):
def gen_support_records(*args, **kwargs): # pylint: disable=unused-argument
""" Mock for apyori.gen_support_records. """
return iter([])

Expand Down Expand Up @@ -47,10 +47,10 @@ def test_normal():
ordered_statistic2 = OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.3, 0.5)

def gen_support_records(*args):
def gen_support_records(*args, **kwargs):
""" Mock for apyori.gen_support_records. """
eq_(args[1], min_support)
eq_(args[2], max_length)
eq_(kwargs['max_length'], max_length)
yield support_record

def gen_ordered_statistics(*_):
Expand Down
Loading

0 comments on commit 4faa907

Please sign in to comment.