Skip to content

Commit

Permalink
[Bugfix] Fix reshape (apache#5739)
Browse files Browse the repository at this point in the history
* Fix reshape

* fix doc warning

* fix ci

* address comments
  • Loading branch information
comaniac authored and Trevor Morris committed Jun 12, 2020
1 parent b6bbed0 commit eaaefd4
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from . import _make
from ..expr import TupleWrapper, const
from ...tir import expr as _expr


def cast(data, dtype):
Expand Down Expand Up @@ -212,7 +213,16 @@ def reshape(data, newshape):
if isinstance(newshape, int):
newshape = const([newshape])
if isinstance(newshape, (tuple, list)):
newshape = const(list(newshape))
tempshape = []
for shape in newshape:
if isinstance(shape, _expr.IntImm):
tempshape.append(shape.value)
else:
try:
tempshape.append(int(shape))
except ValueError as err:
raise RuntimeError('Unrecognized shape type: %s' % err)
newshape = const(tempshape)
return _make.reshape(data, newshape)

def argwhere(condition):
Expand Down

0 comments on commit eaaefd4

Please sign in to comment.