Skip to content
Permalink
Browse files

updated to match latest version of paper

  • Loading branch information...
sfujim committed Jan 29, 2019
1 parent 7c945d8 commit 05c07fc442a2be96f6249b966682cf065045500f
Showing with 57 additions and 75 deletions.
  1. +57 −75 BCQ.py
132 BCQ.py
@@ -1,7 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import utils

@@ -33,42 +32,42 @@ def __init__(self, state_dim, action_dim):
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)

self.l4 = nn.Linear(state_dim + action_dim, 400)
self.l5 = nn.Linear(400, 300)
self.l6 = nn.Linear(300, 1)


def forward(self, state, action):
q = F.relu(self.l1(torch.cat([state, action], 1)))
q = F.relu(self.l2(q))
q = self.l3(q)
return q
q1 = F.relu(self.l1(torch.cat([state, action], 1)))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)

q2 = F.relu(self.l4(torch.cat([state, action], 1)))
q2 = F.relu(self.l5(q2))
q2 = self.l6(q2)
return q1, q2

class Value(nn.Module):
def __init__(self, state_dim, action_dim):
super(Value, self).__init__()
self.l1 = nn.Linear(state_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)


def forward(self, state):
v = F.relu(self.l1(state))
v = F.relu(self.l2(v))
v = self.l3(v)
return v
def q1(self, state, action):
q1 = F.relu(self.l1(torch.cat([state, action], 1)))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
return q1


# Vanilla Variational Auto-Encoder
class VAE(nn.Module):
def __init__(self, state_dim, action_dim, latent_dim, max_action):
super(VAE, self).__init__()
self.e1 = nn.Linear(state_dim + action_dim, 400)
self.e2 = nn.Linear(400, 300)
self.e1 = nn.Linear(state_dim + action_dim, 750)
self.e2 = nn.Linear(750, 750)

self.mean = nn.Linear(300, latent_dim)
self.log_std = nn.Linear(300, latent_dim)
self.mean = nn.Linear(750, latent_dim)
self.log_std = nn.Linear(750, latent_dim)

self.d1 = nn.Linear(state_dim + latent_dim, 400)
self.d2 = nn.Linear(400, 300)
self.d3 = nn.Linear(300, action_dim)
self.d1 = nn.Linear(state_dim + latent_dim, 750)
self.d2 = nn.Linear(750, 750)
self.d3 = nn.Linear(750, action_dim)

self.max_action = max_action
self.latent_dim = latent_dim
@@ -92,7 +91,7 @@ def forward(self, state, action):
def decode(self, state, z=None):
# When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
if z is None:
z = torch.FloatTensor(np.random.normal(0, 1, size=(state.size(0), self.latent_dim))).clamp(-0.5, 0.5).to(device)
z = torch.FloatTensor(np.random.normal(0, 1, size=(state.size(0), self.latent_dim))).to(device).clamp(-0.5, 0.5)

a = F.relu(self.d1(torch.cat([state, z], 1)))
a = F.relu(self.d2(a))
@@ -103,19 +102,18 @@ def decode(self, state, z=None):
class BCQ(object):
def __init__(self, state_dim, action_dim, max_action):

latent_dim = action_dim * 2
latent_dim = action_dim * 2

self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = torch.optim.Adam(self.actor.parameters())

self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())

self.value = Value(state_dim, action_dim).to(device)
self.value_target = Value(state_dim, action_dim).to(device)
self.value_target.load_state_dict(self.value.state_dict())
self.value_optimizer = torch.optim.Adam(self.value.parameters())

self.vae = VAE(state_dim, action_dim, latent_dim, max_action).to(device)
self.vae_optimizer = torch.optim.Adam(self.vae.parameters())

@@ -125,9 +123,10 @@ def __init__(self, state_dim, action_dim, max_action):

def select_action(self, state):
with torch.no_grad():
state = torch.FloatTensor(state.reshape(1, -1)).repeat(100, 1).to(device)
state = torch.FloatTensor(state.reshape(1, -1)).repeat(10, 1).to(device)
action = self.actor(state, self.vae.decode(state))
ind = self.critic(state, action).max(0)[1]
q1 = self.critic.q1(state, action)
ind = q1.max(0)[1]
return action[ind].cpu().data.numpy().flatten()


@@ -136,10 +135,10 @@ def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.
for it in range(iterations):

# Sample replay buffer / batch
state_np, next_state, action, reward, done = replay_buffer.sample(batch_size)
state_np, next_state_np, action, reward, done = replay_buffer.sample(batch_size)
state = torch.FloatTensor(state_np).to(device)
action = torch.FloatTensor(action).to(device)
next_state = torch.FloatTensor(next_state).to(device)
next_state = torch.FloatTensor(next_state_np).to(device)
reward = torch.FloatTensor(reward).to(device)
done = torch.FloatTensor(1 - done).to(device)

@@ -148,7 +147,7 @@ def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.
recon, mean, std = self.vae(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + KL_loss
vae_loss = recon_loss + 0.5 * KL_loss

self.vae_optimizer.zero_grad()
vae_loss.backward()
@@ -157,59 +156,42 @@ def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.

# Critic Training
with torch.no_grad():
target_Q = reward + done * discount * self.value_target(next_state)

current_Q = self.critic(state, action)
critic_loss = F.mse_loss(current_Q, target_Q)
# Duplicate state 10 times
state_rep = torch.FloatTensor(np.repeat(next_state_np, 10, axis=0)).to(device)

# Compute value of perturbed actions sampled from the VAE
target_Q1, target_Q2 = self.critic_target(state_rep, self.actor_target(state_rep, self.vae.decode(state_rep)))

# Soft Clipped Double Q-learning
target_Q = 0.75 * torch.min(target_Q1, target_Q2) + 0.25 * torch.max(target_Q1, target_Q2)
target_Q = target_Q.view(batch_size, -1).max(1)[0].view(-1, 1)

target_Q = reward + done * discount * target_Q

current_Q1, current_Q2 = self.critic(state, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()


# Actor Training
# Pertubation Model / Action Training
sampled_actions = self.vae.decode(state)
perturbed_actions = self.actor(state, sampled_actions)
actor_loss = -(self.critic(state, perturbed_actions)).mean()

# Update through DPG
actor_loss = -self.critic.q1(state, perturbed_actions).mean()

self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()


# Value Training
current_V = self.value(state)
with torch.no_grad():
# Duplicate state 10 times
state = torch.FloatTensor(np.repeat(state_np, 10, axis=0)).to(device)

# Compute value of perturbed actions sampled from the VAE
target_V = self.critic(state, self.actor(state, self.vae.decode(state)))

# Select the max action (out of 10) for each state
target_V = target_V.view(batch_size, -1).max(1)[0].view(-1, 1)

value_loss = F.mse_loss(current_V, target_V)

self.value_optimizer.zero_grad()
value_loss.backward()
self.value_optimizer.step()


# Update the frozen target models
for param, target_param in zip(self.value.parameters(), self.value_target.parameters()):
# Update Target Networks
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


def save(self, filename, directory):
torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))
torch.save(self.value.state_dict(), '%s/%s_value.pth' % (directory, filename))
torch.save(self.vae.state_dict(), '%s/%s_vae.pth' % (directory, filename))


def load(self, filename, directory):
self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))
self.value.load_state_dict(torch.load('%s/%s_value.pth' % (directory, filename)))
self.vae.load_state_dict(torch.load('%s/%s_vae.pth' % (directory, filename)))
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

0 comments on commit 05c07fc

Please sign in to comment.
You can’t perform that action at this time.