# Filter effects

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy
import scipy
import skimage

import sys

sys.path.append('..')
from tests import test_filter_effects

# Task 1: Depth of field effect

1. Load a source image $s$ from `../data/track.png`.
2. Load a focus mask $f$ from `../data/track_mask.png`.
3. Blur the source image $s$ using a Gaussian filter with appropriate $\sigma_s$ to produce blurred image $b$.
4. Replace nonzero values in $f$ by a vertical linear ramp going from zero (closest to zero values) to one (farthest from zero values). You can use [distance transform](https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html). We'll denote the result as weights $w$.
5. Combine (blend) the source image $s$ and the blurred version $b$ by using the weights $w$ computed in point 4 as
   $$
   d = w \cdot s + (1 - w) \cdot b
   $$

<figure class="image">
  <img src="../figures/filter_effects-expected_dof_outputs.png" alt="" style="width: 12.8in;"/>
  <figcaption>Figure 1: Expected outputs of the fake depth of field effect</figcaption>
</figure>

In [None]:
def depth_of_field(img: np.ndarray, mask: np.ndarray, sigma: float = 0.0) -> np.ndarray:
    """
    Args:
        img (np.ndarray): Input image of shape (H, W, C) or (H, W).
        mask (np.ndarray): Binary mask of shape (H, W) where 1 indicates focused areas and 0 indicates unfocused areas.
            The transition between focused and unfocused areas is smooth, controlled by a Gaussian blur.
        sigma (float): Standard deviation for Gaussian kernel used for blurring. Higher values result in more blur.
    Returns:
        np.ndarray: Image with depth-of-field effect applied, same shape as input image.
    """
    ########################################
    # TODO: implement

    raise NotImplementedError

    # ENDTODO
    ########################################

    return dof

In [None]:
rgb = skimage.util.img_as_float(skimage.io.imread('../data/track.jpg'))
print(rgb.shape, rgb.dtype, rgb.min(), rgb.max())
mask = skimage.util.img_as_float(skimage.io.imread('../data/track_mask.png'))[..., 0]
print(mask.shape, mask.dtype, mask.min(), mask.max())

In [None]:
dof = depth_of_field(rgb, mask, sigma=6.)
dof.shape, dof.dtype, dof.min(), dof.max()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8, 6))
axes[0].imshow(rgb);
axes[1].imshow(dof);

In [None]:
test_filter_effects.TestDepthOfField.eval(depth_of_field_fn=depth_of_field)

# Task 2: Jitter effect

- Jitter effect is an example of *nonhomogennous filter*, which means that for each image position, the filter kernel is different.
- The kernel is a matrix of size $K \times K$ with zeros everywhere except a single randomly placed unit impulse.
- For example, for $K = 5$, the kernel might be (notice the single "1" at top right)
  $$
  \boldsymbol{h} = \begin{bmatrix}
    0 & 0 & 0 & 1 & 0 \\
    0 & 0 & 0 & 0 & 0 \\
    0 & 0 & 0 & 0 & 0 \\
    0 & 0 & 0 & 0 & 0 \\
    0 & 0 & 0 & 0 & 0
  \end{bmatrix}
  $$
- The position of the "1" is uniformly and independently random for each pixel, i.e. it can be anywhere in the $K \times K$ kernel with equal probability.
- Size $K$ of the kernel is fixed to the same value for all pixels.

<figure class="image">
  <img src="../figures/filter_effects-expected_jitter_outputs.png" alt="" style="width: 6.4in;"/>
  <figcaption>Figure 2: Expected output of jitter effect</figcaption>
</figure>

**Task**
1. Load the source image `../data/fruits.jpg`.
2. Implement `jitter_filter` function.
3. Apply the function to the source image.

In [None]:
rgb = skimage.io.imread('../data/fruits.jpg')
rgb = skimage.util.img_as_float(rgb)

In [None]:
plt.imshow(rgb);

In [None]:
def jitter_filter(img: np.ndarray, jitter: int = 1) -> np.ndarray:
    """
    Args:
        img: input image (RGB)
        jitter: jitter strength and half of the kernel_size, i.e. 2*jitter+1=kernel_size
    Returns:
        out: image processed by the jitter filter of strength `jitter`
    """
    ########################################
    # TODO: implement

    raise NotImplementedError

    # ENDTODO
    ########################################

In [None]:
rgb_jit = jitter_filter(rgb, jitter=10)

In [None]:
plt.imshow(rgb_jit);

In [None]:
test_filter_effects.TestJitterEffect.eval(jitter_filter_fn=jitter_filter)