Skip to content

Commit

Permalink
hotfix: tflite scalar tensor parsing error
Browse files Browse the repository at this point in the history
  • Loading branch information
dboyliao committed Nov 10, 2020
1 parent 7d18e68 commit 3bacf8b
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions utensor_cgen/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def _build_tensor_map(self, fb_model, ugraph):

if isinstance(tensor.ShapeAsNumpy(), np.ndarray):
shape = tensor.ShapeAsNumpy().tolist()
elif isinstance(tensor.ShapeAsNumpy(), int):
logger.warning(f"{tensor.Name().decode('utf8')} is scalar, convert to tensor as shape [1]")
shape = [1]
else:
shape = list(fb_model.Buffers(12).DataAsNumpy().view(dtype).shape)

Expand Down

0 comments on commit 3bacf8b

Please sign in to comment.