# 基于关联规则的推荐
《机器学习系统设计》第八章

## 购物篮分析

In [1]:
# use download_retail.sh to download train data
!ls data/retail.dat.gz

data/retail.dat.gz


In [3]:
import numpy as np
from collections import defaultdict
from itertools import chain
from gzip import GzipFile
dataset = [[int(tok) for tok in line.strip().split()]
           for line in GzipFile('data/retail.dat.gz')]
counts = defaultdict(int)
for elem in chain(*dataset):
    counts[elem] += 1
counts = np.array(list(counts.values()))
bins = [1, 2, 4, 8, 16, 32, 64, 128, 512]
print(' {0:11} | {1:12}'.format('Nr of baskets', 'Nr of products'))
print('--------------------------------')
for i in range(len(bins)):
    bot = bins[i]
    top = (bins[i + 1] if (i + 1) < len(bins) else 100000000000)
    print('  {0:4} - {1:3}   | {2:12}'.format(
        bot, (top if top < 1000 else ''), np.sum((counts >= bot) & (counts < top))))

 Nr of baskets | Nr of products
--------------------------------
     1 -   2   |         2224
     2 -   4   |         2438
     4 -   8   |         2508
     8 -  16   |         2251
    16 -  32   |         2182
    32 -  64   |         1940
    64 - 128   |         1523
   128 - 512   |         1225
   512 -       |          179


## Apriori
Apriori会将一些集合(购物篮)当做输入，返回这些集合中出现频率非常高的子集(很多购物篮中的商品)。

这个算法是以自底向上的方式工作的：从最小的候选集合开始(只包含一个元素)，然后每次加入一个元素，并且不断增大。

定义最小支持度：minsupport = 80

支持度就是商品被一起购买的次数。Apriori的目标就是寻找一个高支持度的项集(itemset)。从逻辑上讲，任何具有最小支持度的项集，里面每个物品都至少具有该最小支持度：
valid = set(k for k,v in counts.items() if v >= minsupport)

In [4]:
from collections import namedtuple


def apriori(dataset, minsupport, maxsize):
    """
    freqsets, support = apriori(dataset, minsupport, maxsize)

    Parameters
    ----------
    dataset : sequence of sequences
        input dataset
    minsupport : int
        Minimal support for frequent items
    maxsize : int
        Maximal size of frequent items to return

    Returns
    -------
    freqsets : sequence of sequences
    support : dictionary
        This associates each itemset (represented as a frozenset) with a float
        (the support of that itemset)
    """
    from collections import defaultdict

    baskets = defaultdict(list)
    pointers = defaultdict(list)

    for i, ds in enumerate(dataset):
        for ell in ds:
            pointers[ell].append(i)
            baskets[frozenset([ell])].append(i)

    # Convert pointer items to frozensets to speed up operations later
    new_pointers = dict()
    for k in pointers:
        if len(pointers[k]) >= minsupport:
            new_pointers[k] = frozenset(pointers[k])
    pointers = new_pointers
    for k in baskets:
        baskets[k] = frozenset(baskets[k])

    # Valid are all elements whose support is >= minsupport
    valid = set()
    for el, c in baskets.items():
        if len(c) >= minsupport:
            valid.update(el)

    # Itemsets at first iteration are simply all singleton with valid elements:
    itemsets = [frozenset([v]) for v in valid]
    freqsets = []
    for i in range(maxsize - 1):
        print("At iteration {}, number of frequent baskets: {}".format(i, len(itemsets)))
        newsets = []
        for it in itemsets:
            ccounts = baskets[it]

            for v, pv in pointers.items():
                if v not in it:
                    csup = (ccounts & pv)
                    if len(csup) >= minsupport:
                        new = frozenset(it | frozenset([v]))
                        if new not in baskets:
                            newsets.append(new)
                            baskets[new] = csup
        freqsets.extend(itemsets)
        itemsets = newsets
        if not len(itemsets):
            break
    support = {}
    for k in baskets:
        support[k] = float(len(baskets[k]))
    return freqsets, support


# A namedtuple to collect all values that may be interesting
AssociationRule = namedtuple('AssociationRule', ['antecendent', 'consequent', 'base', 'py_x', 'lift'])

def association_rules(dataset, freqsets, support, minlift):
    """
    for assoc_rule in association_rules(dataset, freqsets, support, minlift):
        ...

    This function takes the returns from ``apriori``.

    Parameters
    ----------
    dataset : sequence of sequences
        input dataset
    freqsets : sequence of sequences
    support : dictionary
    minlift : int
        minimal lift of yielded rules

    Returns
    -------
    assoc_rule : sequence of AssociationRule objects
    """
    nr_transactions = float(len(dataset))
    freqsets = [f for f in freqsets if len(f) > 1]
    for fset in freqsets:
        for f in fset:
            consequent = frozenset([f])
            antecendent = fset - consequent
            py_x = support[fset] / support[antecendent]
            base = support[consequent] / nr_transactions
            lift = py_x / base
            if lift > minlift:
                yield AssociationRule(antecendent, consequent, base, py_x, lift)

## 关联规则挖掘
『如果X则Y』，如果用户购买了X，那么他们还会购买Y。

提升度lift：$lift(X\rightarrow Y)=\frac{P(X|Y)}{P(Y)}$

提升度可以避免推荐热销商品，对于一个热销商品，P(Y)和P(X|Y)都会很大。因此，如果提升度接近1，那么这条规则就会被认为是很不相关的。在实践中，我们希望这个值至少是10，或甚至是100。

In [7]:
from gzip import GzipFile

# Load dataset
dataset = [[int(tok) for tok in line.strip().split()]
           for line in GzipFile('data/retail.dat.gz')]

freqsets, support = apriori(dataset, 80, maxsize=16)
rules = list(association_rules(dataset, freqsets, support, minlift=30.0))

rules.sort(key=(lambda ar: -ar.lift))
for ar in rules:
    print('{} -> {} (lift = {:.4})' .format(set(ar.antecendent), set(ar.consequent), ar.lift))

At iteration 0, number of frequent baskets: 2370
At iteration 1, number of frequent baskets: 3797
At iteration 2, number of frequent baskets: 2131
At iteration 3, number of frequent baskets: 483
At iteration 4, number of frequent baskets: 47
At iteration 5, number of frequent baskets: 1
set([3537, 3402]) -> set([3535]) (lift = 343.3)
set([696]) -> set([699]) (lift = 338.3)
set([699]) -> set([696]) (lift = 338.3)
set([3537, 3535]) -> set([3402]) (lift = 325.7)
set([1818, 795, 3311]) -> set([1819]) (lift = 318.1)
set([1219]) -> set([4486]) (lift = 313.1)
set([4486]) -> set([1219]) (lift = 313.1)
set([3402, 39]) -> set([3535]) (lift = 312.3)
set([3535, 39]) -> set([3402]) (lift = 306.0)
set([3535]) -> set([3402]) (lift = 305.2)
set([3402]) -> set([3535]) (lift = 305.2)
set([795, 1818, 1819]) -> set([3311]) (lift = 302.7)
set([795, 1819, 3311]) -> set([1818]) (lift = 302.0)
set([3402, 3535]) -> set([3537]) (lift = 298.3)
set([795, 3311]) -> set([1819]) (lift = 296.3)
set([1818, 1819, 3311]