Skip to content

Commit

Permalink
Refactor tfmdp.model.cell.basic_cell to use tfmdp.model.utils module …
Browse files Browse the repository at this point in the history
…functions
  • Loading branch information
thiagopbueno committed Apr 16, 2019
1 parent 20106cb commit 72b25e3
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions tfmdp/model/cell/basic_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import rddl2tf

from tfmdp.model import utils
from tfmdp.policy.drp import DeepReactivePolicy

import collections
Expand All @@ -39,13 +40,6 @@
OutputTuple = collections.namedtuple('OutputTuple', 'state action interms reward')


def cell_size(sizes: Sequence[Shape]) -> Sequence[Union[Shape, int]]:
return tuple(sz if sz != () else (1,) for sz in sizes)


def to_tensor(fluents):
return tuple(f.tensor for f in fluents)


class BasicMarkovCell(tf.nn.rnn_cell.RNNCell):
'''BasicMarkovCell implements a 1-step MDP transition function
Expand Down Expand Up @@ -74,17 +68,17 @@ def graph(self) -> tf.Graph:
@property
def state_size(self) -> Sequence[Shape]:
'''Returns the MDP state size.'''
return cell_size(self.compiler.rddl.state_size)
return utils.cell_size(self.compiler.rddl.state_size)

@property
def action_size(self) -> Sequence[Shape]:
'''Returns the MDP action size.'''
return cell_size(self.compiler.rddl.action_size)
return utils.cell_size(self.compiler.rddl.action_size)

@property
def interm_size(self) -> Sequence[Shape]:
'''Returns the MDP intermediate state size.'''
return cell_size(self.compiler.rddl.interm_size)
return utils.cell_size(self.compiler.rddl.interm_size)

@property
def output_size(self) -> Tuple[Sequence[Shape], Sequence[Shape], Sequence[Shape], int]:
Expand Down Expand Up @@ -122,8 +116,8 @@ def __call__(self,
reward = self.compiler.reward(state, action, next_state)

# outputs
next_state = to_tensor(next_state)
interms = to_tensor(interms)
next_state = utils.to_tensor(next_state)
interms = utils.to_tensor(interms)
output = OutputTuple(next_state, action, interms, reward)

return (output, next_state)

0 comments on commit 72b25e3

Please sign in to comment.