In [172]:
import numpy as np
from PIL import Image, ImageDraw, ImageColor
import matplotlib.pyplot as plt
import scipy

eps = 1.e-30

# Either
# HIGH_PASS
# or
# LOW_PASS
filter_type = "HIGH_PASS"

In [None]:
image = np.array(Image.open('pops_512.jpg'))
display(Image.fromarray(image, mode='RGB'))

r, g, b = image[:, :, 0], image[:, :, 1], image[:, :, 2]

In [174]:
def normalise(arr, clamp_zero=False, var_arr=None):
    if var_arr is None:
        var_arr = arr

    min_val = np.min((arr, var_arr)) if not clamp_zero else 0
    max_val = np.max((arr, var_arr))

    if min_val == max_val:
        return np.full_like(arr, 0, dtype=np.uint8)

    scale_factor = 255 / (max_val - min_val)
    norm_arr = (arr - min_val) * scale_factor

    return np.clip(np.round(norm_arr), 0, 255).astype(np.uint8)

In [None]:
gauss_window = scipy.signal.windows.gaussian(image.shape[0], 20)

gaussian = gauss_window[:, np.newaxis] @ gauss_window[np.newaxis, :]

# can modify the gaussian for different results
gaussian = gaussian * 1e3 + 1e-2

if (filter_type == "HIGH_PASS"):
    gaussian = np.max(gaussian) - gaussian

display(Image.fromarray(normalise(gaussian), mode="L"))

In [176]:
F_r = np.fft.fftshift(np.fft.fft2(r))
F_g = np.fft.fftshift(np.fft.fft2(g))
F_b = np.fft.fftshift(np.fft.fft2(b))

F_r_filt = F_r * gaussian
F_g_filt = F_g * gaussian
F_b_filt = F_b * gaussian

r_filt = normalise(np.fft.ifft2(np.fft.ifftshift(F_r_filt)).real)
g_filt = normalise(np.fft.ifft2(np.fft.ifftshift(F_g_filt)).real)
b_filt = normalise(np.fft.ifft2(np.fft.ifftshift(F_b_filt)).real)

In [None]:
fig, axs = plt.subplots(3, 2)

fig.patch.set_visible(False)

for ax in axs.ravel():
    ax.axis('off')


axs[0, 0].imshow(normalise(np.log(np.abs(F_r.real))), cmap='Reds_r', vmin = 0, vmax = 255)
axs[0, 1].imshow(normalise(np.log(np.abs(F_r_filt.real + eps))), cmap='Reds_r', vmin = 0, vmax = 255)
axs[1, 0].imshow(normalise(np.log(np.abs(F_g.real))), cmap='Greens_r', vmin = 0, vmax = 255)
axs[1, 1].imshow(normalise(np.log(np.abs(F_g_filt.real + eps))), cmap='Greens_r', vmin = 0, vmax = 255)
axs[2, 0].imshow(normalise(np.log(np.abs(F_b.real))), cmap='Blues_r', vmin = 0, vmax = 255)
axs[2, 1].imshow(normalise(np.log(np.abs(F_b_filt.real + eps))), cmap='Blues_r', vmin = 0, vmax = 255)

plt.tight_layout()
plt.show()

# this plot isn't great, the RGB one at the bottom is better

In [None]:
fig, axs = plt.subplots(3, 2)

fig.patch.set_visible(False)

for ax in axs.ravel():
    ax.axis('off')

axs[0, 0].imshow(r, cmap='Reds_r')
axs[0, 1].imshow(r_filt, cmap='Reds_r')
axs[1, 0].imshow(g, cmap='Greens_r')
axs[1, 1].imshow(g_filt, cmap='Greens_r')
axs[2, 0].imshow(b, cmap='Blues_r')
axs[2, 1].imshow(b_filt, cmap='Blues_r')

plt.tight_layout()
plt.show()

In [None]:
fft_orig = normalise(np.log(np.abs(np.dstack((F_r, F_g, F_b)))), clamp_zero=True)
fft_filt = normalise(np.log(np.abs(np.dstack((F_r_filt, F_g_filt, F_b_filt)))), clamp_zero=True, var_arr=np.log(np.abs(np.dstack((F_r, F_g, F_b)))))
image_filt = normalise(np.dstack((r_filt, g_filt, b_filt)), clamp_zero=True)

display(Image.fromarray(fft_orig, mode='RGB'))
display(Image.fromarray(fft_filt, mode='RGB'))
display(Image.fromarray(image_filt, mode='RGB'))