In [None]:
import networkx as nx
import numpy as np
import pandas as pd
import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class RoutingEnv(object):
    def __init__(self, g, init_state, target_state, keep_hist=False):
        self.g = g
        self.all_states = list(self.g.nodes)
        self.all_actions = list(self.g.edges)
        self.s = init_state
        self.t = target_state
        self.keep_hist = keep_hist
        self.start()
    
    def start(self):
        self.state = self.s
        self.act_hist = []
        self.state_hist = []
        self.hist = dict(state=[], act=[])
        
    def reset(self):
        if self.keep_hist:
            self.episode_hist['state'].append(self.state_hist)
            self.episode_hist['act'].append(self.act_hist)
        
        self.state = self.s
        self.act_hist = []
        self.state_hist = []
    
    def step(self, a):
        '''
            a: an edge (pair of nodes) in the graph g
        '''
        assert a[0] == self.state
        assert a in self.g.edges
        
        repeated_action = a in self.act_hist
    
        self.state = a[1]
        if self.keep_hist:
            self.state_hist.append(self.state)
            self.act_hist.append(a)

        is_deadend = len(list(self.g.neighbors(self.state))) == 0
        episode_over = self.state == self.t or is_deadend

        # test this first since the correct target can also be a deadend
        if self.state == self.t:
            reward = 100
        elif is_deadend:
            reward = -100
        elif repeated_action:
            reward = -5
        else:
            reward = -1
      
        return reward, episode_over, self.state
    
    def render(self, ax=None):
        non_terminal = self.g.nodes - set([self.s, self.t])
        nx.draw_networkx(
            self.g, pos=self.g.my_pos, nodelist=[self.s], edgelist=[], node_color='y', ax=ax
        )
        nx.draw_networkx(
            self.g, pos=self.g.my_pos, nodelist=[self.t], edgelist=[], node_color='c', ax=ax
        )
        nx.draw_networkx(self.g, pos=self.g.my_pos, nodelist=non_terminal, ax=ax)
        nx.draw_networkx_edges(
            self.g, pos=self.g.my_pos, edgelist=self.act_hist, width=8, edge_color='r', ax=ax
        )

In [None]:
edges = [
    (6, 7), (7, 8), (6, 3), (7, 4), (8, 5),
    (3, 4), (4, 5), (3, 0), (4, 1), (5, 2),
    (0, 1), (1, 2),
]
rev_edges = [(j, i) for i, j in edges]
edges = edges + rev_edges
edges = edges + [
    (0, 'i0'), ('o0', 0), (1, 'i1'), ('o1', 1),
    (2, 'i2'), ('o2', 2), (3, 'i3'), ('o3', 3),
    (5, 'i5'), ('o5', 5), (6, 'i6'), ('o6', 6),
    (7, 'i7'), ('o7', 7), (8, 'i8'), ('o8', 8),
]
edges = [(str(i), str(j)) for i, j in edges]

# cartesian frame, i.e. (x, y) pairs, with origin at bottom left
pos = {
    '0': (0, 0), '1': (1, 0), '2': (2, 0), '3': (0, 1),
    '4': (1, 1), '5': (2, 1), '6': (0, 2), '7': (1, 2),
    '8': (2, 2), 'i0': (-.4, -.8), 'o0': (-.8, -.4), 
    'i1': (.8, -1), 'o1': (1.2, -1), 
    'i2': (2.6, -.4), 'o2': (2.4, -.6),
    'i3': (-1, .8), 'o3': (-1, 1.2),
    'i5': (3, .8), 'o5': (3, 1.2),
    'i6': (-.4, 2.6), 'o6': (-.6, 2.4),
    'i7': (.8, 3), 'o7': (1.2, 3),
    'i8': (2.4, 2.6), 'o8': (2.6, 2.4),
}
super_graph = nx.DiGraph(edges)
super_graph.my_pos = pos
env = RoutingEnv(super_graph, 'o0', 'i7')
env.step(('o0', '0'))
env.step(('0', '3'))
env.render()

In [None]:

aa  = pd.DataFrame(np.arange(12).reshape(6, 2), index=[0 ,1 ,2, 3, 'o', 'i'], columns=[(1, 2), ('o', 'i')])
aa.loc[0].idxmax()

In [None]:
𝛼 = .01
𝛾 = .9
𝜖 = .1  # prob explore

env = RoutingEnv(super_graph, 'o0', 'i7')

# using a tuple/edge as columns confuses pandas, convert to strings:
all_actions_str = [f'{i}_{j}' for i, j in env.all_actions]
q = pd.DataFrame(0, index=env.all_states, columns=all_actions_str)

env.start()
reward_hist = []
for ep in tqdm.tqdm(range(100)):
    episode_over = False
    ep_reward = 0
    while not episode_over:
        act_type = np.random.choice(['explore', 'exploit'], p=[𝜖, 1 - 𝜖])
        
        avail_actions = list(env.g.edges(env.state))
        if act_type == 'explore':
            # choice() is fussy:| a list of tuples looks like a 2d array to it which it doesn't like
            ind = np.random.choice(np.arange(len(avail_actions)))
            act = avail_actions[ind]
            act_str = f'{act[0]}_{act[1]}'
        else:  # 'exploit'
            # TODO is pandas efficient enough?
            # TODO is it necessary slice by avail_inds?
            # this is a row slice i.e. a row as a Series, therefore argmax returns the col
            avail_actions_str = [f'{i}_{j}' for i, j in avail_actions]
            act_str = q.loc[env.state, avail_actions_str].idxmax()
            act = act_str.split('_')

        old_state = env.state
        reward, episode_over, _ = env.step(act)
        
        # TODO is it necessary slice by avail_inds?
        avail_actions = list(env.g.edges(env.state))
        if avail_actions:
            avail_actions_str = [f'{i}_{j}' for i, j in avail_actions]
            exp_q_cur_state = q.loc[env.state, avail_actions_str].max()
        else:
            # TODO if we get rid of avail_actions in the slicing, we don't need this if since
            #      at a deadend state e.g. 'i0', no actions are allowed so that column remains
            #      zero (since we initialize q as 0 matrix)
            exp_q_cur_state = 0
        
        q.loc[old_state, act_str] = (1 - 𝛼) * q.loc[old_state, act_str] + 𝛼 * (reward + 𝛾 * exp_q_cur_state)
        assert not q.isna().any().any()
        
        ep_reward += reward
    # fig, ax = plt.subplots()
    # env.render(ax=ax)
    reward_hist.append(ep_reward)
    env.reset()
plt.plot(reward_hist)

In [None]:
env.keep_hist = True
episode_over = False
ep_reward = 0
env.start()
count = 0
while not episode_over and count < 15:
    count += 1
    avail_actions = list(env.g.edges(env.state))
    avail_actions_str = [f'{i}_{j}' for i, j in avail_actions]
    act_str = q.loc[env.state, avail_actions_str].idxmax()
    act = act_str.split('_')

    reward, episode_over, _ = env.step(act)
    ep_reward += reward
env.render()