diff --git a/opendeep/models/utils/pooling.py b/opendeep/models/utils/pooling.py index a6093ef..f528888 100755 --- a/opendeep/models/utils/pooling.py +++ b/opendeep/models/utils/pooling.py @@ -22,15 +22,29 @@ def _pool_out_size(imgshape, ds, st, padding, ignore_border=True): assert len(ds) == len(st), "stride and size need to have the same number of dimensions" ndims = len(ds) pooldims = list(imgshape[-ndims:]) - pooldims = [dim + pad*2 for dim, pad in zip(pooldims, padding)] + pooldims = [dim + pad*2 if dim is not None and pad is not None + else None + for dim, pad in zip(pooldims, padding)] if ignore_border: - outdims = [(dim - size) // stride + 1 for dim, size, stride in zip(pooldims, ds, st)] - outdims = [max(outdim, 0) for outdim in outdims] - else: - outdims = [(dim - 1) // stride + 1 if stride >= size - else max(0, (dim - 1 - size) // stride + 1) + 1 + outdims = [(dim - size) // stride + 1 + if dim is not None and size is not None and stride is not None + else None for dim, size, stride in zip(pooldims, ds, st)] + outdims = [max(outdim, 0) + if outdim is not None + else None + for outdim in outdims] + else: + outdims = [] + for dim, size, stride in zip(pooldims, ds, st): + if dim is not None and size is not None and stride is not None: + if stride >= size: + outdims.append((dim - 1) // stride + 1) + else: + outdims.append(max(0, (dim - 1 - size) // stride + 1) + 1) + else: + outdims.append(None) rval = list(imgshape[:-ndims]) + outdims return rval