# Q-learning tryout on Santa's uncertain bags

## States 

A state is characterized by a matrix of size `(N_BAGS, N_TYPES)`. For example, `s[0,:]=[1,0,1,0,0,0,0,0,0]`. The initial state is when the matrix is null or a customly defined. Terminal states are defined by state's score. 

How many state there are? There are at most `N_BAGS * 10^N_TYPES` states.


## Actions

Action is to add a toy following the list of available toys.


## Rewards

Action reward can be defined by the score of the bag where a toy has been added.


## Q-learning: Off-Policy Temporal Difference Control

In this algorithm we estimate action-value function $Q(s,a)$ as :
$$
Q(S_t,A_t) \leftarrow Q(S_t,A_t) + \alpha \left[ R_{t+1} + \gamma \max_{a} Q(S_{t+1}, a) - Q(S_t,A_t) \right], \, Q(\cal{S}^{+},a)=0
$$

**Algorithm**
<br>
<div style="background-color: #aaaaaa; padding: 10px; width: 75%; border: solid black; border-radius: 5px;">

    Initialize $Q(s, a)$, for all $s \in \cal{S}$, $a \in \cal{A}(s)$, arbitrarily, and $Q(\text{terminal-state}, \cdot) = 0$<br>
    Repeat (for each episode):<br>
    &emsp;Initialize $S$<br>
    &emsp;Choose $A$ from $S$ using policy derived from $Q$ (e.g., $\epsilon$-greedy)<br>
    &emsp;Repeat (for each step of episode):<br>
    &emsp;&emsp;Take action $A$, observe $R$, $S'$<br>
    &emsp;&emsp;$Q(S,A) \leftarrow Q(S,A) + \alpha \left[ R + \gamma \max_{a}Q(S', a) - Q(S,A) \right]$<br>
    &emsp;&emsp;$S \leftarrow S'; \, A \leftarrow A';$<br>
    &emsp;until $S$ is terminal
</div>

In [1]:
# https://ipython.org/ipython-doc/3/config/extensions/autoreload.html
%load_ext autoreload
%autoreload 2

In [2]:
from time import time
from copy import deepcopy
import numpy as np
np.random.seed(2016)

import logging
logging.getLogger().setLevel(logging.DEBUG)

In [3]:
import sys
sys.path.append('../common')
from utils import weight3 as weight_fn, weight_by_index
from utils import bag_weight, score
from utils import MAX_WEIGHT, AVAILABLE_GIFTS, GIFT_TYPES, N_TYPES, N_BAGS

In [214]:
REJECTED_BAGS_THRESHOLD = 0.1
TOTAL_MEAN_REJECTED_BAGS = 10
NEGATIVE_REWARD = -1000
POSITIVE_REWARD = 1000

def step_reward(rejected):    
    return 1.0 if rejected < REJECTED_BAGS_THRESHOLD else -rejected*10

In [5]:
initial_state = np.zeros((N_BAGS, N_TYPES), dtype=np.uint8)
alpha = 0.72
goal_weight = MAX_WEIGHT * N_BAGS * alpha

print goal_weight

36000.0


In [6]:
score(initial_state)

0.0

In [157]:
example_action = (500, 0)

In [158]:
def take_action(state, action):
    new_state = state.copy()
    new_state[action[0], action[1]] += 1
    return new_state

def is_available(state, available_gifts, gift_types=GIFT_TYPES):
    sum_gifts = np.sum(np.array(state), axis=0)
    for v, gift_type in zip(sum_gifts, gift_types):
        if available_gifts[gift_type] - v < 0:
            return False
    return True

def update_available_gifts(available_gifts, state, gift_types=GIFT_TYPES):
    sum_gifts = np.sum(state, axis=0)
    for v, gift_type in zip(sum_gifts, gift_types):
        assert available_gifts[gift_type] - v >= 0, "Found state is not available : {}, {}".format(state, available_gifts)
        available_gifts[gift_type] = available_gifts[gift_type] - v

In [159]:
new_state = take_action(initial_state, example_action)
new_score, rejected = score(new_state, return_rejected=True)
print new_state, new_score, is_available(new_state, AVAILABLE_GIFTS), step_reward(rejected)

[[0 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]
 ..., 
 [0 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]] 2.00230156967 True 0.0


In [173]:
from collections import defaultdict
import heapq
import hashlib

def state_to_str(state):
    return hashlib.md5(state).hexdigest()

def sum_gifts(state):
    return np.sum(np.array(state), axis=0)

In [174]:
print state_to_str(initial_state), sum_gifts(initial_state), sum_gifts(new_state)

5420afa22f6423a9f59e669540656bb4 [0 0 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0 0]


In [117]:
def find_value(action, actions_values, return_index=False):
    for i, (v, a) in enumerate(actions_values):
        if action == a:
            if return_index:
                return v, i
            return v
    raise Exception("No action={} in actions_values={}".format(action, actions_values))

In [177]:
def get_policy_action(state, action_value_function, epsilon=0.1):
    state_key = state_to_str(state)
    u = np.random.rand()
    if state_key in action_value_function and u > epsilon:
        actions_values = action_value_function[state_key]
        max_action_value = actions_values[0]
        return max_action_value[1]
    else:
        # Arbitrary initialization
        bag_id = np.random.randint(N_BAGS)
        toy_id = np.random.randint(N_TYPES)
        action = (bag_id, toy_id)
        # We store values as POSITIVE_REWARD - value to use heapq property that heap[0] is the smallest element
        # In our case this element corresponds to the largest value
        value = POSITIVE_REWARD - np.random.rand()
        heapq.heappush(action_value_function[state_key], [value, action])         
        return action

In [178]:
test_av_func = defaultdict(list)
state = initial_state
for i in range(2):
    a = get_policy_action(state, test_av_func, epsilon=0.5)
    state = take_action(state, a)
    print "Loop : ", sum_gifts(state), a
print "test_av_func: ", test_av_func

Loop :  [0 0 0 0 0 0 1 0 0] (319, 6)
Loop :  [0 0 0 0 0 0 2 0 0] (376, 6)
test_av_func:  defaultdict(<type 'list'>, {'19e80ad718223eca66e42c1da0f38ec0': [[999.7093964045671, (376, 6)]], '5420afa22f6423a9f59e669540656bb4': [[999.3494118180238, (319, 6)]]})


In [202]:
def q_learning(goal_weight, 
               available_gifts,
               initial_state=None,
               n_episodes=10, alpha=0.75, gamma=0.7, epsilon=0.001, action_value_function=None):
    
    logging.info("--- Q-learning : goal={}, n_episodes={}".format(goal_weight, n_episodes))
    if action_value_function is None:
        logging.info("-- Reset action_value_function")
        action_value_function = defaultdict(list)
    
    best_state = initial_state
    best_score = 0
    
    for i in range(n_episodes):

        logging.info("-- Episode : %i" % i)
        
        episode_length = N_BAGS * N_TYPES * 10
        state = np.zeros((N_BAGS, N_TYPES), dtype=np.uint8) if initial_state is None else initial_state        
        action = get_policy_action(state, action_value_function, epsilon=epsilon)
        state_score, rejected = score(state, return_rejected=True)
        is_terminal = state_score > goal_weight and is_available(state, available_gifts)
        
        logging.debug("Initial state score/action: {}, {}".format(state_score, action))
        
        while not is_terminal:
            
            episode_length -= 1 
            if episode_length < 0:
                logging.warn('Episode length is reached, but state score is still : %f / %f' % (state_score, goal_weight))
                break
            
            current_reward = 0 
            new_state = take_action(state, action)
            # Slower score computation
            #state_score, rejected = score(new_state, return_rejected=True)
            
            # Faster score computation
            new_bag = new_state[action[0]]
            bag_score, bag_rejected = score((new_bag,), return_rejected=True)
            rejected += bag_rejected
            state_score += bag_score
            
            
            if not is_available(new_state, available_gifts) or rejected > TOTAL_MEAN_REJECTED_BAGS:                
                current_reward = NEGATIVE_REWARD
                is_terminal = True
                logging.info("---> Episode finished with NEGATIVE reward")
            elif state_score >= goal_weight:
                current_reward = POSITIVE_REWARD
                is_terminal = True
                logging.info("---> Episode finished with POSITIVE reward")
                
                if best_score < state_score:
                    best_score = state_score
                    best_state = new_state
                
            elif state_score < goal_weight:
                current_reward = step_reward(rejected)
            else:
                raise Exception("Unclassified state: {}, score={}".format(new_state, state_score))

            logging.debug("New state score, reward, sum_gifts, action : {}, {}, {}, {}".format(state_score, current_reward, sum_gifts(new_state), action))                
                
            # Update Q(s,a)
            state_key = state_to_str(state)
            new_state_key = state_to_str(new_state)
            
            actions_values = action_value_function[state_key]
            action_value, action_index = find_value(action, actions_values, return_index=True)
            v = POSITIVE_REWARD - action_value
            # actions_values is a heap with first element being the smallest element
            # We store values in actions_values as POSITIVE_REWARD - Q(s,a)
            nv = POSITIVE_REWARD - actions_values[0][0]
            t = alpha * (current_reward + gamma * nv - v)
            
            action_value_function[state_key][action_index] = [POSITIVE_REWARD - (v + t), action]
            
            state = new_state
            action = get_policy_action(state, action_value_function, epsilon=epsilon)                        
                
    return action_value_function, best_score, best_state

In [215]:
final_action_value_function = defaultdict(list)
final_state = np.zeros((N_BAGS, N_TYPES), dtype=np.uint8)

In [220]:
logging.getLogger().setLevel(logging.INFO)
final_action_value_function, best_score, best_state = q_learning(goal_weight, 
                                                                 AVAILABLE_GIFTS,
                                                                 initial_state=final_state,
                                                                 n_episodes=1000, 
                                                                 alpha=0.75, 
                                                                 gamma=0.7, 
                                                                 epsilon=0.15, 
                                                                 action_value_function=final_action_value_function)

INFO:root:--- Q-learning : goal=36000.0, n_episodes=1000
INFO:root:-- Episode : 0
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 1
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 2
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 3
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 4
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 5
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 6
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 7
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 8
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 9
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 10
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 11
INFO:root:Episode finished with NEGATIVE reward
INFO:root:-- Episode : 12
INFO:root:Episode finished with NEGATIV

In [221]:
best_score, best_state

(0, array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ..., 
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8))

In [222]:
len(final_action_value_function)

701945

In [223]:
for k in final_action_value_function:
    actions_values = final_action_value_function[k]
    if len(actions_values) > 1:
        print final_action_value_function[k]

[[996.6873533852284, (831, 8)], [997.4318410780477, (206, 2)]]
[[996.6671534562975, (912, 1)], [998.6120117171893, (177, 4)], [998.0016038355111, (736, 6)], [997.5885779071169, (836, 6)], [997.347688639074, (922, 0)], [997.2625050269717, (926, 5)], [997.2811488829848, (413, 1)], [997.4527812735226, (337, 7)]]
[[996.6666666702142, (493, 5)], [998.3530442725099, (171, 6)], [997.5973349370843, (598, 5)], [997.6113232311084, (477, 6)], [997.4637027695585, (663, 5)], [997.3883207672454, (590, 5)], [997.4972471397264, (204, 5)], [997.469737780565, (706, 4)], [997.2685626527118, (722, 6)], [997.4146955571891, (805, 3)], [997.2563747713339, (322, 1)], [997.2693021727224, (413, 0)], [997.4525561081266, (735, 0)], [997.4411942548148, (93, 8)], [997.4971398631639, (782, 8)]]
[[996.9982855719969, (843, 0)], [997.7534895606111, (912, 3)]]
[[996.6666666666669, (731, 2)], [997.6047892983097, (583, 3)], [997.4887605569862, (37, 8)], [997.4905389182451, (368, 0)], [997.3891501710859, (291, 6)], [997.27