-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent1.py
99 lines (84 loc) · 4.04 KB
/
agent1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import random
import numpy as np
import torch
import torch.nn.functional as F
import copy
class DDPGAgent:
def __init__(self, config):
self.config = config
self.seed = config.seed
# Actor Network (w/ Target Network)
self.actor_local = config.actor_network_fn()
self.actor_target = config.actor_network_fn()
self.actor_optimizer = config.actor_optimizer_fn(self.actor_local.parameters())
# Critic Network (w/ Target Network)
self.critic_local = config.critic_network_fn()
self.critic_target = config.critic_network_fn()
self.critic_optimizer = config.critic_optimizer_fn(self.critic_local.parameters())
# ----------------------- initialize target networks ----------------------- #
self.soft_update(self.critic_local, self.critic_target, 1)
self.soft_update(self.actor_local, self.actor_target, 1)
self.noise = config.noise_fn()
if config.shared_replay_buffer:
self.memory = config.memory
else:
self.memory = config.memory_fn()
def reset(self):
self.noise.reset()
def act(self, states):
"""Returns actions for given state as per current policy."""
states = torch.from_numpy(states).float().to(self.config.device)
self.actor_local.eval()
with torch.no_grad():
actions = self.actor_local(states).cpu().data.numpy()
self.actor_local.train()
actions += self.noise.sample()
return np.clip(actions, -1, 1)
def learn(self, experiences, gamma):
"""Update policy and value parameters using given batch of experience tuples.
Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
where:
actor_target(state) -> action
critic_target(state, action) -> Q-value
Params
======
experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
gamma (float): discount factor
"""
states, actions, rewards, next_states, dones = experiences
# ---------------------------- update critic ---------------------------- #
# Get predicted next-state actions and Q values from target models
actions_next = self.actor_target(next_states)
Q_targets_next = self.critic_target(next_states, actions_next)
# Compute Q targets for current states (y_i)
Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
# Compute critic loss
Q_expected = self.critic_local(states, actions)
critic_loss = F.mse_loss(Q_expected, Q_targets)
# Minimize the loss
self.critic_optimizer.zero_grad()
critic_loss.backward()
# torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)
self.critic_optimizer.step()
# ---------------------------- update actor ---------------------------- #
# Compute actor loss
actions_pred = self.actor_local(states)
actor_loss = -self.critic_local(states, actions_pred).mean()
# Minimize the loss
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# ----------------------- update target networks ----------------------- #
self.soft_update(self.critic_local, self.critic_target, self.config.tau)
self.soft_update(self.actor_local, self.actor_target, self.config.tau)
def soft_update(self, local_model, target_model, tau):
"""Soft update model parameters.
θ_target = τ*θ_local + (1 - τ)*θ_target
Params
======
local_model: PyTorch model (weights will be copied from)
target_model: PyTorch model (weights will be copied to)
tau (float): interpolation parameter
"""
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)