Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/refactor #3

Merged
merged 11 commits into from
Mar 4, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ A simple implementation of *Apriori algorithm* by Python.
Features
--------

- Consisted of only one file and able to be used portably.
- Is consisted of only one file and depends on no other libraries,
which enable you to use Apyori portably.
- Can be used as APIs.
- Supports a TSV output format for 2-items relations.

Expand Down
140 changes: 87 additions & 53 deletions apyori.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import sys
import argparse
import json
from collections import namedtuple
from itertools import combinations
from itertools import chain


__version__ = '0.1.0'
Expand Down Expand Up @@ -38,21 +40,22 @@ def __init__(self, transactions):
self.__num_transaction = 0
self.__items = []
self.__transaction_index_map = {}
self.add_transactions(transactions)

def add_transactions(self, transactions):
for transaction in transactions:
self.add_transaction(transaction)

def add_transaction(self, transaction):
"""
Add transactions.
Add a transaction.

@param transactions A transaction list.
@param transaction A transaction.
"""
for transaction in transactions:
for item in transaction:
if item not in self.__transaction_index_map:
self.__items.append(item)
self.__transaction_index_map[item] = set()
self.__transaction_index_map[item].add(self.__num_transaction)
self.__num_transaction += 1
for item in transaction:
if item not in self.__transaction_index_map:
self.__items.append(item)
self.__transaction_index_map[item] = set()
self.__transaction_index_map[item].add(self.__num_transaction)
self.__num_transaction += 1

def calc_support(self, items):
"""
Expand Down Expand Up @@ -146,7 +149,7 @@ def gen_support_records(
max_length=None,
_generate_candidates_func=create_next_candidates):
"""
Returns the supported relations.
Returns a generator of support records with given transactions.
"""
candidates = transaction_manager.initial_candidates()
length = 1
Expand All @@ -167,7 +170,7 @@ def gen_support_records(

def gen_ordered_statistics(transaction_manager, record):
"""
Returns the relation stats.
Returns a generator of ordered statistics.
"""
items = record.items
combination_sets = [
Expand All @@ -192,15 +195,19 @@ def apriori(transactions, **kwargs):
min_support = kwargs.get('min_support', 0.1)
max_length = kwargs.get('max_length', None)
min_confidence = kwargs.get('min_confidence', 0.0)
_gen_support_records = kwargs.get(
'_gen_support_records', gen_support_records)
_gen_ordered_statistics = kwargs.get(
'_gen_ordered_statistics', gen_ordered_statistics)

# Calculate supports.
transaction_manager = TransactionManager.create(transactions)
support_records = gen_support_records(
support_records = _gen_support_records(
transaction_manager, min_support, max_length)

# Calculate stats.
# Calculate ordered stats.
for support_record in support_records:
ordered_statistics = gen_ordered_statistics(
ordered_statistics = _gen_ordered_statistics(
transaction_manager, support_record)
filtered_ordered_statistics = [
x for x in ordered_statistics if x.confidence >= min_confidence]
Expand All @@ -211,24 +218,31 @@ def apriori(transactions, **kwargs):
filtered_ordered_statistics)


def print_record_default(record, output_file):
def dump_as_json(record, output_file):
"""
Print an Apriori algorithm result.
Dump an relation record as a json value.

@param record A record.
@param output_file An output file.
"""
for ordered_stats in record.ordered_statistics:
output_file.write(
'{{{0}}} => {{{1}}} {2:.8f} {3:.8f} {4:.8f}\n'.format(
','.join(ordered_stats.items_base),
','.join(ordered_stats.items_add),
record.support, ordered_stats.confidence, ordered_stats.lift))
def default_func(value):
"""
Default conversion for JSON value.
"""
if isinstance(value, frozenset):
return sorted(value)
raise TypeError(repr(value) + " is not JSON serializable")

converted_record = record._replace(
ordered_statistics=[x._asdict() for x in record.ordered_statistics])
output_file.write(
json.dumps(converted_record._asdict(), default=default_func))
output_file.write('\n')


def print_record_as_two_item_tsv(record, output_file):
def dump_as_two_item_tsv(record, output_file):
"""
Print an Apriori algorithm result as two item TSV.
Dump a relation record as TSV only for 2 item relations.

@param record A record.
@param output_file An output file.
Expand All @@ -238,57 +252,77 @@ def print_record_as_two_item_tsv(record, output_file):
return
if len(ordered_stats.items_add) != 1:
return
output_file.write(
'{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}\n'.format(
[x for x in ordered_stats.items_base][0],
[x for x in ordered_stats.items_add][0],
record.support, ordered_stats.confidence, ordered_stats.lift))
output_file.write('{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}\n'.format(
list(ordered_stats.items_base)[0], list(ordered_stats.items_add)[0],
record.support, ordered_stats.confidence, ordered_stats.lift))


def main():
def parse_args(argv):
"""
Main.
Parse commandline arguments.
"""
output_funcs = {
'json': dump_as_json,
'tsv': dump_as_two_item_tsv,
}
default_output_func_key = 'json'

parser = argparse.ArgumentParser()
parser.add_argument(
'-i', '--input-file', help='Input file.', metavar='path',
type=argparse.FileType('r'), default=sys.stdin)
'-v', '--version', action='version',
version='%(prog)s {0}'.format(__version__))
parser.add_argument(
'input', metavar='inpath', nargs='*',
help='Input transaction file (default: stdin).',
type=argparse.FileType('r'), default=[sys.stdin])
parser.add_argument(
'-o', '--output-file', help='Output file.', metavar='path',
'-o', '--output', metavar='outpath',
help='Output file (default: stdout).',
type=argparse.FileType('w'), default=sys.stdout)
parser.add_argument(
'-l', '--max-length', help='Max length.', metavar='int',
'-l', '--max-length', metavar='int',
help='Max length of relations (default: infinite).',
type=int, default=None)
parser.add_argument(
'-s', '--min-support', help='Minimum support (0.0-1.0).',
metavar='float', type=float, default=0.15)
'-s', '--min-support', metavar='float',
help='Minimum support ratio (must be > 0, default: 0.1).',
type=float, default=0.1)
parser.add_argument(
'-c', '--min-confidence', help='Minimum confidence (0.0-1.0).',
metavar='float', type=float, default=0.6)
'-c', '--min-confidence', metavar='float',
help='Minimum confidence (default: 0.5).',
type=float, default=0.5)
parser.add_argument(
'-d', '--delimiter', help='Delimiter for input.',
metavar='str', type=str, default='\t')
'-d', '--delimiter', metavar='str',
help='Delimiter for items of transactions (default: tab).',
type=str, default='\t')
parser.add_argument(
'-f', '--out-format', help='Output format (default or tsv).',
metavar='str', type=str, choices=['default', 'tsv'],
default='default')
args = parser.parse_args()
metavar='str', type=str, choices=output_funcs.keys(),
default=default_output_func_key)
args = parser.parse_args(argv)
if args.min_support <= 0:
raise ValueError('min support must be > 0')

args.output_func = output_funcs[args.out_format]
return args


def main():
"""
Main.
"""
args = parse_args(sys.argv[1:])

transactions = [
line.strip().split(args.delimiter)
for line in args.input_file]
line.strip().split(args.delimiter) for line in chain(*args.input)]
result = apriori(
transactions,
max_length=args.max_length,
min_support=args.min_support,
min_confidence=args.min_confidence)

output_func = {
'default': print_record_default,
'tsv': print_record_as_two_item_tsv
}.get(args.out_format)
for record in result:
output_func(record, args.output_file)
args.output_func(record, args.output)


if __name__ == '__main__':
Expand Down
83 changes: 83 additions & 0 deletions test/test_apriori.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Tests for apyori.apriori.
"""

from nose.tools import eq_
from mock import Mock

from apyori import TransactionManager
from apyori import SupportRecord
from apyori import RelationRecord
from apyori import OrderedStatistic
from apyori import apriori


def test_empty():
"""
Test for empty data.
"""
transaction_manager = Mock(spec=TransactionManager)
def gen_support_records(*_):
""" Mock for apyori.gen_support_records. """
return iter([])

def gen_ordered_statistics(*_):
""" Mock for apyori.gen_ordered_statistics. """
yield OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.1, 0.7)

result = list(apriori(
transaction_manager,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [])


def test_normal():
"""
Test for normal data.
"""
transaction_manager = Mock(spec=TransactionManager)
min_support = 0.1
max_length = 2
support_record = SupportRecord(frozenset(['A', 'B']), 0.5)
ordered_statistic1 = OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.1, 0.7)
ordered_statistic2 = OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.3, 0.5)

def gen_support_records(*args):
""" Mock for apyori.gen_support_records. """
eq_(args[1], min_support)
eq_(args[2], max_length)
yield support_record

def gen_ordered_statistics(*_):
""" Mock for apyori.gen_ordered_statistics. """
yield ordered_statistic1
yield ordered_statistic2

# Will not create any records because of confidence.
result = list(apriori(
transaction_manager,
min_support=min_support,
min_confidence=0.4,
max_length=max_length,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [])

# Will create a record.
result = list(apriori(
transaction_manager,
min_support=min_support,
min_confidence=0.3,
max_length=max_length,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [RelationRecord(
support_record.items, support_record.support, [ordered_statistic2]
)])
58 changes: 58 additions & 0 deletions test/test_dump_as_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Tests for apyori.dump_as_json.
"""

import json

# For Python 2 compatibility.
try:
from StringIO import StringIO
except ImportError:
from io import StringIO

from nose.tools import raises
from nose.tools import eq_

from apyori import RelationRecord
from apyori import OrderedStatistic
from apyori import dump_as_json


def test_normal():
"""
Test for normal data.
"""
test_data = RelationRecord(
frozenset(['A']), 0.5,
[OrderedStatistic(frozenset([]), frozenset(['A']), 0.8, 1.2)]
)
output_file = StringIO()
dump_as_json(test_data, output_file)

output_file.seek(0)
result = json.loads(output_file.read())
eq_(result, {
'items': ['A'],
'support': 0.5,
'ordered_statistics': [
{
'items_base': [],
'items_add': ["A"],
'confidence': 0.8,
'lift': 1.2
}
]
})


@raises(TypeError)
def test_bad():
"""
Test for bad data.
"""
test_data = RelationRecord(
set(['A']), 0.5,
[OrderedStatistic(frozenset([]), frozenset(['A']), 0.8, 1.2)]
)
output_file = StringIO()
dump_as_json(test_data, output_file)
Loading