# Multi-Arm Bandit Problem

#### Explore vs exploit.

Consider two slot machines, A and B.

Each has some unknown probability of $p$ of winning whenever we play it.

At each iteration, we can choose to play either A or B, with the goal of maximising our long-term success rate

In [172]:
import random
import pandas

In [146]:
class Bandit:
    def __init__(self,ratio):
        self.ratio=ratio
        
    def play(self):
        'returns 1 for a win, 0 for a loss'
        return int(random.random()<self.ratio)
        

In [147]:
a=Bandit(random.random())
b=Bandit(random.random())

So we have two bandits. We need to formalize the concept of a *strategy*.

A strategy is just a function that is given the current game history and returns one of 'a' or 'b'.

In [4]:
def always_a(history):
    'Simplest possible strategy'
    return 'a'


def random_lever(history):
    return random.choice(['a','b']) 

    

Now define a couple of functions for evaluating the outcome of applying a particular strategy.

In [175]:

def apply_strategy(a,b,strategy,turns=1000):
    history=[]
    for i in range(turns):
        choice = strategy(history)
        assert choice in ('a','b')
        bandit = a if choice=='a' else b
        outcome=bandit.play()
        history.append((choice,outcome))
    return sum(outcome for choice,outcome in history)

def test_strategies(strategies,turns=1000):

    a=Bandit(random.random())
    b=Bandit(random.random())
    return {
        strategy.__name__:apply_strategy(a,b,strategy,turns=turns)
        for strategy in strategies
    }


In [176]:
test_strategies([
    random_lever
])

{'random_lever': 192}

Some more strategies:

In [177]:
def win_stay_lose_shift(history):
    if not history:
        return 'a'
    else:
        # what happened on the last turn?
        choice,outcome=history[-1]
        if outcome: # last choice worked, stick with it
            return choice
        else: # last choice didn't work, switch straight away
            return 'b' if choice=='a' else 'a'
        
        
# and the opposite...
def win_shift_lose_stay(history):
    if not history:
        return 'a'
    else:
        choice,outcome=history[-1]
        if outcome:
            return 'b' if choice=='a' else 'a'
        else:
            return choice


In [178]:
test_strategies([random_lever,win_stay_lose_shift,win_shift_lose_stay])

{'random_lever': 821, 'win_stay_lose_shift': 901, 'win_shift_lose_stay': 813}

#### Now test multiple strategies on many bandit pairs

In [179]:
strategies=[random_lever,win_stay_lose_shift,win_shift_lose_stay,always_a]
df=pandas.DataFrame([
    test_strategies(strategies,turns=1000)
    for i in range(10)
])

df

Unnamed: 0,always_a,random_lever,win_shift_lose_stay,win_stay_lose_shift
0,651,640,614,624
1,167,186,190,179
2,878,878,901,878
3,719,796,798,800
4,951,607,491,898
5,129,483,244,745
6,126,440,168,676
7,883,740,732,811
8,313,277,231,291
9,422,334,317,374


In [180]:
df.sum().sort_values()

win_shift_lose_stay    4686
always_a               5239
random_lever           5381
win_stay_lose_shift    6276
dtype: int64

### Explore/Exploit

Now try out some kind of explore vs exploit strategy.

First I just use hardcoded thresholds. On every turn, estimate the win/loss ratio of each bandit from the data we've currently observed. Then on a weighted coin toss decide whether to continue using what we believe to be the best bandit (exploit) or whether to switch in order to learn more (explore). 


In [181]:
def explore_exploit(threshold):
    assert 0<=threshold<=1
    def inner(history):
        aa = [outcome for choice,outcome in history if choice=='a']
        bb = [outcome for choice,outcome in history if choice=='b']

        if aa:
            a_estimate=sum(aa)/len(aa)
        else:
            a_estimate=0.5
        if bb:
            b_estimate=sum(bb)/len(bb)
        else:
            b_estimate=0.5

        best = 'a' if a_estimate>b_estimate else 'b'
        worst= 'a' if best=='b' else 'b'

        exploit = random.random() < threshold

        if exploit:
            return best
        else:
            return worst
        
    inner.__name__=f"explore_exploit_{threshold}"
    return inner

In [182]:
test_strategies([explore_exploit(0.9)])

{'explore_exploit_0.9': 678}

In [183]:
strategies=[
    random_lever,
    win_stay_lose_shift,
    win_shift_lose_stay,
    always_a,
    explore_exploit(0.4),
    explore_exploit(0.5),
    explore_exploit(0.6),
    explore_exploit(0.7),
    explore_exploit(0.8),
    explore_exploit(0.9)
]
df=pandas.DataFrame([
    test_strategies(strategies,turns=1000)
    for i in range(50)
])


In [184]:
df.sum().sort_values()

win_shift_lose_stay    19242
explore_exploit_0.4    22017
always_a               22416
explore_exploit_0.5    23905
random_lever           23973
explore_exploit_0.6    25664
explore_exploit_0.7    27200
win_stay_lose_shift    28499
explore_exploit_0.8    28755
explore_exploit_0.9    30302
dtype: int64

Finally, a more adaptive version of explore/exploit that starts our with a tendency to explore, but becomes more conservative as time progresses

In [185]:
def explore_exploit_adaptive(history):
    
    # I want threshold to trend towards 1 as the history grows.

    l=len(history)
    threshold = l/(10+l) 
    assert 0<=threshold<=1
    
    aa = [outcome for choice,outcome in history if choice=='a']
    bb = [outcome for choice,outcome in history if choice=='b']

    if aa:
        a_estimate=sum(aa)/len(aa)
    else:
        a_estimate=0.5
    if bb:
        b_estimate=sum(bb)/len(bb)
    else:
        b_estimate=0.5

    best = 'a' if a_estimate>b_estimate else 'b'
    worst= 'a' if best=='b' else 'b'

    exploit = random.random() < threshold

    if exploit:
        return best
    else:
        return worst
        

In [186]:

strategies=[
    random_lever,
    win_stay_lose_shift,
    win_shift_lose_stay,
    always_a,
    explore_exploit(0.4),
    explore_exploit(0.5),
    explore_exploit(0.6),
    explore_exploit(0.7),
    explore_exploit(0.8),
    explore_exploit(0.9),
    explore_exploit_adaptive
]
df=pandas.DataFrame([
    test_strategies(strategies,turns=1000)
    for i in range(50)
])
df.sum().sort_values()


win_shift_lose_stay         19123
explore_exploit_0.4         21768
always_a                    22659
explore_exploit_0.5         23619
random_lever                23828
explore_exploit_0.6         25578
explore_exploit_0.7         27390
win_stay_lose_shift         28476
explore_exploit_0.8         29032
explore_exploit_0.9         30612
explore_exploit_adaptive    31628
dtype: int64

So `explore_exploit_adaptive` is so far the best strategy we've come up with.