In [17]:
import kornia
import torch
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.filters.kernels import _unpack_2d_ks, get_gaussian_kernel2d
from kornia.filters.median import _compute_zero_padding
from kornia.core import Tensor, pad
from typing import Union, Tuple
import numpy as np
from PIL import Image

In [2]:
def joint_bilateral_blur(
    inp: Tensor,
    guidance: Union[Tensor, None],
    kernel_size: Union[Tuple[int, int], int],
    sigma_color: Union[float, Tensor],
    sigma_space: Union[Tuple[float, float], Tensor],
    border_type: str = 'reflect',
    color_distance_type: str = 'l1',
) -> Tensor:
    "Single implementation for both Bilateral Filter and Joint Bilateral Filter"

    if isinstance(sigma_color, Tensor):
        KORNIA_CHECK_SHAPE(sigma_color, ['B'])
        sigma_color = sigma_color.to(device=inp.device, dtype=inp.dtype).view(-1, 1, 1, 1, 1)

    kx, ky = _unpack_2d_ks(kernel_size)
    pad_x, pad_y = _compute_zero_padding(kernel_size)

    padded_input = pad(inp, (pad_x, pad_x, pad_y, pad_y), mode=border_type)
    unfolded_input = padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2)  # (B, C, H, W, K x K)

    if guidance is None:
        guidance = inp
        unfolded_guidance = unfolded_input
    else:
        padded_guidance = pad(guidance, (pad_x, pad_x, pad_y, pad_y), mode=border_type)
        unfolded_guidance = padded_guidance.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2)  # (B, C, H, W, K x K)

    diff = unfolded_guidance - guidance.unsqueeze(-1)
    if color_distance_type == "l1":
        color_distance_sq = diff.abs().sum(1, keepdim=True).square()
    elif color_distance_type == "l2":
        color_distance_sq = diff.square().sum(1, keepdim=True)
    else:
        raise ValueError("color_distance_type only acceps l1 or l2")
    color_kernel = (-0.5 / sigma_color**2 * color_distance_sq).exp()  # (B, 1, H, W, K x K)

    space_kernel = get_gaussian_kernel2d(kernel_size, sigma_space, device=inp.device, dtype=inp.dtype)
    space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky)

    kernel = space_kernel * color_kernel
    out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1)
    return out

In [35]:
unsharp_filter = kornia.filters.UnsharpMask((5, 5), (1.5, 1.5))
guided_blur = lambda inp, gui: joint_bilateral_blur(inp, gui, (5, 5), 0.1, (1.5, 1.5))

In [25]:
img = Image.open('he.png').convert('RGB').resize((256, 256))
img = np.array(img)

In [26]:
# convert to tensor
img = kornia.image_to_tensor(img, keepdim=True)

In [36]:
blurred = guided_blur(img.expand(1, -1, -1, -1).type(torch.float32), img.expand(1, -1, -1, -1).type(torch.float32))
blurred = kornia.tensor_to_image(blurred)
blurred = Image.fromarray(blurred.astype(np.uint8))
blurred.show()

In [27]:
unsharped = unsharp_filter(img.expand(1, -1, -1, -1).type(torch.float32))

In [28]:
unsharped.shape

torch.Size([1, 3, 256, 256])

In [29]:
# tensor back to image and display it
unsharped = kornia.tensor_to_image(unsharped)

In [30]:
unsharped.shape

(256, 256, 3)

In [31]:
unsharped = Image.fromarray(unsharped.astype(np.uint8))

In [33]:
unsharped.show()