In [8]:

import torch
from torchvision.io import read_image, write_png
from lichi import LIChI
import argparse
import time
import matplotlib.pyplot as plt
import numpy as np

In [19]:
parser = argparse.ArgumentParser()
parser.add_argument("--sigma", type=float, dest="sigma",
                    help="Standard deviation of the noise (noise level). Should be between 0 and 50.", default=30)
parser.add_argument("--in", type=str, dest="path_in",
                    help="Path to the image to denoise (PNG or JPEG).", default="./test_images/cameraman.png")
parser.add_argument("--out", type=str, dest="path_out",
                    help="Path to save the denoised image.", default="./img_lichi.png")
parser.add_argument("--add_noise", action='store_true',
                    help="Add artificial Gaussian noise to the image.", default=True)

# To avoid conflicts with Jupyter's arguments, pass an empty list to parse_args()
args = parser.parse_args(args=[])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [20]:
# Read image and Add noise
img = read_image(args.path_in)[None, :, :, :].float().to(device)
img_noisy = img + args.sigma * torch.randn_like(img) if args.add_noise else img

In [21]:
def lichi(img_noisy, sigma):
    model = LIChI()
    if sigma <= 10:
        img_lichi = model(img_noisy, sigma=sigma, constraints='affine', method='n2n', p1=9, p2=6, k1=16, k2=64, w=65, s=3, M=6)
    elif sigma <= 30:
        img_lichi = model(img_noisy, sigma=sigma, constraints='affine', method='n2n', p1=11, p2=6, k1=16, k2=64, w=65, s=3, M=9)
    else:
        img_lichi = model(img_noisy, sigma=sigma, constraints='affine', method='n2n', p1=13, p2=6, k1=16, k2=64, w=65, s=3, M=11)
    img_lichi = img_lichi.clip(0, 255)
    return img_lichi

def LR(img_noisy, threshold=0.99):
    img_noisy = torch.squeeze(img_noisy)
    U, S, Vh = torch.linalg.svd(img_noisy, full_matrices=False)
    energy_total = torch.sum(S**2)
    energy_cumulative = torch.cumsum(S**2, dim=0)
    energy_threshold = threshold * energy_total
    k = torch.searchsorted(energy_cumulative, energy_threshold).item()+1
    U_k = U[:, :k]
    S_k = torch.diag(S[:k])
    Vh_k = Vh[:k, :]
    img_lr = U_k @ S_k @ Vh_k
    img_lr = img_lr.clip(0, 255)
    return img_lr

def im2row(im, img_size, block_size):
    H, W = img_size
    block_h, block_w = block_size

    patch_h = H - block_h + 1
    patch_w = W - block_w + 1
    patch_num = patch_h * patch_w
    block_length = block_h * block_w
    result = torch.zeros((block_length, patch_num))

    patch_idx = 0
    for y in range(patch_w):
        for x in range(patch_h):
            patch = im[x:x+block_h, y:y+block_w]
            result[:, patch_idx] = patch.reshape(-1)
            patch_idx += 1
    return result


def row2im(patches, img_size, block_size):
    H, W = img_size
    block_h, block_w = block_size
    result = torch.zeros((H, W))
    weight = torch.zeros((H, W))

    patch_idx = 0
    for y in range(W-block_w+1):
        for x in range(H-block_h+1):
            patch = patches[:, patch_idx]
            result[x:x+block_h, y:y+block_w] += patch.reshape(8, 8)
            weight[x:x+block_h, y:y+block_w] += 1
            patch_idx += 1
    
    # weight = torch.clamp(weight, min=1e-6)
    result /= weight
    return result

def LLR(img_noisy, k, block_size=[8, 8]):
    img_noisy = torch.squeeze(img_noisy)
    img_size = img_noisy.shape
    patches = im2row(img_noisy, img_size, block_size)
    U, S, Vh = torch.linalg.svd(patches, full_matrices=False)
    U_k = U[:, :k]
    S_k = torch.diag(S[:k])
    Vh_k = Vh[:k, :]
    patch_lr = U_k @ S_k @ Vh_k
    img_llr = row2im(patch_lr, img_size, block_size)
    img_llr = img_llr.clip(0, 255)
    return img_llr


In [23]:
t = time.time()
img_lichi = lichi(img_noisy, args.sigma)
print("Time elapsed LiChI:", round(time.time() - t, 3), "seconds")
write_png(img_lichi[0, :, :, :].byte().to("cpu"), args.path_out)

t = time.time()
img_lr = LR(img_noisy)
print("Time elapsed LR:", round(time.time() - t, 3), "seconds")
write_png(img_lr.unsqueeze(0).byte().to("cpu"), "./img_lr.png")

t = time.time()
img_llr = LLR(img_noisy, k=20)
print("Time elapsed LLR:", round(time.time() - t, 3), "seconds")
write_png(img_llr.unsqueeze(0).byte().to("cpu"), "./img_llr.png")

Time elapsed LiChI: 12.871 seconds
Time elapsed LR: 0.023 seconds
Time elapsed LLR: 5.408 seconds


In [24]:
psnr = 10*torch.log10(255**2 / torch.mean((img_lichi - img)**2))
print("PSNR LiChI:", round(float(psnr), 2), "dB")

psnr = 10*torch.log10(255**2 / torch.mean((img_lr - img)**2))
print("PSNR LR:", round(float(psnr), 2), "dB")

psnr = 10*torch.log10(255**2 / torch.mean((img_llr - img)**2))
print("PSNR LLR:", round(float(psnr), 2), "dB")

PSNR LiChI: 28.86 dB
PSNR LR: 19.7 dB
PSNR LLR: 23.84 dB


In [17]:
img_lr.shape, img_llr.shape

(torch.Size([256, 256]), torch.Size([256, 256]))