Skip to content

Commit

Permalink
Types update
Browse files Browse the repository at this point in the history
  • Loading branch information
cnheider committed Jun 6, 2020
1 parent f3bd89f commit 940b5ec
Show file tree
Hide file tree
Showing 48 changed files with 708 additions and 295 deletions.
53 changes: 45 additions & 8 deletions neodroidagent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ObservationSpace,
SignalSpace,
)
from neodroidagent.utilities import IntrinsicSignalProvider

__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Expand All @@ -22,7 +23,7 @@

__all__ = ["Agent"]

ClipFeature = namedtuple("ClipFeature", ("enabled", "low", "high"))
TogglableLowHigh = namedtuple("ClipFeature", ("enabled", "low", "high"))


class Agent(ABC):
Expand All @@ -37,8 +38,9 @@ def __init__(
input_shape: Sequence = None,
output_shape: Sequence = None,
divide_by_zero_safety: float = 1e-6,
action_clipping: ClipFeature = ClipFeature(False, -1.0, 1.0),
signal_clipping: ClipFeature = ClipFeature(False, -1.0, 1.0),
action_clipping: TogglableLowHigh = TogglableLowHigh(False, -1.0, 1.0),
signal_clipping: TogglableLowHigh = TogglableLowHigh(False, -1.0, 1.0),
intrinsic_signal_provider_arch: IntrinsicSignalProvider = None,
**kwargs,
):
self._sample_i = 0
Expand All @@ -54,8 +56,9 @@ def __init__(
self._action_clipping = action_clipping
self._signal_clipping = signal_clipping

self._intrinsic_signal_provider_arch = intrinsic_signal_provider_arch

self._divide_by_zero_safety = divide_by_zero_safety
self._intrinsic_signal = lambda *a: 0 # TODO: ICM

self.__set_protected_attr(**kwargs)

Expand Down Expand Up @@ -131,6 +134,34 @@ def __infer_io_shapes(
highlight=True,
)

def __build_intrinsic_module(
self,
observation_space: ObservationSpace,
action_space: ActionSpace,
signal_space: SignalSpace,
**kwargs,
):
"""
@param observation_space:
@type observation_space:
@param action_space:
@type action_space:
@param signal_space:
@type signal_space:
@param kwargs:
@type kwargs:
"""
if self._intrinsic_signal_provider_arch is None:
self._intrinsic_signal_provider = lambda *a: 0
else:
self._intrinsic_signal_provider = self._intrinsic_signal_provider_arch(
observation_space=observation_space,
action_space=action_space,
signal_space=signal_space,
**kwargs,
)

# endregion

# region Public
Expand All @@ -155,6 +186,12 @@ def build(
self.action_space = action_space
self.signal_space = signal_space
self.__infer_io_shapes(observation_space, action_space, signal_space)
self.__build_intrinsic_module(
observation_space=observation_space,
action_space=action_space,
signal_space=signal_space,
**kwargs,
)
self.__build__(
observation_space=observation_space,
action_space=action_space,
Expand Down Expand Up @@ -245,19 +282,19 @@ def extract_action(self, sample: Any) -> numpy.ndarray:
"""
return numpy.array(sample)

def extract_signal(self, snapshot: EnvironmentSnapshot, **kwargs) -> numpy.ndarray:
def extract_signal(self, snapshot: EnvironmentSnapshot) -> numpy.ndarray:
"""
Allows for modulation of signal based on for example an Instrinsic Curiosity signal
@param signal:
@param snapshot:
@type snapshot:
@param kwargs:
@return:
"""

signal_out = numpy.array(snapshot.signal)

if self._intrinsic_signal:
signal_out += self._intrinsic_signal()
signal_out += self._intrinsic_signal_provider(snapshot)

return signal_out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from neodroid.environments.droid_environment import SingleUnityEnvironment
from neodroid.utilities import Displayable
from neodroidagent.agents.numpy_agents.numpy_agent import NumpyAgent
from neodroidagent.utilities.exploration.ucb1 import UCB1
from neodroidagent.utilities.exploration.sampling.ucb1 import UCB1

__author__ = "Christian Heider Nielsen"

Expand Down
4 changes: 2 additions & 2 deletions neodroidagent/agents/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _sample(
*args,
deterministic: bool = False,
metric_writer: Writer = MockWriter(),
**kwargs,
**kwargs
) -> Any:
"""
Expand All @@ -46,7 +46,7 @@ def __build__(
observation_space: ObservationSpace = None,
action_space: ActionSpace = None,
signal_space: SignalSpace = None,
**kwargs,
**kwargs
) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DQNAgent(TorchAgent):
def __init__(
self,
value_arch_spec: Architecture = GDKC(DuelingQMLP),
exploration_spec: GDKC = ExplorationSpecification(
exploration_spec: ExplorationSpecification = ExplorationSpecification(
start=0.95, end=0.05, decay=3000
),
memory_buffer: Memory = TransitionPointPrioritisedBuffer(int(1e5)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@


import copy
import itertools
from typing import Any, Dict, Sequence, Tuple

import itertools
import numpy
import torch
import torch.nn as nn
from torch.nn.functional import mse_loss
from tqdm import tqdm
from typing import Any, Dict, Sequence, Tuple

from draugr.writers import MockWriter, Writer
from draugr.torch_utilities import freeze_model, frozen_parameters, to_tensor
from draugr.writers import MockWriter, Writer
from neodroid.utilities import ActionSpace, ObservationSpace, SignalSpace
from neodroidagent.agents.torch_agents.torch_agent import TorchAgent
from neodroidagent.common import (
Architecture,
ConcatInputMLP,
Memory,
SamplePoint,
ShallowStdNormalMLP,
TransitionPoint,
TransitionPointBuffer,
Memory,
)
from neodroidagent.utilities import (
ActionSpaceNotSupported,
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
),
critic_arch_spec: GDKC = GDKC(ConcatInputMLP),
critic_criterion: callable = mse_loss,
**kwargs,
**kwargs
):
"""
Expand Down Expand Up @@ -110,7 +110,15 @@ def __init__(
self.inner_update_i = 0

@drop_unused_kws
def _remember(self, *, signal, terminated, state, successor_state, sample):
def _remember(
self,
*,
signal: Any,
terminated: Any,
state: Any,
successor_state: Any,
sample: Any
) -> None:
"""
@param signal:
Expand Down Expand Up @@ -146,8 +154,8 @@ def _sample(
state: Any,
*args,
deterministic: bool = False,
metric_writer: Writer = MockWriter(),
) -> Tuple[Sequence, Any]:
metric_writer: Writer = MockWriter()
) -> Tuple[torch.Tensor, Any]:
"""
@param state:
Expand Down Expand Up @@ -177,7 +185,7 @@ def __build__(
action_space: ActionSpace,
signal_space: SignalSpace,
metric_writer: Writer = MockWriter(),
print_model_repr=True,
print_model_repr: bool = True,
) -> None:
"""
Expand All @@ -186,12 +194,6 @@ def __build__(
@param signal_space:
@param metric_writer:
@param print_model_repr:
@param critic_1:
@param critic_1_optimizer:
@param critic_2:
@param critic_2_optimizer:
@param actor:
@param actor_optimiser:
@return:
"""
if action_space.is_discrete:
Expand Down Expand Up @@ -301,7 +303,9 @@ def update_critics(

return out_loss

def update_actor(self, tensorised, metric_writer: Writer = None) -> float:
def update_actor(
self, tensorised: torch.Tensor, metric_writer: Writer = None
) -> float:
"""
@param tensorised:
Expand Down Expand Up @@ -344,9 +348,13 @@ def update_actor(self, tensorised, metric_writer: Writer = None) -> float:

return out_loss

def update_alpha(self, log_prob, metric_writer: Writer = None) -> float:
def update_alpha(
self, log_prob: torch.Tensor, metric_writer: Writer = None
) -> float:
"""
@param log_prob:
@type log_prob:
@param tensorised:
@param metric_writer:
@return:
Expand Down Expand Up @@ -405,7 +413,7 @@ def _update(self, *args, metric_writer: Writer = MockWriter(), **kwargs) -> floa

if metric_writer:
metric_writer.scalar("Accum_loss", accum_loss)
metric_writer.scalar("_num_inner_updates", i)
metric_writer.scalar("num_inner_updates_i", i)

return accum_loss

Expand All @@ -422,6 +430,8 @@ def update_targets(
where \rho is polyak. (Always between 0 and 1, usually close to 1.)
@param metric_writer:
@type metric_writer:
@param copy_percentage:
@return:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
copy_percentage: float = 0.005,
actor_optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=1e-4),
critic_optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=1e-2),
**kwargs,
**kwargs
):
"""
Expand Down Expand Up @@ -168,7 +168,7 @@ def models(self) -> Dict[str, Architecture]:
return {"_actor": self._actor, "_critic": self._critic}

def update_targets(
self, update_percentage, *, metric_writer: Writer = None
self, update_percentage: float, *, metric_writer: Writer = None
) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from neodroidagent.agents.torch_agents.torch_agent import TorchAgent
from neodroidagent.common import (
CategoricalMLP,
Memory,
MultiDimensionalNormalMLP,
SamplePoint,
SampleTrajectoryBuffer,
Expand Down Expand Up @@ -46,12 +47,14 @@ class PGAgent(TorchAgent):

def __init__(
self,
evaluation_function=torch.nn.CrossEntropyLoss(),
policy_arch_spec=GDKC(CategoricalMLP),
discount_factor=0.95,
optimiser_spec=GDKC(torch.optim.Adam, lr=1e-4),
scheduler_spec=GDKC(torch.optim.lr_scheduler.StepLR, step_size=100, gamma=0.65),
memory_buffer=SampleTrajectoryBuffer(),
evaluation_function: callable = torch.nn.CrossEntropyLoss(),
policy_arch_spec: GDKC = GDKC(CategoricalMLP),
discount_factor: float = 0.95,
optimiser_spec: GDKC = GDKC(torch.optim.Adam, lr=1e-4),
scheduler_spec: GDKC = GDKC(
torch.optim.lr_scheduler.StepLR, step_size=100, gamma=0.65
),
memory_buffer: Memory = SampleTrajectoryBuffer(),
**kwargs,
) -> None:
r"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy
import torch
from torch.distributions import Distribution
from torch.nn.functional import mse_loss
from tqdm import tqdm

Expand All @@ -13,7 +14,7 @@

from draugr import mean_accumulator, shuffled_batches
from neodroid.utilities import ActionSpace, ObservationSpace, SignalSpace
from neodroidagent.agents.agent import ClipFeature
from neodroidagent.agents.agent import TogglableLowHigh
from neodroidagent.agents.torch_agents.torch_agent import TorchAgent
from neodroidagent.common import (
ActorCriticMLP,
Expand All @@ -23,7 +24,6 @@
)
from neodroidagent.utilities import (
ActionSpaceNotSupported,
Distribution,
is_none_or_zero_or_negative_or_mod_zero,
torch_compute_gae,
update_target,
Expand Down Expand Up @@ -65,8 +65,8 @@ def __init__(
optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=3e-4),
continuous_arch_spec: GDKC = GDKC(ActorCriticMLP),
discrete_arch_spec: GDKC = GDKC(CategoricalActorCriticMLP),
gradient_norm_clipping: ClipFeature = ClipFeature(True, 0, 0.5),
**kwargs,
gradient_norm_clipping: TogglableLowHigh = TogglableLowHigh(True, 0, 0.5),
**kwargs
) -> None:
"""
Expand Down Expand Up @@ -204,7 +204,7 @@ def _remember(
terminated: Any,
state: Any,
successor_state: Any,
sample: Any,
sample: Any
) -> None:
self._memory_buffer.add_transition_point(
ValuedTransitionPoint(
Expand Down Expand Up @@ -330,7 +330,7 @@ def _policy_loss(
log_prob_batch_old,
adv_batch,
*,
metric_writer: Writer = None,
metric_writer: Writer = None
):
action_log_probs_new = self.get_log_prob(new_distribution, action_batch)
ratio = torch.exp(action_log_probs_new - log_prob_batch_old)
Expand Down
Loading

0 comments on commit 940b5ec

Please sign in to comment.