Skip to content

Commit

Permalink
Add shape attribute to StochasticTensor to fix the error when running…
Browse files Browse the repository at this point in the history
… vae_conv
  • Loading branch information
csy530216 committed Oct 10, 2018
1 parent bfb2ad5 commit 22fb905
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tests/model/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

class TestStochasticTensor(tf.test.TestCase):
def test_init(self):
samples = Mock()
static_shape = Mock()
get_shape_func = Mock(return_value=static_shape)
samples = Mock(get_shape=get_shape_func)
log_probs = Mock()
probs = Mock()
sample_func = Mock(return_value=samples)
Expand All @@ -35,6 +37,8 @@ def test_init(self):
self.assertTrue(s_tensor.tensor is samples)
self.assertTrue(s_tensor.log_prob(None) is log_probs)
self.assertTrue(s_tensor.prob(None) is probs)
self.assertTrue(s_tensor.get_shape() is static_shape)
self.assertTrue(s_tensor.shape is static_shape)

obs_int32 = tf.placeholder(tf.int32, None)
obs_float32 = tf.placeholder(tf.float32, None)
Expand Down
4 changes: 4 additions & 0 deletions zhusuan/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def tensor(self):
self._tensor = self.sample(self._n_samples)
return self._tensor

@property
def shape(self):
return self.get_shape()

def get_shape(self):
return self.tensor.get_shape()

Expand Down

0 comments on commit 22fb905

Please sign in to comment.