-
Notifications
You must be signed in to change notification settings - Fork 93
/
test_apriori.py
83 lines (71 loc) · 2.3 KB
/
test_apriori.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Tests for apyori.apriori.
"""
from nose.tools import eq_
from mock import Mock
from apyori import TransactionManager
from apyori import SupportRecord
from apyori import RelationRecord
from apyori import OrderedStatistic
from apyori import apriori
def test_empty():
"""
Test for empty data.
"""
transaction_manager = Mock(spec=TransactionManager)
def gen_support_records(*_):
""" Mock for apyori.gen_support_records. """
return iter([])
def gen_ordered_statistics(*_):
""" Mock for apyori.gen_ordered_statistics. """
yield OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.1, 0.7)
result = list(apriori(
transaction_manager,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [])
def test_normal():
"""
Test for normal data.
"""
transaction_manager = Mock(spec=TransactionManager)
min_support = 0.1
max_length = 2
support_record = SupportRecord(frozenset(['A', 'B']), 0.5)
ordered_statistic1 = OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.1, 0.7)
ordered_statistic2 = OrderedStatistic(
frozenset(['A']), frozenset(['B']), 0.3, 0.5)
def gen_support_records(*args):
""" Mock for apyori.gen_support_records. """
eq_(args[1], min_support)
eq_(args[2], max_length)
yield support_record
def gen_ordered_statistics(*_):
""" Mock for apyori.gen_ordered_statistics. """
yield ordered_statistic1
yield ordered_statistic2
# Will not create any records because of confidence.
result = list(apriori(
transaction_manager,
min_support=min_support,
min_confidence=0.4,
max_length=max_length,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [])
# Will create a record.
result = list(apriori(
transaction_manager,
min_support=min_support,
min_confidence=0.3,
max_length=max_length,
_gen_support_records=gen_support_records,
_gen_ordered_statistics=gen_ordered_statistics,
))
eq_(result, [RelationRecord(
support_record.items, support_record.support, [ordered_statistic2]
)])