-
Notifications
You must be signed in to change notification settings - Fork 169
/
attention_sac.py
280 lines (257 loc) · 11.4 KB
/
attention_sac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import torch
import torch.nn.functional as F
from torch.optim import Adam
from utils.misc import soft_update, hard_update, enable_gradients, disable_gradients
from utils.agents import AttentionAgent
from utils.critics import AttentionCritic
MSELoss = torch.nn.MSELoss()
class AttentionSAC(object):
"""
Wrapper class for SAC agents with central attention critic in multi-agent
task
"""
def __init__(self, agent_init_params, sa_size,
gamma=0.95, tau=0.01, pi_lr=0.01, q_lr=0.01,
reward_scale=10.,
pol_hidden_dim=128,
critic_hidden_dim=128, attend_heads=4,
**kwargs):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
sa_size (list of (int, int)): Size of state and action space for
each agent
gamma (float): Discount factor
tau (float): Target update rate
pi_lr (float): Learning rate for policy
q_lr (float): Learning rate for critic
reward_scale (float): Scaling for reward (has effect of optimal
policy entropy)
hidden_dim (int): Number of hidden dimensions for networks
"""
self.nagents = len(sa_size)
self.agents = [AttentionAgent(lr=pi_lr,
hidden_dim=pol_hidden_dim,
**params)
for params in agent_init_params]
self.critic = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim,
attend_heads=attend_heads)
self.target_critic = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim,
attend_heads=attend_heads)
hard_update(self.target_critic, self.critic)
self.critic_optimizer = Adam(self.critic.parameters(), lr=q_lr,
weight_decay=1e-3)
self.agent_init_params = agent_init_params
self.gamma = gamma
self.tau = tau
self.pi_lr = pi_lr
self.q_lr = q_lr
self.reward_scale = reward_scale
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.niter = 0
@property
def policies(self):
return [a.policy for a in self.agents]
@property
def target_policies(self):
return [a.target_policy for a in self.agents]
def step(self, observations, explore=False):
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
Outputs:
actions: List of actions for each agent
"""
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]
def update_critic(self, sample, soft=True, logger=None, **kwargs):
"""
Update central critic for all agents
"""
obs, acs, rews, next_obs, dones = sample
# Q loss
next_acs = []
next_log_pis = []
for pi, ob in zip(self.target_policies, next_obs):
curr_next_ac, curr_next_log_pi = pi(ob, return_log_pi=True)
next_acs.append(curr_next_ac)
next_log_pis.append(curr_next_log_pi)
trgt_critic_in = list(zip(next_obs, next_acs))
critic_in = list(zip(obs, acs))
next_qs = self.target_critic(trgt_critic_in)
critic_rets = self.critic(critic_in, regularize=True,
logger=logger, niter=self.niter)
q_loss = 0
for a_i, nq, log_pi, (pq, regs) in zip(range(self.nagents), next_qs,
next_log_pis, critic_rets):
target_q = (rews[a_i].view(-1, 1) +
self.gamma * nq *
(1 - dones[a_i].view(-1, 1)))
if soft:
target_q -= log_pi / self.reward_scale
q_loss += MSELoss(pq, target_q.detach())
for reg in regs:
q_loss += reg # regularizing attention
q_loss.backward()
self.critic.scale_shared_grads()
grad_norm = torch.nn.utils.clip_grad_norm(
self.critic.parameters(), 10 * self.nagents)
self.critic_optimizer.step()
self.critic_optimizer.zero_grad()
if logger is not None:
logger.add_scalar('losses/q_loss', q_loss, self.niter)
logger.add_scalar('grad_norms/q', grad_norm, self.niter)
self.niter += 1
def update_policies(self, sample, soft=True, logger=None, **kwargs):
obs, acs, rews, next_obs, dones = sample
samp_acs = []
all_probs = []
all_log_pis = []
all_pol_regs = []
for a_i, pi, ob in zip(range(self.nagents), self.policies, obs):
curr_ac, probs, log_pi, pol_regs, ent = pi(
ob, return_all_probs=True, return_log_pi=True,
regularize=True, return_entropy=True)
logger.add_scalar('agent%i/policy_entropy' % a_i, ent,
self.niter)
samp_acs.append(curr_ac)
all_probs.append(probs)
all_log_pis.append(log_pi)
all_pol_regs.append(pol_regs)
critic_in = list(zip(obs, samp_acs))
critic_rets = self.critic(critic_in, return_all_q=True)
for a_i, probs, log_pi, pol_regs, (q, all_q) in zip(range(self.nagents), all_probs,
all_log_pis, all_pol_regs,
critic_rets):
curr_agent = self.agents[a_i]
v = (all_q * probs).sum(dim=1, keepdim=True)
pol_target = q - v
if soft:
pol_loss = (log_pi * (log_pi / self.reward_scale - pol_target).detach()).mean()
else:
pol_loss = (log_pi * (-pol_target).detach()).mean()
for reg in pol_regs:
pol_loss += 1e-3 * reg # policy regularization
# don't want critic to accumulate gradients from policy loss
disable_gradients(self.critic)
pol_loss.backward()
enable_gradients(self.critic)
grad_norm = torch.nn.utils.clip_grad_norm(
curr_agent.policy.parameters(), 0.5)
curr_agent.policy_optimizer.step()
curr_agent.policy_optimizer.zero_grad()
if logger is not None:
logger.add_scalar('agent%i/losses/pol_loss' % a_i,
pol_loss, self.niter)
logger.add_scalar('agent%i/grad_norms/pi' % a_i,
grad_norm, self.niter)
def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
soft_update(self.target_critic, self.critic, self.tau)
for a in self.agents:
soft_update(a.target_policy, a.policy, self.tau)
def prep_training(self, device='gpu'):
self.critic.train()
self.target_critic.train()
for a in self.agents:
a.policy.train()
a.target_policy.train()
if device == 'gpu':
fn = lambda x: x.cuda()
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
self.critic = fn(self.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
self.target_critic = fn(self.target_critic)
self.trgt_critic_dev = device
def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.cuda()
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents],
'critic_params': {'critic': self.critic.state_dict(),
'target_critic': self.target_critic.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}}
torch.save(save_dict, filename)
@classmethod
def init_from_env(cls, env, gamma=0.95, tau=0.01,
pi_lr=0.01, q_lr=0.01,
reward_scale=10.,
pol_hidden_dim=128, critic_hidden_dim=128, attend_heads=4,
**kwargs):
"""
Instantiate instance of this class from multi-agent environment
env: Multi-agent Gym environment
gamma: discount factor
tau: rate of update for target networks
lr: learning rate for networks
hidden_dim: number of hidden dimensions for networks
"""
agent_init_params = []
sa_size = []
for acsp, obsp in zip(env.action_space,
env.observation_space):
agent_init_params.append({'num_in_pol': obsp.shape[0],
'num_out_pol': acsp.n})
sa_size.append((obsp.shape[0], acsp.n))
init_dict = {'gamma': gamma, 'tau': tau,
'pi_lr': pi_lr, 'q_lr': q_lr,
'reward_scale': reward_scale,
'pol_hidden_dim': pol_hidden_dim,
'critic_hidden_dim': critic_hidden_dim,
'attend_heads': attend_heads,
'agent_init_params': agent_init_params,
'sa_size': sa_size}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance
@classmethod
def init_from_save(cls, filename, load_critic=False):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)
if load_critic:
critic_params = save_dict['critic_params']
instance.critic.load_state_dict(critic_params['critic'])
instance.target_critic.load_state_dict(critic_params['target_critic'])
instance.critic_optimizer.load_state_dict(critic_params['critic_optimizer'])
return instance