Skip to content

Commit

Permalink
fix Pool2D output_size calculation to work with unknown (None) image …
Browse files Browse the repository at this point in the history
…dimensions.
  • Loading branch information
mbeissinger committed Nov 6, 2015
1 parent e8d25b4 commit 523d152
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions opendeep/models/utils/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,29 @@ def _pool_out_size(imgshape, ds, st, padding, ignore_border=True):
assert len(ds) == len(st), "stride and size need to have the same number of dimensions"
ndims = len(ds)
pooldims = list(imgshape[-ndims:])
pooldims = [dim + pad*2 for dim, pad in zip(pooldims, padding)]
pooldims = [dim + pad*2 if dim is not None and pad is not None
else None
for dim, pad in zip(pooldims, padding)]

if ignore_border:
outdims = [(dim - size) // stride + 1 for dim, size, stride in zip(pooldims, ds, st)]
outdims = [max(outdim, 0) for outdim in outdims]
else:
outdims = [(dim - 1) // stride + 1 if stride >= size
else max(0, (dim - 1 - size) // stride + 1) + 1
outdims = [(dim - size) // stride + 1
if dim is not None and size is not None and stride is not None
else None
for dim, size, stride in zip(pooldims, ds, st)]
outdims = [max(outdim, 0)
if outdim is not None
else None
for outdim in outdims]
else:
outdims = []
for dim, size, stride in zip(pooldims, ds, st):
if dim is not None and size is not None and stride is not None:
if stride >= size:
outdims.append((dim - 1) // stride + 1)
else:
outdims.append(max(0, (dim - 1 - size) // stride + 1) + 1)
else:
outdims.append(None)

rval = list(imgshape[:-ndims]) + outdims
return rval
Expand Down

0 comments on commit 523d152

Please sign in to comment.