Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions src/algorithms/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch.nn as nn
import torch.nn.functional as F

from src.utils.experience_buffer import unpack_batch

Env = TypeVar("Env")
Optimizer = TypeVar("Optimizer")
LossFunction = TypeVar("LossFunction")
Expand Down Expand Up @@ -53,6 +55,8 @@ def __init__(self):
self.n_iterations_per_episode: int = 100
self.optimizer: Optimizer = None
self.loss_function: LossFunction = None
self.batch_size: int = 0
self.device: str = 'cpu'


class A2C(Generic[Optimizer]):
Expand All @@ -63,15 +67,15 @@ def __init__(self, config: A2CConfig, a2c_net: A2CNet):
self.tau = config.tau
self.n_workers = config.n_workers
self.n_iterations_per_episode = config.n_iterations_per_episode
self.batch_size = config.batch_size
self.optimizer = config.optimizer
self.device = config.device
self.loss_function = config.loss_function
self.a2c_net = a2c_net
self.rewards = []
self.memory = []
self.name = "A2C"

def _optimize_model(self):
pass

def select_action(self, env: Env, observation: State) -> Action:
"""
Select an action
Expand All @@ -81,17 +85,43 @@ def select_action(self, env: Env, observation: State) -> Action:
"""
return env.sample_action()

def update(self):
def update_policy_network(self):
"""
Update the policy network
:return:
"""
pass

def calculate_loss(self):
"""
Calculate the loss
:return:
"""
pass

def accummulate_batch(self):
"""
Accumulate the memory items
:return:
"""
pass

def train(self, env: Env) -> None:
"""
Train the agent on the given environment
:param env:
:return:
"""

# reset the environment and obtain the
# the time step
time_step: TimeStep = env.reset()

observation = time_step.observation

# the batch to process
batch = []

# learn over the episode
for iteration in range(1, self.n_iterations_per_episode + 1):

Expand All @@ -102,11 +132,27 @@ def train(self, env: Env) -> None:
# to the selected action
next_time_step = env.step(action=action)

batch.append(next_time_step.observation)

if len(batch) < self.batch_size:
continue

# unpack the batch in order to process it
states_v, actions_t, vals_ref = unpack_batch(batch=batch, net=self.a2c_net, device=self.device)
batch.clear()

self.optimizer.zero_grad()
# we reached the end of the episode
if next_time_step.last():
break
#if next_time_step.last():
# break

#next_state = next_time_step.observation
policy_val, v_val = self.a2c_net.forward(x=states_v)

self.optimizer.zero_grad()

next_state = next_time_step.observation
policy_val, v_val = self.a2c_net.forward(x=next_state)
self._optimize_model()
# claculate loss
loss = self.calculate_loss()
loss.backward()
self.optimizer.step()

8 changes: 8 additions & 0 deletions src/utils/experience_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

from typing import TypeVar

Net = TypeVar('Net')
Batch = TypeVar('Batch')

def unpack_batch(batch, net: Net, device: str='cpu'):
pass
2 changes: 1 addition & 1 deletion src/utils/serial_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class SerialHierarchy(HierarchyBase):
that are applied one after the other. Applications should explicitly
provide the list of the ensuing transformations. For example assume that the
data field has the value 'foo' then values
the following list ['fo*', 'f**', '***']
the following list ['fo*', 'f**', '***']
"""
def __init__(self, values: List) -> None:
"""
Expand Down