Skip to content

Commit

Permalink
Refactoring agent structure to be more method oriented.
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Aug 20, 2019
1 parent 11c8d99 commit a11b9a5
Show file tree
Hide file tree
Showing 27 changed files with 319 additions and 318 deletions.
3 changes: 0 additions & 3 deletions digideep/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .ppo import PPO
from .ddpg import DDPG
from .sac import SAC
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .agent import Agent
from .policy import Policy
14 changes: 6 additions & 8 deletions digideep/agent/ddpg.py → digideep/agent/ddpg/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,19 @@
import torch.nn.functional as F
import torch.utils.data

from digideep.agent.samplers.ddpg import sampler_re
# from digideep.agent.samplers.common import check_shape

from digideep.utility.toolbox import get_class
from digideep.utility.logging import logger
from digideep.utility.profiling import KeepTime
from digideep.utility.monitoring import monitor


from .base import AgentBase
from digideep.agent.policy.deterministic import Policy
# from digideep.agent.sampler_common import check_shape
from digideep.agent.agent_base import AgentBase
from .sampler import sampler_re
from .policy import Policy

# torch.utils.backcompat.broadcast_warning.enabled = True

class DDPG(AgentBase):
class Agent(AgentBase):
"""This is an implementation of the Deep Deterministic Policy Gradient (`DDPG <https://arxiv.org/abs/1509.02971>`_) method.
Args:
Expand All @@ -46,7 +44,7 @@ class DDPG(AgentBase):
"""

def __init__(self, session, memory, **params):
super(DDPG, self).__init__(session, memory, **params)
super(Agent, self).__init__(session, memory, **params)

self.device = self.session.get_device()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from digideep.utility.toolbox import get_class
from digideep.utility.logging import logger

from digideep.agent.policy.base import PolicyBase
from digideep.agent.policy.common import Averager
from digideep.agent.policy_base import PolicyBase
from digideep.agent.policy_common import Averager

from copy import deepcopy

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import numpy as np
import warnings

from .common import Compose, flatten_memory_to_train_key, get_memory_params, check_nan, check_shape, check_stats, print_line
from .common import flatten_first_two
from digideep.utility.logging import logger
from digideep.agent.sampler_common import Compose, flatten_memory_to_train_key, get_memory_params, check_nan, check_shape, check_stats, print_line
from digideep.agent.sampler_common import flatten_first_two

from digideep.utility.logging import logger
from digideep.utility.profiling import KeepTime

def get_sample_memory(buffer, info):
"""Sampler function for DDPG-like algorithms where we want to sample data from an experience replay buffer.
This function does not sample from final steps where mask equals ``0`` (as they don't have subsequent observations.)
This function adds the following key to the memory:
This function adds the following key to the buffer:
* ``/observations_2``
Expand Down
Empty file removed digideep/agent/policy/__init__.py
Empty file.
117 changes: 0 additions & 117 deletions digideep/agent/policy/stochastic/blocks.py

This file was deleted.

28 changes: 0 additions & 28 deletions digideep/agent/policy/stochastic/common.py

This file was deleted.

122 changes: 0 additions & 122 deletions digideep/agent/policy/stochastic/distributions.py

This file was deleted.

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .agent import Agent
from .policy import Policy
8 changes: 4 additions & 4 deletions digideep/agent/ppo.py → digideep/agent/ppo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
import torch.nn.functional as F
import torch.utils.data

from digideep.agent.samplers.ppo import sampler_ff, sampler_rn
from .sampler import sampler_ff, sampler_rn
# from digideep.agent.samplers.common import check_shape

from digideep.utility.toolbox import get_class
from digideep.utility.logging import logger
from digideep.utility.profiling import KeepTime
from digideep.utility.monitoring import monitor

from .base import AgentBase
from digideep.agent.agent_base import AgentBase


# def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
Expand All @@ -25,7 +25,7 @@
# param_group['lr'] = lr


class PPO(AgentBase):
class Agent(AgentBase):
"""The implementation of the Proximal Policy Optimization (`PPO <https://arxiv.org/abs/1707.06347>`_) method.
Args:
Expand All @@ -50,7 +50,7 @@ class PPO(AgentBase):
"""
def __init__(self, session, memory, **params):
super(PPO, self).__init__(session, memory, **params)
super(Agent, self).__init__(session, memory, **params)

self.device = self.session.get_device()

Expand Down

0 comments on commit a11b9a5

Please sign in to comment.