Skip to content

Commit

Permalink
Merge pull request #3 from ymoch/features/refactor
Browse files Browse the repository at this point in the history
Features/refactor
  • Loading branch information
ymoch committed Mar 4, 2016
2 parents 49ba6a5 + ddadf0c commit 23ad310
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 59 deletions.
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -10,7 +10,8 @@ A simple implementation of *Apriori algorithm* by Python.
Features
--------

- Consisted of only one file and able to be used portably.
- Is consisted of only one file and depends on no other libraries,
which enable you to use Apyori portably.
- Can be used as APIs.
- Supports a TSV output format for 2-items relations.

Expand Down
140 changes: 87 additions & 53 deletions apyori.py
Expand Up @@ -6,8 +6,10 @@

import sys
import argparse
import json
from collections import namedtuple
from itertools import combinations
from itertools import chain


__version__ = '0.1.0'
Expand Down Expand Up @@ -38,21 +40,22 @@ def __init__(self, transactions):
self.__num_transaction = 0
self.__items = []
self.__transaction_index_map = {}
self.add_transactions(transactions)

def add_transactions(self, transactions):
for transaction in transactions:
self.add_transaction(transaction)

def add_transaction(self, transaction):
"""
Add transactions.
Add a transaction.
@param transactions A transaction list.
@param transaction A transaction.
"""
for transaction in transactions:
for item in transaction:
if item not in self.__transaction_index_map:
self.__items.append(item)
self.__transaction_index_map[item] = set()
self.__transaction_index_map[item].add(self.__num_transaction)
self.__num_transaction += 1
for item in transaction:
if item not in self.__transaction_index_map:
self.__items.append(item)
self.__transaction_index_map[item] = set()
self.__transaction_index_map[item].add(self.__num_transaction)
self.__num_transaction += 1

def calc_support(self, items):
"""
Expand Down Expand Up @@ -146,7 +149,7 @@ def gen_support_records(
max_length=None,
_generate_candidates_func=create_next_candidates):
"""
Returns the supported relations.
Returns a generator of support records with given transactions.
"""
candidates = transaction_manager.initial_candidates()
length = 1
Expand All @@ -167,7 +170,7 @@ def gen_support_records(

def gen_ordered_statistics(transaction_manager, record):
"""
Returns the relation stats.
Returns a generator of ordered statistics.
"""
items = record.items
combination_sets = [
Expand All @@ -192,15 +195,19 @@ def apriori(transactions, **kwargs):
min_support = kwargs.get('min_support', 0.1)
max_length = kwargs.get('max_length', None)
min_confidence = kwargs.get('min_confidence', 0.0)
_gen_support_records = kwargs.get(
'_gen_support_records', gen_support_records)
_gen_ordered_statistics = kwargs.get(
'_gen_ordered_statistics', gen_ordered_statistics)

# Calculate supports.
transaction_manager = TransactionManager.create(transactions)
support_records = gen_support_records(
support_records = _gen_support_records(
transaction_manager, min_support, max_length)

# Calculate stats.
# Calculate ordered stats.
for support_record in support_records:
ordered_statistics = gen_ordered_statistics(
ordered_statistics = _gen_ordered_statistics(
transaction_manager, support_record)
filtered_ordered_statistics = [
x for x in ordered_statistics if x.confidence >= min_confidence]
Expand All @@ -211,24 +218,31 @@ def apriori(transactions, **kwargs):
filtered_ordered_statistics)


def print_record_default(record, output_file):
def dump_as_json(record, output_file):
"""
Print an Apriori algorithm result.
Dump an relation record as a json value.
@param record A record.
@param output_file An output file.
"""
for ordered_stats in record.ordered_statistics:
output_file.write(
'{{{0}}} => {{{1}}} {2:.8f} {3:.8f} {4:.8f}\n'.format(
','.join(ordered_stats.items_base),
','.join(ordered_stats.items_add),
record.support, ordered_stats.confidence, ordered_stats.lift))
def default_func(value):
"""
Default conversion for JSON value.
"""
if isinstance(value, frozenset):
return sorted(value)
raise TypeError(repr(value) + " is not JSON serializable")

converted_record = record._replace(
ordered_statistics=[x._asdict() for x in record.ordered_statistics])
output_file.write(
json.dumps(converted_record._asdict(), default=default_func))
output_file.write('\n')


def print_record_as_two_item_tsv(record, output_file):
def dump_as_two_item_tsv(record, output_file):
"""
Print an Apriori algorithm result as two item TSV.
Dump a relation record as TSV only for 2 item relations.
@param record A record.
@param output_file An output file.
Expand All @@ -238,57 +252,77 @@ def print_record_as_two_item_tsv(record, output_file):
return
if len(ordered_stats.items_add) != 1:
return
output_file.write(
'{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}\n'.format(
[x for x in ordered_stats.items_base][0],
[x for x in ordered_stats.items_add][0],
record.support, ordered_stats.confidence, ordered_stats.lift))
output_file.write('{0}\t{1}\t{2:.8f}\t{3:.8f}\t{4:.8f}\n'.format(
list(ordered_stats.items_base)[0], list(ordered_stats.items_add)[0],
record.support, ordered_stats.confidence, ordered_stats.lift))


def main():
def parse_args(argv):
"""
Main.
Parse commandline arguments.
"""
output_funcs = {
'json': dump_as_json,
'tsv': dump_as_two_item_tsv,
}
default_output_func_key = 'json'

parser = argparse.ArgumentParser()
parser.add_argument(
'-i', '--input-file', help='Input file.', metavar='path',
type=argparse.FileType('r'), default=sys.stdin)
'-v', '--version', action='version',
version='%(prog)s {0}'.format(__version__))
parser.add_argument(
'input', metavar='inpath', nargs='*',
help='Input transaction file (default: stdin).',
type=argparse.FileType('r'), default=[sys.stdin])
parser.add_argument(
'-o', '--output-file', help='Output file.', metavar='path',
'-o', '--output', metavar='outpath',
help='Output file (default: stdout).',
type=argparse.FileType('w'), default=sys.stdout)
parser.add_argument(
'-l', '--max-length', help='Max length.', metavar='int',
'-l', '--max-length', metavar='int',
help='Max length of relations (default: infinite).',
type=int, default=None)
parser.add_argument(
'-s', '--min-support', help='Minimum support (0.0-1.0).',
metavar='float', type=float, default=0.15)
'-s', '--min-support', metavar='float',
help='Minimum support ratio (must be > 0, default: 0.1).',
type=float, default=0.1)
parser.add_argument(
'-c', '--min-confidence', help='Minimum confidence (0.0-1.0).',
metavar='float', type=float, default=0.6)
'-c', '--min-confidence', metavar='float',
help='Minimum confidence (default: 0.5).',
type=float, default=0.5)
parser.add_argument(
'-d', '--delimiter', help='Delimiter for input.',
metavar='str', type=str, default='\t')
'-d', '--delimiter', metavar='str',
help='Delimiter for items of transactions (default: tab).',
type=str, default='\t')
parser.add_argument(
'-f', '--out-format', help='Output format (default or tsv).',
metavar='str', type=str, choices=['default', 'tsv'],
default='default')
args = parser.parse_args()
metavar='str', type=str, choices=output_funcs.keys(),
default=default_output_func_key)
args = parser.parse_args(argv)
if args.min_support <= 0:
raise ValueError('min support must be > 0')

args.output_func = output_funcs[args.out_format]
return args


def main():
"""
Main.
"""
args = parse_args(sys.argv[1:])

transactions = [
line.strip().split(args.delimiter)
for line in args.input_file]
line.strip().split(args.delimiter) for line in chain(*args.input)]
result = apriori(
transactions,
max_length=args.max_length,
min_support=args.min_support,
min_confidence=args.min_confidence)

output_func = {
'default': print_record_default,
'tsv': print_record_as_two_item_tsv
}.get(args.out_format)
for record in result:
output_func(record, args.output_file)
args.output_func(record, args.output)


if __name__ == '__main__':
Expand Down
83 changes: 83 additions & 0 deletions test/test_apriori.py
@@ -0,0 +1,83 @@
"""
Tests for apyori.apriori.
"""

from nose.tools import eq_
from mock import Mock

from apyori import TransactionManager
from apyori import SupportRecord
from apyori import RelationRecord
from apyori import OrderedStatistic
from apyori import apriori


def test_empty():
"""
Test for empty data.
"""
transaction_manager = Mock(spec=TransactionManager)
def gen_support_records(*_):
""" Mock for apyori.gen_support_records. """
return iter([])

def gen_ordered_statistics(*_):
""" Mock for apyori.gen_ordered_statistics. """
yield OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.1, 0.7)

result = list(apriori(
transaction_manager,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [])


def test_normal():
"""
Test for normal data.
"""
transaction_manager = Mock(spec=TransactionManager)
min_support = 0.1
max_length = 2
support_record = SupportRecord(frozenset(['A', 'B']), 0.5)
ordered_statistic1 = OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.1, 0.7)
ordered_statistic2 = OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.3, 0.5)

def gen_support_records(*args):
""" Mock for apyori.gen_support_records. """
eq_(args[1], min_support)
eq_(args[2], max_length)
yield support_record

def gen_ordered_statistics(*_):
""" Mock for apyori.gen_ordered_statistics. """
yield ordered_statistic1
yield ordered_statistic2

# Will not create any records because of confidence.
result = list(apriori(
transaction_manager,
min_support=min_support,
min_confidence=0.4,
max_length=max_length,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [])

# Will create a record.
result = list(apriori(
transaction_manager,
min_support=min_support,
min_confidence=0.3,
max_length=max_length,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [RelationRecord(
support_record.items, support_record.support, [ordered_statistic2]
)])
58 changes: 58 additions & 0 deletions test/test_dump_as_json.py
@@ -0,0 +1,58 @@
"""
Tests for apyori.dump_as_json.
"""

import json

# For Python 2 compatibility.
try:
from StringIO import StringIO
except ImportError:
from io import StringIO

from nose.tools import raises
from nose.tools import eq_

from apyori import RelationRecord
from apyori import OrderedStatistic
from apyori import dump_as_json


def test_normal():
"""
Test for normal data.
"""
test_data = RelationRecord(
frozenset(['A']), 0.5,
[OrderedStatistic(frozenset([]), frozenset(['A']), 0.8, 1.2)]
)
output_file = StringIO()
dump_as_json(test_data, output_file)

output_file.seek(0)
result = json.loads(output_file.read())
eq_(result, {
'items': ['A'],
'support': 0.5,
'ordered_statistics': [
{
'items_base': [],
'items_add': ["A"],
'confidence': 0.8,
'lift': 1.2
}
]
})


@raises(TypeError)
def test_bad():
"""
Test for bad data.
"""
test_data = RelationRecord(
set(['A']), 0.5,
[OrderedStatistic(frozenset([]), frozenset(['A']), 0.8, 1.2)]
)
output_file = StringIO()
dump_as_json(test_data, output_file)

0 comments on commit 23ad310

Please sign in to comment.