forked from ctlab/dqnroute
-
Notifications
You must be signed in to change notification settings - Fork 0
/
q_routing.py
executable file
·131 lines (110 loc) · 4.75 KB
/
q_routing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import networkx as nx
import random
from typing import List, Tuple, Dict
from ..base import *
from .link_state import *
from ...messages import *
from ...utils import dict_min
class SimpleQRouter(Router, RewardAgent):
"""
A router which implements Q-routing algorithm
"""
def __init__(self, learning_rate: float, nodes: List[AgentId], **kwargs):
super().__init__(**kwargs)
self.learning_rate = learning_rate
self.nodes = nodes
self.Q = {u: {v: 0 if u == v else 10
for v in self.interface_map.values()}
for u in self.nodes}
def addLink(self, to: AgentId, params={}) -> List[Message]:
msgs = super().addLink(to, params)
for (u, dct) in self.Q.items():
if to not in dct:
dct[to] = 0 if u == to else 10
return msgs
def route(self, sender: AgentId, pkg: Package, allowed_nbrs: List[AgentId]) -> Tuple[AgentId, List[Message]]:
Qs = self._Q(pkg.dst, allowed_nbrs)
to, estimate = dict_min(Qs)
reward_msg = self.registerResentPkg(pkg, estimate, to, pkg.dst)
return to, [OutMessage(self.id, sender, reward_msg)] if sender[0] != 'world' else []
def pathCost(self, to: AgentId) -> float:
return min(self._Q(to, list(self.interface_map.values())).values())
def handleMsgFrom(self, sender: AgentId, msg: Message) -> List[Message]:
if isinstance(msg, RewardMsg):
action, Q_new, dst = self.receiveReward(msg)
self.Q[dst][action] += self.learning_rate * (Q_new - self.Q[dst][action])
return []
else:
return super().handleMsgFrom(sender, msg)
def _Q(self, d: int, allowed_nbrs: List[AgentId]) -> Dict[int, float]:
"""
Returns a dict which only includes available neighbours
"""
return {n: self.Q[d][n] for n in allowed_nbrs}
class PredictiveQRouter(SimpleQRouter, RewardAgent):
def __init__(self, beta: float, gamma: float, **kwargs):
super().__init__(**kwargs)
self.beta = beta
self.gamma = gamma
self.B = deepcopy(self.Q)
self.R = {u: {v: 0 for v in self.interface_map.values()}
for u in self.nodes}
self.U = {u: {v: 0 for v in self.interface_map.values()}
for u in self.nodes}
def addLink(self, to: AgentId, params={}) -> List[Message]:
msgs = super().addLink(to, params)
for u in self.nodes:
if to not in self.B[u]:
self.B[u][to] = 0 if u == to else self.Q[u][to]
if to not in self.R[u]:
self.R[u][to] = 0
if to not in self.U[u]:
self.U[u][to] = self.env.time()
return msgs
def route(self, sender: AgentId, pkg: Package, allowed_nbrs: List[AgentId]) -> Tuple[AgentId, List[Message]]:
Qs = self._Q(pkg.dst, allowed_nbrs)
Qs_altered = self._Q_altered(pkg.dst, allowed_nbrs)
to, _ = dict_min(Qs_altered)
estimate = min(Qs.values())
reward_msg = self.registerResentPkg(pkg, estimate, to, pkg.dst)
return to, [OutMessage(self.id, sender, reward_msg)] if sender[0] != 'world' else []
def handleMsgFrom(self, sender: AgentId, msg: Message) -> List[Message]:
if isinstance(msg, RewardMsg):
action, Q_new, dst = self.receiveReward(msg)
dQ = Q_new - self.Q[dst][action]
self.Q[dst][action] += self.learning_rate * dQ
self.B[dst][action] = min(self.B[dst][action], self.Q[dst][action])
now = self.env.time()
if dQ < 0:
dR = dQ / (now - self.U[dst][action])
self.R[dst][action] += self.beta * dR
elif dQ > 0:
self.R[dst][action] *= self.gamma
self.U[dst][action] = now
return []
else:
return super().handleMsgFrom(sender, msg)
def _Q_altered(self, d: int, allowed_nbrs: List[AgentId]) -> Dict[int, float]:
"""
Returns estimates for all available neighbours
"""
now = self.env.time()
res = {}
for n in allowed_nbrs:
dt = now - self.U[d][n]
res[n] = max(self.Q[d][n] + dt * self.R[d][n], self.B[d][n])
return res
class SimpleQRouterNetwork(NetworkRewardAgent, SimpleQRouter):
"""
Q-router which calculates rewards for computer routing setting
"""
pass
class SimpleQRouterConveyor(LSConveyorMixin, ConveyorRewardAgent, SimpleQRouter, LinkStateRouter):
"""
Q-router which calculates rewards for conveyor routing setting
"""
pass
class PredictiveQRouterNetwork(NetworkRewardAgent, PredictiveQRouter):
pass
class PredictiveQRouterConveyor(LSConveyorMixin, ConveyorRewardAgent, PredictiveQRouter, LinkStateRouter):
pass