In [41]:
import sys 
from itertools import chain , combinations
from collections import defaultdict 
from optparse import OptionParser 
import time 

In [2]:
def subsets(arr):
    return chain(*[combinations(arr,i+1) for i ,a in enumerate(arr)])

In [20]:
def runItemsWithMinSupport(itemSet , transactionList , minSupport , freqSet):
    """计算每个项集的频次，并按照最小支持度进行过滤"""
    _itemSet = set()
    localSet = defaultdict(int) 
    totalLength = len(transactionList)
    # 计算每个商品集合的频数
    for item in itemSet:
        for transaction in transactionList :
            if item.issubset(transaction)  :
                freqSet[item] += 1 
                localSet[item] += 1
    # 过滤掉那些支持度不达标的集合
    for item,counts in localSet.items(): 
        support = counts*1.0/totalLength
        if support >= minSupport:
            _itemSet.add(item)
    return _itemSet 

In [32]:
def joinSet(itemSet,length):
    """根据已有的频繁项集获取新的频繁项集"""
    return set([ i.union(j) for i in itemSet for j in itemSet if len(i.union(j)) == length ])

In [5]:
def getItemSetAndTransactionList(data):
    """获取原始数据，然后获取所有的一元项集，
    这里的项集并非频繁项集，而是所有可能的项集"""
    transactionList = list()
    itemSet = set()
    for transaction in data :
        transactionList.append(transaction)
        transactionSet = frozenset(transaction)
        for item in transactionSet:
            itemSet.add(frozenset([item]))
    return transactionList , itemSet 

In [89]:
def runApriori(dataIterator,minSupport,minConfidence):
    """计算频繁项集并获取关联规则"""
    # 获取数据信息及一元项集
    transactionList , itemSet = getItemSetAndTransactionList(dataIterator)
    transactionLen = len(transactionList)
    #定义一个存储全局项集的频次的字典
    freqSet = defaultdict(int) 
    #定义一个存储所有频繁项集的集合
    globalFreqSet = {}
    #定义一个存储全局关联规则的列表
    globalRules = list()
    #获取一元频繁项集
    oneItemSet = runItemsWithMinSupport(itemSet, transactionList, minSupport,freqSet) 
    initialSet = oneItemSet 
    #print(" 1 ----> " , initialSet)
    k = 2 
    #计算所有的频繁项集
    while True:
        #print( len(initialSet)) 
        if len(initialSet) > 0 :
            #添加频繁项集
            globalFreqSet[k] = initialSet
        else:
            break
        # 生成可能的频繁更多一个元素的项集
        probaFreqSet = joinSet(initialSet,k)
        #print(probaFreqSet)
        confirmedFreqSet = runItemsWithMinSupport(probaFreqSet , transactionList , minSupport , freqSet)
        initialSet = confirmedFreqSet 
        k += 1 
    #获取关联规则，计算每条规则的置信度，并按找minConfidence进行过滤
    #定义一个计算事务概率的函数
    #print(freqSet)
    def get_probability(seta):
        return freqSet[seta]*1.0/transactionLen
    
    globalFreqSetWithSupport = []
    for key, value in globalFreqSet.items():
        globalFreqSetWithSupport.extend([(tuple(item), get_probability(item))
                           for item in value])
    
    #循环每个频繁项集
    for freqItems in globalFreqSet.values() :
        #print("freqItems = > " , freqItems)
        for oneSet in freqItems:
            _subsets = map(frozenset, [x for x in subsets(oneSet)])
            #获取该频繁项集的每个子集
            for leftSet in _subsets:
                # 获取规则的右边
                rightSet = oneSet.difference(leftSet)
                if len(rightSet) == 0 or len(leftSet) == 0 :
                    continue 
                else:
                    # 计算置信度
                    #print("{s1} {s2} {f1} {f2} ".format(s1=oneSet,s2=leftSet,f1=freqSet[oneSet],f2=freqSet[leftSet]))
                    confidence = get_probability(oneSet)/get_probability(leftSet) 
                if confidence >= minConfidence:
                    globalRules.append( (tuple(leftSet),tuple(rightSet),confidence) )
    return  globalFreqSetWithSupport,globalRules 

In [84]:
data = """
apple,beer,rice,chicken
apple,beer,rice
apple,beer
apple,mango
milk,beer,rice,chicken
milk,beer,rice
milk,beer
milk,mango
"""
# 定义一个读取事务集的函数
def dataFromString(s):
    s = s.strip()
    for line in s.split('\n'):
        record = frozenset(line.strip().split(','))
        yield record 
minSupport = 0.2 
minConfidence = 0.5
dataIterator = dataFromString(data)

In [85]:
transactions , sets  = getItemSetAndTransactionList(dataIterator)

In [10]:
sets

{frozenset({'mango'}),
 frozenset({'beer'}),
 frozenset({'rice'}),
 frozenset({'milk'}),
 frozenset({'apple'}),
 frozenset({'chicken'})}

In [11]:
transactions

[frozenset({'apple', 'beer', 'chicken', 'rice'}),
 frozenset({'apple', 'beer', 'rice'}),
 frozenset({'apple', 'beer'}),
 frozenset({'apple', 'mango'}),
 frozenset({'beer', 'chicken', 'milk', 'rice'}),
 frozenset({'beer', 'milk', 'rice'}),
 frozenset({'beer', 'milk'}),
 frozenset({'mango', 'milk'})]

In [24]:
lset = defaultdict(int)

In [25]:
set2 = runItemsWithMinSupport(sets, transactions, minSupport,lset)

In [26]:
set2

{frozenset({'mango'}),
 frozenset({'beer'}),
 frozenset({'rice'}),
 frozenset({'milk'}),
 frozenset({'apple'}),
 frozenset({'chicken'})}

In [27]:
lset

defaultdict(int,
            {frozenset({'mango'}): 2,
             frozenset({'beer'}): 6,
             frozenset({'rice'}): 4,
             frozenset({'milk'}): 4,
             frozenset({'apple'}): 4,
             frozenset({'chicken'}): 2})

In [28]:
len(transactions)*0.2

1.6

In [90]:
dataIterator = dataFromString(data)
globalSet , globalRule =  runApriori(dataIterator,minSupport,minConfidence) 

In [91]:
print(globalSet)

[(('mango',), 0.25), (('beer',), 0.75), (('rice',), 0.5), (('milk',), 0.5), (('apple',), 0.5), (('chicken',), 0.25), (('beer', 'chicken'), 0.25), (('beer', 'milk'), 0.375), (('beer', 'apple'), 0.375), (('rice', 'milk'), 0.25), (('beer', 'rice'), 0.5), (('rice', 'apple'), 0.25), (('rice', 'chicken'), 0.25), (('beer', 'rice', 'apple'), 0.25), (('beer', 'rice', 'chicken'), 0.25), (('beer', 'rice', 'milk'), 0.25)]


In [92]:
print(globalRule)

[(('chicken',), ('beer',), 1.0), (('beer',), ('milk',), 0.5), (('milk',), ('beer',), 0.75), (('beer',), ('apple',), 0.5), (('apple',), ('beer',), 0.75), (('rice',), ('milk',), 0.5), (('milk',), ('rice',), 0.5), (('beer',), ('rice',), 0.6666666666666666), (('rice',), ('beer',), 1.0), (('rice',), ('apple',), 0.5), (('apple',), ('rice',), 0.5), (('rice',), ('chicken',), 0.5), (('chicken',), ('rice',), 1.0), (('rice',), ('beer', 'apple'), 0.5), (('apple',), ('beer', 'rice'), 0.5), (('beer', 'rice'), ('apple',), 0.5), (('beer', 'apple'), ('rice',), 0.6666666666666666), (('rice', 'apple'), ('beer',), 1.0), (('rice',), ('beer', 'chicken'), 0.5), (('chicken',), ('beer', 'rice'), 1.0), (('beer', 'rice'), ('chicken',), 0.5), (('beer', 'chicken'), ('rice',), 1.0), (('rice', 'chicken'), ('beer',), 1.0), (('rice',), ('beer', 'milk'), 0.5), (('milk',), ('beer', 'rice'), 0.5), (('beer', 'rice'), ('milk',), 0.5), (('beer', 'milk'), ('rice',), 0.6666666666666666), (('rice', 'milk'), ('beer',), 1.0)]


In [93]:
len(globalRule)

28

In [94]:
len(globalSet)

16

In [96]:
for x in chain([1,2,3,4,5]) :
    print(x)

1
2
3
4
5


In [97]:
for x in combinations([1,2,3,4,5,6],3):
    print(x)

(1, 2, 3)
(1, 2, 4)
(1, 2, 5)
(1, 2, 6)
(1, 3, 4)
(1, 3, 5)
(1, 3, 6)
(1, 4, 5)
(1, 4, 6)
(1, 5, 6)
(2, 3, 4)
(2, 3, 5)
(2, 3, 6)
(2, 4, 5)
(2, 4, 6)
(2, 5, 6)
(3, 4, 5)
(3, 4, 6)
(3, 5, 6)
(4, 5, 6)


In [102]:
for x in chain(*[combinations([1,2,3,4,5,6],3),combinations([1,2,3,4,5,6],4)]):
    print(x)

(1, 2, 3)
(1, 2, 4)
(1, 2, 5)
(1, 2, 6)
(1, 3, 4)
(1, 3, 5)
(1, 3, 6)
(1, 4, 5)
(1, 4, 6)
(1, 5, 6)
(2, 3, 4)
(2, 3, 5)
(2, 3, 6)
(2, 4, 5)
(2, 4, 6)
(2, 5, 6)
(3, 4, 5)
(3, 4, 6)
(3, 5, 6)
(4, 5, 6)
(1, 2, 3, 4)
(1, 2, 3, 5)
(1, 2, 3, 6)
(1, 2, 4, 5)
(1, 2, 4, 6)
(1, 2, 5, 6)
(1, 3, 4, 5)
(1, 3, 4, 6)
(1, 3, 5, 6)
(1, 4, 5, 6)
(2, 3, 4, 5)
(2, 3, 4, 6)
(2, 3, 5, 6)
(2, 4, 5, 6)
(3, 4, 5, 6)


In [109]:
for x in chain(*[combinations([1,2,3,4,5,6],3) ]):
    print(x)

(1, 2, 3)
(1, 2, 4)
(1, 2, 5)
(1, 2, 6)
(1, 3, 4)
(1, 3, 5)
(1, 3, 6)
(1, 4, 5)
(1, 4, 6)
(1, 5, 6)
(2, 3, 4)
(2, 3, 5)
(2, 3, 6)
(2, 4, 5)
(2, 4, 6)
(2, 5, 6)
(3, 4, 5)
(3, 4, 6)
(3, 5, 6)
(4, 5, 6)


In [116]:
l = [(1,2,4),(2,3,5)]

In [118]:
for x in chain(l):
    print(x)

(1, 2, 4)
(2, 3, 5)


In [119]:
for x in chain(*l):
    print(x)

1
2
4
2
3
5
