Skip to content

Commit

Permalink
items optimiser load
Browse files Browse the repository at this point in the history
  • Loading branch information
cnheider committed Aug 6, 2020
1 parent 220f3bd commit d858b10
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,18 @@ def _remember(self, *, signal, terminated, transition):
@param terminated:
@return:
"""

a = [TransitionPoint(*s) for s in zip(*transition, signal, terminated)]
if self._use_per:
with torch.no_grad():
td_error, *_ = self._td_error(zip(*a))
for a_, e_ in zip(a, td_error.detach().squeeze(-1).cpu().numpy()):
self._memory_buffer.add_transition_point(a_, e_)
if transition:
a = [TransitionPoint(*s) for s in zip(*transition, signal, terminated)]
if self._use_per:
with torch.no_grad():
td_error, *_ = self._td_error(zip(*a))
for a_, e_ in zip(a, td_error.detach().squeeze(-1).cpu().numpy()):
self._memory_buffer.add_transition_point(a_, e_)
else:
for a_ in a:
self._memory_buffer.add_transition_point(a_)
else:
for a_ in a:
self._memory_buffer.add_transition_point(a_)
raise ValueError('Missing transition')

@drop_unused_kws
def _sample_model(self, state: Any) -> numpy.ndarray:
Expand Down
8 changes: 5 additions & 3 deletions neodroidagent/agents/torch_agents/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def post_process_gradients(self, parameters: Iterable[Parameter]) -> None:
@param model:
@return:
:param parameters:
"""
if self._gradient_clipping.enabled:
for params in parameters:
Expand Down Expand Up @@ -113,6 +114,7 @@ def build(
@param print_model_repr:
@param kwargs:
@return:
:param verbose:
"""
super().build(
observation_space,
Expand Down Expand Up @@ -204,9 +206,9 @@ def load(self, *, save_directory: Path, evaluation: bool = False) -> bool:
"""
loaded = True
if save_directory.exists():
print("Loading models froms: " + str(save_directory))
print(f"Loading models from: {str(save_directory)}")
for (model_key, model), (optimiser_key, optimiser) in zip(
self.models.items(), self.optimisers.values()
self.models.items(), self.optimisers.items()
):
model_identifier = self.model_name(model_key, model)
(model, optimiser), loaded = load_latest_model_parameters(
Expand All @@ -217,7 +219,7 @@ def load(self, *, save_directory: Path, evaluation: bool = False) -> bool:
)
if loaded:
model = model.to(self._device)
optimiser = optimiser.to(self._device)
#optimiser = optimiser.to(self._device)
if evaluation:
model = model.eval()
model.train(False) # Redundant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ def rollout_off_policy(

state = successor_state

if use_episodic_buffer:
t = TransitionPoint(*zip(*episode_buffer))
agent.remember(
signal=t.signal,
terminated=t.terminal,
transition=Transition(t.state, t.action, t.successor_state),
)
if train_agent:
if use_episodic_buffer:
t = TransitionPoint(*zip(*episode_buffer))
agent.remember(
signal=t.signal,
terminated=t.terminal,
transition=Transition(t.state, t.action, t.successor_state),
)

if step_i > 0:
if train_agent:
Expand Down
2 changes: 1 addition & 1 deletion neodroidagent/entry_points/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train(self, **overrides) -> None:
self.agent_callable(config=default_config)

def run(self):
pass
self.train(train_agent=False,render_frequency=1,save=False)


class NeodroidAgentCLI:
Expand Down

0 comments on commit d858b10

Please sign in to comment.