Skip to content

Commit

Permalink
fixed optimization bugs triggered by propagation of shape 1 to unbroa…
Browse files Browse the repository at this point in the history
…dcastable dims - hint: prefer broadcast_like() to alloc()
  • Loading branch information
James Bergstra committed Jan 6, 2011
1 parent d99d6d3 commit 9a45333
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions theano/tensor/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def broadcast_like(value, template, env):
if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
rval = T.alloc(T.cast(value, template.dtype), *shape_of[template])
# the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) if rval.broadcastable[i]
and not template.broadcastable[i]])
assert rval.type == template.type
return rval

Expand Down Expand Up @@ -663,14 +667,20 @@ def local_fill_to_alloc(node):
elif v.type.broadcastable == node.outputs[0].type.broadcastable:
# this is a cast
rval = [T.cast(v, node.outputs[0].type.dtype)]
elif r.type.broadcastable == node.outputs[0].type.broadcastable:
# we are broadcasting v somehow, but not r
rval = [broadcast_like(v, r, node.env)]
else:
# we are broadcasting v somehow
shape_of = node.env.shape_feature.shape_of
# we are broadcasting both v and r,
# the output shape must be computed
#
# TODO: implement this case (including a test!)
#
# I think the strategy should be to extend the shorter shape vector
# with 1s (how?) and then take the elementwise max of the two.
# - how to flag an error of shape mismatch where broadcasting should be illegal?
return
# TODO: cut out un-necessary dimshuffles of v
rval = [T.alloc(T.cast(v, node.outputs[0].dtype), *shape_of[node.outputs[0]])]

#if rval[0].type != node.outputs[0].type:
#print >> sys.stderr, theano.printing.debugprint(node.outputs[0], file='str')

assert rval[0].type == node.outputs[0].type, ('rval', rval[0].type,
'orig', node.outputs[0].type,
Expand Down Expand Up @@ -2259,8 +2269,7 @@ def local_mul_specialize(node):
neg ^= True #toggles
elif N.all(y == 0.0):
# if we find any zero, we just return right away
return [T.alloc(numpy.asarray(0, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
return [broadcast_like(0, node.outputs[0], node.env)]
else:
new_inputs.append(input)

Expand All @@ -2277,21 +2286,14 @@ def local_mul_specialize(node):
else:
rval = T.mul(*new_inputs)

return [T.alloc(T.cast(rval, node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
return [broadcast_like(rval, node.outputs[0], node.env)]
else:
# there are no variable inputs to mul
# N.B. this could have been constant-folded...
if neg:
# return output's worth of -1
return [T.alloc(
numpy.asarray(-1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
return [broadcast_like(-1, node.outputs[0], node.env)]
else:
# return output's worth of 1
return [T.alloc(
numpy.asarray(1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
return [broadcast_like(1, node.outputs[0], node.env)]

register_specialize(local_mul_specialize)

Expand Down

0 comments on commit 9a45333

Please sign in to comment.