# Fourier Transform for detecting defects on images with regular patterns

In [None]:
# To autoreload external functions
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from PIL import Image, ImageChops
import numpy as np
from scipy.fft import rfft, rfftfreq, fft2, fftshift, ifft2, ifftshift
import cv2
from skimage import morphology
from skimage import measure
from skimage import segmentation
from bokeh.plotting import figure, show
from bokeh.models import HoverTool, VBar, Block
from bokeh.layouts import row, column

import rootutils
root = rootutils.setup_root(Path.cwd(), dotenv=True, pythonpath=True, cwd=False)

from src.visualization.utils import save_plot_from_notbook_for_jekyll, bokeh_notebook_setup, save_plot_from_notebook_to_html
from src.visualization.image import plot_img_rgba, plot_img_scalar, add_bboxes_on_img

## Setup

In [None]:
bokeh_notebook_setup()

In [None]:
data_path = Path("../data/raw/grid")
output_path = Path("./logs")

## Helper Functions

In [None]:
def line_plot(x: np.ndarray, y: np.ndarray, title: str = "") -> None:
    p = figure(width=400, height=400, title=title)
    p.add_tools(
        HoverTool(
            tooltips=[
                ("(x,y)", "(@x, @y)"),
            ],
        )
    )
    p.line(x, y, line_width=2)
    p.toolbar.logo = None

    return p

## Normal data sample with regular pattern

In [None]:
img_path = data_path / "train/good/000.png"

img = Image.open(img_path)
img = img.convert("RGBA")

p = plot_img_rgba(img)
show(p)

## Data sample with a defect

In [None]:
img_path = data_path / "test/broken/000.png"

img = Image.open(img_path)
img = img.convert("RGBA")

p = plot_img_rgba(img)
show(p)

## Image Fourier Transform

### Idea

In [None]:
x_max = 1.0
N_samples = 100
sample_rate = N_samples / x_max
freq = 3.0

x = np.linspace(0.0, x_max, N_samples)
y = np.sin(freq * 2 * np.pi * x)

p_left = line_plot(x, y, title="Original Signal")

yf = rfft(y)
yf = np.abs(yf)
xf = rfftfreq(N_samples, 1 / sample_rate)

p_right = line_plot(xf, yf, title="Fourier Transform")

p = row(p_left, p_right)
show(p)

In [None]:
y = np.zeros_like(x)
y[0] = 1.0

p_left = line_plot(x, y, title="Original Signal")

yf = rfft(y)
yf = np.abs(yf)
xf = rfftfreq(N_samples, 1 / sample_rate)

p_right = line_plot(xf, yf, title="Fourier Transform")

p = row(p_left, p_right)
show(p)

### Step-by-step

In [None]:
img_path = data_path / "test/broken/000.png"

img = Image.open(img_path)
img = img.convert("L")

In [None]:
# prepare image
img_np = np.array(img)
img_np = img_np / 255.0

# transform to Fourier space
f = fft2(img_np)
fshift = fftshift(f)

# frequency magnitudes
mag_img = np.log(np.abs(fshift))

p = plot_img_scalar(mag_img)
show(p)

In [None]:
# thresholding
mag_thresh = 0.65 # relative to max value
max_val = mag_img.max()
thresh_val = mag_img.min() + mag_thresh * (max_val - mag_img.min())
ret, mag_img_mask = cv2.threshold(mag_img, thresh_val, 1.0, cv2.THRESH_BINARY)

p = plot_img_scalar(mag_img_mask)
show(p)

In [None]:
# masking
mag_img_mask = mag_img_mask.astype(bool)
fshift_proc = fshift * mag_img_mask

# transform back
f_ishift = ifftshift(fshift_proc)
img_proc = ifft2(f_ishift)
img_proc = np.abs(img_proc)

# convert to image
img_proc = (img_proc * 255.0).astype(np.uint8)
img_proc = Image.fromarray(np.uint8(img_proc))

p_proc = plot_img_rgba(img_proc, title="Reconstruction after masking")
p_org = plot_img_rgba(img, title="Original image")

p = column(p_proc, p_org)

show(p)

In [None]:
# plot_path = output_path / "defect_image_reconstructed.html"
# save_plot_from_notbook_for_jekyll(p, plot_path)

In [None]:
diff = ImageChops.difference(img, img_proc)
p = plot_img_rgba(diff)
show(p)

### Alternative Approach

In [None]:
# invert mask
mag_img_mask_inv = ~mag_img_mask

# masking
fshift_proc = fshift * mag_img_mask_inv

# transform back
f_ishift = ifftshift(fshift_proc)
img_proc = ifft2(f_ishift)
img_proc = np.abs(img_proc)

# convert to image
img_proc = (img_proc * 255.0).astype(np.uint8)
img_proc = Image.fromarray(np.uint8(img_proc))

p = plot_img_rgba(img_proc)
show(p)

### Putting everything together

In [None]:
def ft_extract_anomalies(img : Image.Image, mag_thresh : float = 0.5) -> Image.Image:

    # prepare image
    img_np = np.array(img)
    img_np = img_np / 255.0

    # transform to Fourier space
    f = fft2(img_np)
    fshift = fftshift(f)

    # frequency magnitudes
    mag_img = np.log(np.abs(fshift))

    # thresholding
    max_val = mag_img.max()
    thresh_val = mag_img.min() + mag_thresh * (max_val - mag_img.min())
    ret, mag_img_mask = cv2.threshold(mag_img, thresh_val, 1.0, cv2.THRESH_BINARY)

    # masking
    mag_img_mask = mag_img_mask.astype(bool)
    mag_img_mask_inv = ~mag_img_mask

    fshift_proc = fshift * mag_img_mask_inv

    # transform back
    f_ishift = ifftshift(fshift_proc)
    img_proc = ifft2(f_ishift)
    img_proc = np.abs(img_proc)

    # convert to image
    img_proc = (img_proc * 255.0).astype(np.uint8)
    img_proc = Image.fromarray(img_proc)

    return img_proc

In [None]:
img_path = data_path / "test/broken/000.png"

img = Image.open(img_path)
img = img.convert("L")

img_proc = ft_extract_anomalies(img, mag_thresh=0.65)
p = plot_img_rgba(img_proc)
show(p)

## Post Processing

In [None]:
img_proc_np = np.array(img_proc)

max_val = img_proc_np.max()
thresh_val = np.percentile(img_proc_np, 99)
ret, img_thresh = cv2.threshold(img_proc_np, int(thresh_val), 1.0, cv2.THRESH_BINARY)
img_thresh = img_thresh > 0

p = plot_img_scalar(img_thresh.astype(np.uint8))
show(p)

In [None]:
# combine neighboring mask regions
img_morph = morphology.binary_dilation(img_thresh, np.ones([7,7]))

# remove artifacts due to blurring at the edges
img_morph = segmentation.clear_border(img_morph)

# assign label to each connected region
img_lab = measure.label(img_morph) 

p = plot_img_scalar(img_lab)
show(p)

In [None]:
# plot_path = output_path / "defects_label_map.html"
# save_plot_from_notbook_for_jekyll(p, plot_path)

In [None]:
regions = measure.regionprops(img_lab)
# filter out small ones
area_thresh = img_proc_np.shape[0] * img_proc_np.shape[1] * 0.001
defects_bboxes = [reg.bbox for reg in regions if reg.area >= area_thresh]

In [None]:
p = plot_img_rgba(img, title="Original image with defect bounding boxes")
p = add_bboxes_on_img(p, defects_bboxes)

show(p)

In [None]:
# plot_path = output_path / "defects_bboxes.html"
# save_plot_from_notbook_for_jekyll(p, plot_path)

### Putting everything together

In [None]:
def find_bounding_boxes(
    img: Image.Image,
    perc_thresh: int = 99,
    area_thresh: float = 0.001,
    dilation_size: int = 7,
) -> list:
    img_np = np.array(img)

    max_val = img_np.max()
    thresh_val = np.percentile(img_np, perc_thresh)
    ret, img_thresh = cv2.threshold(img_np, int(thresh_val), 1.0, cv2.THRESH_BINARY)
    img_thresh = img_thresh > 0

    # combine neighboring mask regions
    img_morph = morphology.binary_dilation(
        img_thresh, np.ones([dilation_size, dilation_size])
    )

    # remove artifacts due to blurring at the edges
    img_morph = segmentation.clear_border(img_morph)

    # assign label to each connected region
    img_lab = measure.label(img_morph)

    # filter out small regions
    regions = measure.regionprops(img_lab)
    area_thresh = img_proc_np.shape[0] * img_proc_np.shape[1] * area_thresh
    bboxes = [reg.bbox for reg in regions if reg.area >= area_thresh]

    return bboxes

## Defect Detection

In [None]:
def ft_defect_detection(
    img: Image.Image,
    mag_thresh: float = 0.5,
    perc_thresh: int = 99,
    area_thresh: float = 0.001,
    dilation_size: int = 7,
) -> list:
    
    img_proc = ft_extract_anomalies(img, mag_thresh)
    bboxes = find_bounding_boxes(img_proc, perc_thresh, area_thresh, dilation_size)

    return bboxes

In [None]:
# iterator through defect images
path_it = (data_path / "test/broken/").iterdir()

In [None]:
img_path = next(path_it)

img = Image.open(img_path)
img = img.convert("L")

defect_bboxes = ft_defect_detection(img, mag_thresh=0.65)

p = plot_img_rgba(img)
p = add_bboxes_on_img(p, defect_bboxes)
show(p)

In [None]:
# plot_path = output_path / "defects_example.html"
# save_plot_from_notbook_for_jekyll(p, plot_path)