In [5]:
# -*- coding: utf-8 -*-
"""
Inference of the MRI-SAM on the nifti datasets.
"""

import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F
import torch

In [3]:
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) :
        """
        Compute the output size given input size and target long side length.
        """
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)

In [40]:
def preprocess_mask(mask, target_size=256):
    """
    Preprocess masks from original size to target size.

    Args:
        mask (np.array): the mask with shape of BxCxHxW
        target_size (int, optional): The target size. Defaults to 256.
        
    Returns:
        (np.array): the mask with the target shape.
        
    """
    resize_long = get_preprocess_shape(mask.shape[-2], mask.shape[-1], target_size)
    print(resize_long)
    resized_mask = F.interpolate(mask, 
                           resize_long, 
                           mode="nearest")
    # Pad
    print(resized_mask.shape)
    h, w = resized_mask.shape[-2:]
    padh = target_size - h
    padw = target_size - w
    x = F.pad(resized_mask, (0, padw, 0, padh), value=0)
    print(x.shape)
    return x

In [36]:
mask = torch.ones((128,255))

In [37]:
mask.shape

torch.Size([128, 255])

In [39]:
get_preprocess_shape(mask.shape[0], mask.shape[1], 256)

(129, 256)

In [41]:
re = preprocess_mask(mask.type(torch.float)[None,None,:,:],target_size=256)

(129, 256)
torch.Size([1, 1, 129, 256])
torch.Size([1, 1, 256, 256])


In [42]:
re

tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])

In [25]:
re.shape

torch.Size([1, 1, 256, 256])