In [1]:
import time
import cv2
import numpy as np
import torchvision.transforms.functional as TF

In [2]:
def blend(img_path, mask_path, mask_opacity):
    """Blend mask and image.
    
    Args:
        img_path (str): path to the image.
        mask_path (str): path to the mask.
        mask_opacity (float): opacity of the mask.
    
    Returns:
        np.ndarray: blended image.
    """
    mask_img = cv2.imread(mask_path)
    bkgd_img = cv2.imread(img_path)
    
    # blend part of background
    mask = cv2.cvtColor(mask_img, cv2.COLOR_BGR2GRAY)
    bkgd_blend = cv2.bitwise_and(bkgd_img, bkgd_img, mask=mask)
    
    # non-blend part of background
    inv_mask = (mask == 0).astype(np.uint8)
    bkgd_non_blend = cv2.bitwise_and(bkgd_img, bkgd_img, mask=inv_mask)
    
    mask_ovelay = cv2.addWeighted(mask_img, mask_opacity, bkgd_blend, 1.0 - mask_opacity, 0)
    whole_img = cv2.addWeighted(mask_ovelay, 1.0, bkgd_non_blend, 1.0, 0)
    return whole_img

In [3]:
def torch_blend(img_path, mask_path, mask_opacity):
    """Blend mask and image using PyTorch.
    
    Args:
        img_path (str): path to the image.
        mask_path (str): path to the mask.
        mask_opacity (float): opacity of the mask.
    
    Returns:
        np.ndarray: blended image.
    """
    # load image and mask
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    
    # convert to tensor and move to GPU
    img_tensor = TF.to_tensor(img).unsqueeze(0).cuda()
    mask_tensor = TF.to_tensor(mask).unsqueeze(0).cuda()
    
    # blend image and mask
    blended_tensor = img_tensor * (1 - mask_opacity) + mask_tensor * mask_opacity
    
    # convert back to numpy array
    blended_img = blended_tensor.squeeze().cpu().numpy()
    blended_img = np.transpose(blended_img, (1, 2, 0))
    blended_img = (blended_img * 255).astype(np.uint8)
    
    return blended_img


In [4]:
def _test():
    img_path = '/home/tom/github/niceview/db/cache/gt-iz-p9-rep2-wsi-img.tiff'
    mask_path = '/home/tom/github/niceview/db/cache/gt-iz-p9-rep2-mask-cell-gene-img.png'
    mask_opacity = 0.5
    
    time_start = time.time()
    blend(img_path, mask_path, mask_opacity)
    time_end = time.time()
    print('time cost: ', time_end - time_start, 's')
    
    time_start = time.time()
    torch_blend(img_path, mask_path, mask_opacity)
    time_end = time.time()
    print('time cost: ', time_end - time_start, 's')

_test()

time cost:  2.378056764602661 s
time cost:  6.381999254226685 s
