# 01 Shapley numba tutorial

In [None]:
import numba
import numpy as np

from shapley_numba import numba_game
from shapley_numba.shapley import shapley

To define a game with `shapley_numba` you need to create a class which will be [_jit-compiled_](https://numba.readthedocs.io/en/stable/user/jitclass.html) by *numba*. The class needs to have a method called `value` which takes a parameter `subset`. `subset` is a numpy integer array of zeros and ones.

The game can have an internal state, which is defined with `__init__`. The internal variables need to be declared using a numba spec.

You can utilize the facility of shapley-numba by adding a `@numba_game` decorator. The `@numba_game` decorator will try to apply `numba.experimental.jitclass` which tries to _jit-compile_ your class with numba. The decorator needs a spec parameter, which defines the types of internal elements of your class.

Example for glove game.

In [None]:
glove_spec = [('num_left_gloves', numba.int_)]


@numba_game(glove_spec)
class GloveGame:
    def __init__(self, num_left_gloves):
        self.num_left_gloves = num_left_gloves

    def value(self, subset):
        left_gloves = np.sum(subset[: self.num_left_gloves])
        right_gloves = np.sum(subset[self.num_left_gloves :])
        return min(left_gloves, right_gloves)

Define game parameters:

In [3]:
num_players = 3
num_left_gloves = 1

Now we can run the computation

In [5]:
glove_game = GloveGame(num_left_gloves)
result = shapley(glove_game, num_players)
result

array([0.66666667, 0.16666667, 0.16666667])

The result is of course the famous $\frac{2}{3}, \frac{1}{6}, \frac{1}{6}$ result.

The `@numba_game` decorator allows you to use the game even if numba compilation fails. You can still use the regular, python-only computation albeit slower.

In [11]:
%%timeit
shapley(glove_game, num_players, use_numba=False)

44.1 μs ± 3.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
%%timeit
shapley(glove_game, num_players)  # use_numba=True is the default

8.31 μs ± 599 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
