In [None]:
import numpy as np
from matplotlib import pyplot as plt
import cv2
import seaborn as sns

In [None]:
from skimage import color, data, restoration
from scipy.signal import convolve2d

In [None]:
img = cv2.imread('data/SampleImage.jpg',-1)

In [None]:
# blur_img_noise = blur_img + 0.1 * img.std() * np.random.standard_normal(img.shape)

In [None]:
def no_axis(axes):
    assert isinstance(axes, list)
    for ax in axes:
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)

In [None]:
def image_display(img, title=None):
    if title is None:
        title = "image"
    cv2.imshow(title, img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

## [1] scikit-image.restoration.wiener 이용

In [None]:
from skimage import color, data, restoration
img = color.rgb2gray(data.astronaut())
from scipy.signal import convolve2d
psf = np.ones((5, 5)) / 25
blur_img = convolve2d(img, psf, 'full')
blur_img_noise = blur_img + 0.1 * blur_img.std() * np.random.standard_normal(blur_img.shape)
deconvolved_img = restoration.wiener(blur_img_noise, psf, 0.1)

In [None]:
image_display(img, 'original')
image_display(blur_img, 'blur')
image_display(deconvolved_img, 'ret')

## [2] numpy fourier transform 이용

In [None]:
import numpy as np
from numpy.fft import fft2, ifft2

def wiener_filter(img, kernel, K = 10):
    dummy = np.copy(img)
    kernel = np.pad(kernel, [(0, dummy.shape[0] - kernel.shape[0]), (0, dummy.shape[1] - kernel.shape[1])], 'constant')
    # Fourier Transform
    dummy = fft2(dummy)
    kernel = fft2(kernel)
    kernel = np.conj(kernel) / (np.abs(kernel) ** 2 + K)
    dummy = dummy * kernel
    dummy = np.abs(ifft2(dummy))
    return np.uint8(dummy)

In [None]:
from scipy.signal import gaussian, convolve2d, deconvolve
from scipy.ndimage import convolve

def blur(img, mode = 'box', block_size = 3):
    # mode = 'box' or 'gaussian' or 'motion'
    dummy = np.copy(img)
    if mode == 'box':
        h = np.ones((block_size, block_size)) / block_size ** 2
    elif mode == 'gaussian':
        h = gaussian(block_size, block_size / 3).reshape(block_size, 1)
        h = np.dot(h, h.transpose())
        h /= np.sum(h)
    elif mode == 'motion':
        h = np.eye(block_size) / block_size
    dummy = convolve2d(dummy, h, mode = 'valid')
    return np.uint8(dummy), h

def gaussian_add(img, sigma = 5):
    dummy = np.copy(img).astype(float)
    gauss = np.random.normal(0, sigma, np.shape(img))
    # Additive Noise
    dummy = np.round(gauss + dummy)
    # Saturate lower bound
    dummy[np.where(dummy < 0)] = 0
    # Saturate upper bound
    dummy[np.where(dummy > 255)] = 255
    return np.uint8(dummy)

In [None]:
img = cv2.imread('./data/SampleImage_grey.jpg', -1)
# _img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

In [None]:
psf = np.ones((5,5))/25
blur_img = np.copy(img)
blur_img= convolve2d(img, psf, 'full')
blur_img = convolve(img, psf, mode='wrap')
ret_img = wiener_filter(blur_img, psf, K=0.4)

In [None]:
image_display(img, 'original')
image_display(blur_img, 'blur')
image_display(ret_img, 'ret')

## [Additional] fft

In [None]:
from scipy import fftpack, stats

def _convolve(star, psf):
    star_fft = fftpack.fftshift(fftpack.fftn(star))
    psf_fft = fftpack.fftshift(fftpack.fftn(psf))
    return fftpack.fftshift(fftpack.ifftn(fftpack.ifftshift(star_fft*psf_fft)))

def _deconvolve(star, psf):
    star_fft = fftpack.fftshift(fftpack.fftn(star))
    psf_fft = fftpack.fftshift(fftpack.fftn(psf))
    return fftpack.fftshift(fftpack.ifftn(fftpack.ifftshift(star_fft/psf_fft)))

sx, sy = 100, 100
X, Y = np.ogrid[0:sx, 0:sy]
star = stats.norm.pdf(np.sqrt((X - sx/2)**2 + (Y - sy/2)**2), 0, 4)
psf = stats.norm.pdf(np.sqrt((X - sx/2)**2 + (Y - sy/2)**2), 0, 10)

# star_conv = fftconvolve(star, psf, mode="same")
# star_deconv = fftdeconvolve(star_conv, psf, mode="same")


# star = img
# psf = np.ones((512,512))/100000
star_conv = _convolve(star, psf)
star_deconv = _deconvolve(star_conv, psf)

f, axes = plt.subplots(2,2)
axes[0,0].imshow(star)
axes[0,1].imshow(psf)
axes[1,0].imshow(np.real(star_conv))
axes[1,1].imshow(np.real(star_deconv))
plt.show()

In [None]:
# fig, ((ax1,ax2), (ax3,ax4)) = plt.subplots(nrows=2, ncols=2)
# no_axis([ax1,ax2,ax3,ax4])
# ax1.imshow(img)
# ax2.imshow(blur_img)
# ax3.imshow(blur_img_noise)
# ax4.imshow(restoration.wiener(blur_img, psf, 10))
# plt.show()