Skip to content

Commit

Permalink
Updated parameter files
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Feb 19, 2020
1 parent 221dc63 commit 4d071d6
Show file tree
Hide file tree
Showing 18 changed files with 253 additions and 183 deletions.
2 changes: 1 addition & 1 deletion digideep/agent/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def reset_hidden_state(self, num_workers):
return np.zeros((num_workers, hidden_size), dtype=np.float32)

def random_action_generator(self, envs, num_workers):
actions = np.array([envs.action_space.spaces[self.params["name"]].sample() for i in range(num_workers)])
actions = np.array([envs.action_space.spaces[self.params["name"]].sample() for i in range(num_workers)], dtype=np.float32)
hidden_state = self.reset_hidden_state(num_workers)
return dict(actions=actions, hidden_state=hidden_state)

Expand Down
60 changes: 36 additions & 24 deletions digideep/agent/ddpg/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from digideep.utility.monitoring import monitor

# from digideep.agent.sampler_common import check_shape
from digideep.agent.sampler_common import Compose
from digideep.agent.agent_base import AgentBase
from .sampler import sampler_re

from .policy import Policy

# torch.utils.backcompat.broadcast_warning.enabled = True
Expand Down Expand Up @@ -60,6 +61,10 @@ def __init__(self, session, memory, **params):
self.optimizer["actor"] = optimclass_actor(self.policy.model["actor"].parameters(), **self.params["optimargs_actor"])
self.optimizer["critic"] = optimclass_critic(self.policy.model["critic"].parameters(), **self.params["optimargs_critic"])

# Build the sampler from sampler list:
sampler_list = [get_class(k) for k in self.params["sampler_list"]]
self.sampler = Compose(sampler_list)

noiseclass = get_class(self.params["noisename"])
self.noise = noiseclass(**self.params["noiseargs"])

Expand Down Expand Up @@ -114,46 +119,52 @@ def step(self):
The first three keys are generated by the :class:`~digideep.environment.explorer.Explorer`
and the last key is added by the sampler.
"""

with KeepTime("sampler"):
info = deepcopy(self.params["sampler"])
batch = sampler_re(data=self.memory, info=info)
info = deepcopy(self.params["sampler_args"])
batch = self.sampler(data=self.memory, info=info)
if batch is None:
return


with KeepTime("to_torch"):
# ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2']

o1 = torch.from_numpy(batch["/obs_with_key"]).to(self.device)
r1 = torch.from_numpy(batch["/rewards"]).to(self.device)
a1 = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device)
o2 = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device)
masks = torch.from_numpy(batch["/masks"]).to(self.device).view(-1)

# o1.clamp_(min=-self.params["trainer"]["clamp_obs"], max= self.params["trainer"]["clamp_obs"])
# o2.clamp_(min=-self.params["trainer"]["clamp_obs"], max= self.params["trainer"]["clamp_obs"])
# ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2', ...]
o1 = torch.from_numpy(batch["/observations"+ self.params["observation_path"]]).to(self.device).float()
a1 = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device).float()
r1 = torch.from_numpy(batch["/rewards"]).to(self.device).float()
o2 = torch.from_numpy(batch["/observations"+self.params["observation_path"]+"_2"]).to(self.device).float()
masks = torch.from_numpy(batch["/masks"]).to(self.device)
# .view(-1).float()

# with KeepTime("to_torch"):
# # ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2']
# o1 = torch.from_numpy(batch["/obs_with_key"]).to(self.device).float()
# r1 = torch.from_numpy(batch["/rewards"]).to(self.device).float()
# a1 = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device).float()
# o2 = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device).float()
# masks = torch.from_numpy(batch["/masks"]).to(self.device).view(-1).float()
# # o1.clamp_(min=-self.params["trainer"]["clamp_obs"], max= self.params["trainer"]["clamp_obs"])
# # o2.clamp_(min=-self.params["trainer"]["clamp_obs"], max= self.params["trainer"]["clamp_obs"])


with KeepTime("loss/critic"):
# ---------------------- optimize critic ----------------------
# Use target actor exploitation policy here for loss evaluation
a2 = self.policy.model["actor_target"](o2).detach()
next_val = torch.squeeze(self.policy.model["critic_target"](o2, a2).detach())
next_val = self.policy.model["critic_target"](o2, a2).detach()

# y_target = r + gamma * Q'( s2, pi'(s2))
# NOTE: THIS SENTENCE IS VERY IMPORTANT!
r1 = torch.squeeze(r1)
r1 = r1
y_target = r1 + masks * next_val * float(self.params["methodargs"]["gamma"])

# TODO: IT WASN'T IN THE ORIGINAL IMPLEMENTATION BUT IN HER's.
# y_target.clamp_(min=-self.params["methodargs"]["clamp_return"], max=0)

# y_pred = Q( s1, a1)
y_predicted = torch.squeeze(self.policy.model["critic"](o1, a1))
y_predicted = self.policy.model["critic"](o1, a1)
# compute critic loss, and update the critic
# smooth_l1_loss: Calculates l2 norm near zero and l1 elsewhere


# NOTE: The following is in DDPG+HER implementation.
# loss_critic = F.mse_loss(y_predicted, y_target, reduction='sum')
# NOTE: The following was used in the original!
Expand All @@ -180,14 +191,15 @@ def step(self):
def update(self):
# Update the networks for n times
for i in range(self.params["methodargs"]["n_update"]):
# Step
with KeepTime("step"):
self.step()

with KeepTime("targets"):
# Update actor/critic targets
self.policy.averager["actor"].update_target()
self.policy.averager["critic"].update_target()
with KeepTime("targets"):
# Update actor/critic targets
self.policy.averager["actor"].update_target()
self.policy.averager["critic"].update_target()

## For debugging
# for p, ptar in zip(self.policy.model["actor"].parameters(), self.policy.model["actor_target"].parameters()):
# print(p.mean(), ptar.mean())
Expand Down
23 changes: 12 additions & 11 deletions digideep/agent/ddpg/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ def __init__(self, device, **params):
state_size = self.params["obs_space"]["dim"][0]
action_size = self.params["act_space"]["dim"] if np.isscalar(self.params["act_space"]["dim"]) else self.params["act_space"]["dim"][0]
action_gain = self.params["act_space"]["lim"][1][0]
hidden_size = self.params['hidden_size']

self.model["actor"] = ActorModel(state_size=state_size, action_size=action_size, action_gain=action_gain, **self.params["actor_args"])
self.model["actor"] = ActorModel(state_size=state_size, action_size=action_size, action_gain=action_gain, hidden_size=hidden_size, **self.params["actor_args"])
self.model["actor_target"] = deepcopy(self.model["actor"])

self.model["critic"] = CriticModel(state_size=state_size, action_size=action_size, **self.params["critic_args"])
self.model["critic"] = CriticModel(state_size=state_size, action_size=action_size, hidden_size=hidden_size, **self.params["critic_args"])
self.model["critic_target"] = deepcopy(self.model["critic"])

self.averager = {}
Expand Down Expand Up @@ -99,21 +100,21 @@ def __init__(self, **params):
# init_ = init_easy()
# self.bn1 = nn.BatchNorm1d(num_features=self.params['state_size'])

self.fcs1 = nn.Linear(self.params['state_size'], 256)
self.fcs1 = nn.Linear(self.params['state_size'], self.params["hidden_size"])
self.fcs1.weight.data = fanin_init(self.fcs1.weight.data)

self.fcs2 = nn.Linear(256,128)
self.fcs2 = nn.Linear(self.params["hidden_size"], int(self.params["hidden_size"]/2))
self.fcs2.weight.data = fanin_init(self.fcs2.weight.data)

self.fca1 = nn.Linear(self.params['action_size'], 128)
self.fca1 = nn.Linear(self.params['action_size'], (self.params["hidden_size"]-int(self.params["hidden_size"]/2)))
self.fca1.weight.data = fanin_init(self.fca1.weight.data)

# self.fc2 = nn.Linear(256,256)
self.fc2 = nn.Linear(256,128)
self.fc2 = nn.Linear(self.params["hidden_size"],self.params["hidden_size"])
self.fc2.weight.data = fanin_init(self.fc2.weight.data)

# self.fc3 = nn.Linear(256, 1)
self.fc3 = nn.Linear(128, 1)
self.fc3 = nn.Linear(self.params["hidden_size"], 1)
self.fc3.weight.data.uniform_(-self.params['eps'], self.params['eps'])

def forward(self, state, action):
Expand Down Expand Up @@ -155,19 +156,19 @@ def __init__(self, **params):

# self.bn1 = nn.BatchNorm1d(num_features=self.params['state_size'])

self.fc1 = nn.Linear(self.params['state_size'], 256)
self.fc1 = nn.Linear(self.params['state_size'], self.params["hidden_size"])
self.fc1.weight.data = fanin_init(self.fc1.weight.data)

# self.fc2 = nn.Linear(256,256)
self.fc2 = nn.Linear(256,128)
self.fc2 = nn.Linear(self.params["hidden_size"],self.params["hidden_size"])
self.fc2.weight.data = fanin_init(self.fc2.weight.data)

# self.fc3 = nn.Linear(256,256)
self.fc3 = nn.Linear(128,64)
self.fc3 = nn.Linear(self.params["hidden_size"],self.params["hidden_size"])
self.fc3.weight.data = fanin_init(self.fc3.weight.data)

# self.fc4 = nn.Linear(256, self.params['action_size'])
self.fc4 = nn.Linear(64, self.params['action_size'])
self.fc4 = nn.Linear(self.params["hidden_size"], self.params['action_size'])
self.fc4.weight.data.uniform_(-self.params['eps'], self.params['eps'])
self.tanh = nn.Tanh()

Expand Down
83 changes: 59 additions & 24 deletions digideep/agent/ddpg/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from digideep.utility.logging import logger
from digideep.utility.profiling import KeepTime

def get_sample_memory(buffer, info):
def get_sample_memory(memory, info):
"""Sampler function for DDPG-like algorithms where we want to sample data from an experience replay buffer.
This function adds the following key to the buffer:
Expand All @@ -19,46 +19,81 @@ def get_sample_memory(buffer, info):
key in the output batch will be: ``(batch_size, *key_shape[2:])``
"""
# Get information from info
batch_size = info["batch_size"]
num_workers = info["num_workers"]
observation_path = info["observation_path"]
# Whether to use CER or not:
use_cer = info.get("use_cer", False)

# Get the main data from the memory
buffer = memory.get_buffer()

# Get some constants from the memory
num_workers = memory.get_num_batches()
N = memory.get_last_trans_index() - 1 # We don't want to consider the last "incomplete" record, hence "-1"

record_arr = memory.get_index_valid_elements()
worker_arr = np.arange(0, num_workers)

N = info["num_records"] - 1 # We don't want to consider the last "incomplete" record, hence "-1"
num_records = len(record_arr) * num_workers

with KeepTime("mask_array"):
masks_arr = buffer["/masks"][:,:N]
masks_arr = masks_arr.reshape(-1)
with KeepTime("total_array"):
total_arr = np.arange(0,num_workers*N)
## This is if we want to mask final states (mask equals 0 for final state, 1 otherwise.)
# valid_arr = total_arr[masks_arr.astype(bool)]
valid_arr = total_arr
# with KeepTime("mask_array"):
# masks_arr = buffer["/masks"][:,record_arr]
# masks_arr = masks_arr.reshape(-1)

if batch_size >= len(valid_arr):
if batch_size >= num_records:
# We don't have enough data in the memory yet.
logger.debug("batch_size ({}) should be smaller than total number of records (~ {}={}x{}).".format(batch_size, num_workers*N, num_workers, N))
logger.debug("batch_size ({}) should be smaller than total number of records (~ {}={}x{}).".format(batch_size, num_records, num_workers, len(record_arr)))
return None

with KeepTime("sampling_by_choice"):
# Sampling with replacement:
sample_indices = np.random.choice(valid_arr, batch_size, replace=True)
# NOTE: Never ever use sampling without replacement: Its time scales up with th array size.
# sample_indices = np.random.choice(valid_arr, batch_size, replace=False)
if use_cer:
last_chunk_indices = memory.get_index_valid_last_chunk()
available_batch_size = len(last_chunk_indices) * num_workers
if available_batch_size <= batch_size:
# We have selected a few transitions from previous step.
# Now, we should sample the rest from the replay buffer.
sample_record_recent = np.repeat(last_chunk_indices, num_workers) # 10 10 10 10 11 11 11 11 ...
sample_worker_recent = np.tile(worker_arr, len(last_chunk_indices)) # 0 1 2 3 0 1 2 3 ...

with KeepTime("tabular_index_extraction"):
sample_tabular = [[sample_indices // N], [sample_indices % N]]
sample_tabular_2 = [[sample_indices // N], [sample_indices % N + 1]]
batch_size_prime = batch_size - available_batch_size

# Select the rest ...
sample_record_prime = np.random.choice(record_arr, batch_size_prime, replace=True)
sample_worker_prime = np.random.choice(worker_arr, batch_size_prime, replace=True)

# Combine
sample_record = np.concatenate([sample_record_recent, sample_record_prime])
sample_worker = np.concatenate([sample_worker_recent, sample_worker_prime])

else:
# OK, we have enough data, so no sampling!
logger.warn("CER: Latest transitions greater than batch size. Sample from last transitions.")

sample_record = np.random.choice(last_chunk_indices, batch_size, replace=True)
sample_worker = np.random.choice(worker_arr, batch_size, replace=True)

else:
# NOTE: NEVER ever use sampling WITHOUT replacement: Its time scales up with th array size.
# Sampling with replacement:
sample_record = np.random.choice(record_arr, batch_size, replace=True)
sample_worker = np.random.choice(worker_arr, batch_size, replace=True)

# Move the next step samples
sample_record_2 = memory.get_index_move_n_steps(sample_record, 1)
# Make a table of indices to extract transitions
sample_tabular = [[sample_worker], [sample_record]]
sample_tabular_2 = [[sample_worker], [sample_record_2]]

with KeepTime("tabular_index_extraction"):
# Extracting the indices
batch = {}
for key in buffer:
batch[key] = buffer[key][sample_tabular[0],sample_tabular[1]]

with KeepTime("post_key_generation"):
observation_path = "/observations" + observation_path
batch["/obs_with_key"] = batch[observation_path]
# Adding predictive keys
batch["/obs_with_key_2"] = buffer[observation_path][sample_tabular_2[0],sample_tabular_2[1]]
batch[observation_path+"_2"] = buffer[observation_path][sample_tabular_2[0],sample_tabular_2[1]]

with KeepTime("flatten_first_two"):
batch = flatten_first_two(batch)
Expand All @@ -71,7 +106,7 @@ def get_sample_memory(buffer, info):

# Sampler with replay buffer
sampler_re = Compose([flatten_memory_to_train_key, # Must be present: It flattens the memory dict to the "train" key.
get_memory_params, # Must be present: It gets the memory parameters and passes them to the rest of functions through "info".
# get_memory_params, # Must be present: It gets the memory parameters and passes them to the rest of functions through "info".
get_sample_memory, # Sample
# check_shape, # It prints the shapes of the existing keys in the chunk.
# check_nan, # It complains about existing NaNs in the chunk.
Expand Down
17 changes: 12 additions & 5 deletions digideep/agent/sac/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,18 @@ def step(self):

with KeepTime("to_torch"):
# ['/obs_with_key', '/masks', '/agents/agent/actions', '/agents/agent/hidden_state', '/rewards', '/obs_with_key_2', ...]
state = torch.from_numpy(batch["/obs_with_key"]).to(self.device)
action = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device)
reward = torch.from_numpy(batch["/rewards"]).to(self.device)
next_state = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device)
# masks = torch.from_numpy(batch["/masks"]).to(self.device).view(-1)
state = torch.from_numpy(batch["/observations"+ self.params["observation_path"]]).to(self.device).float()
action = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device).float()
reward = torch.from_numpy(batch["/rewards"]).to(self.device).float()
next_state = torch.from_numpy(batch["/observations"+self.params["observation_path"]+"_2"]).to(self.device).float()
masks = torch.from_numpy(batch["/masks"]).to(self.device)

# state = torch.from_numpy(batch["/obs_with_key"]).to(self.device)
# action = torch.from_numpy(batch["/agents/"+self.params["name"]+"/actions"]).to(self.device)
# reward = torch.from_numpy(batch["/rewards"]).to(self.device)
# next_state = torch.from_numpy(batch["/obs_with_key_2"]).to(self.device)
# # masks = torch.from_numpy(batch["/masks"]).to(self.device).view(-1)
# masks = torch.from_numpy(batch["/masks"]).to(self.device)

with KeepTime("loss"):
expected_q_value = self.policy.model["softq"](state, action)
Expand All @@ -159,6 +165,7 @@ def step(self):
value_loss = self.criterion["value"](expected_value, next_value.detach())

log_prob_target = expected_new_q_value - expected_value
# TODO: Apperantly the calculation of actor_loss is problematic: none of its ingredients have gradients! So backprop does nothing.
actor_loss = (log_prob * (log_prob - log_prob_target).detach()).mean()

mean_loss = float(self.params["methodargs"]["mean_lambda"]) * mean.pow(2).mean()
Expand Down

0 comments on commit 4d071d6

Please sign in to comment.