## Demosaicing

In this assigment, we are going to 'demosaic' an image encoded with the Bayer Pattern. There are some cameras that use the Bayer Pattern in order to save an image. Using this encoding only 50% of green pixels, 25% of red pixels and 25% of blue pixels are kept. The Bayer encoding takes a RBG image and encodes it as in the bellow image. 
<img src="bayer_patterns.PNG" alt="Drawing" style="heigth: 300px;"/>




In this lab, we are going to 'demosaic' an encoded image in the **RGGB** pattern.   
<img src="bayer_rggb.PNG" alt="Drawing" style="width: 300px;"/>



We will implement a very simple algorithm which, for each pixel, fills in the two missing channels by averaging the values of their nearest neighbors (1, 2 or 4) in the corresponding channel.  
<img src="interpolation.PNG" alt="Drawing" style="width: 500px;"/>


To complete this task, we have to do:
- read the encoded image (crayons_mosaic.bmp)
- recreate the green, red and blue channel by copying the values into the corresponding positons of each channel
- interpolate the missing values in each channel

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv

In [None]:
# read encoded image
encoded_image = plt.imread("crayons_mosaic.bmp").astype(np.float32)
plt.imshow(np.uint8(encoded_image), cmap='gray')
print('encoded_image shape = {}'.format(encoded_image.shape))

In [None]:
def split_r_b_g(encoded_image):
    """
    This function takes the encoded image and returns 
    R, G and B channels with the corresponding values taken from the encoded image.
    The image was encoded using the following pattern:
    R G
    G B
    The encoded image looks like:
    R G R G
    G B G B
    R G R G
    G B G B
    """
    height, width = encoded_image.shape
    r_channel = np.zeros((height, width), np.float32)
    g_channel = np.zeros((height, width), np.float32)
    b_channel = np.zeros((height, width), np.float32)
    # TODO: copy the corresponding pixels
    r_channel[0::2,0::2]=encoded_image[0::2,0::2]
    return r_channel, g_channel, b_channel

In [None]:
r_channel, g_channel, b_channel = split_r_b_g(encoded_image)
color_image = np.stack((r_channel, g_channel, b_channel), axis=2)
print(color_image.shape)
plt.imshow(np.uint8(color_image))
print(encoded_image[:5,:5])
print(r_channel[:5,:5])

In [None]:
# print some values for the red channel:
print(r_channel[0:6, 0:6])

In [None]:
# print some values for the green channel:
print(g_channel[0:6, 0:6])

In [None]:
# print some values for the blue channel:
print(b_channel[0:6, 0:6])

In [None]:
# define 4 types of interpolation based on the shape of neighboring pixels

def interpolate_4_points_plus(channel, i, j):
    # check boundaries
    if i - 1 < 0 or j - 1 < 0 or i + 1 >= channel.shape[0] or j + 1 >= channel.shape[1]:
        return    
    channel[i, j] = np.mean([channel[i - 1, j], channel[i, j + 1], channel[i + 1, j], channel[i, j - 1]])
    
def interpolate_4_points_diag(channel, i, j):
    # check boundaries
    if i - 1 < 0 or j - 1 < 0 or i + 1 >= channel.shape[0] or j + 1 >= channel.shape[1]:
        return    
    channel[i, j] = np.mean([channel[i - 1, j - 1], channel[i - 1, j + 1], channel[i + 1, j + 1], channel[i + 1, j - 1]])
    
def interpolate_2_points_horizontal(channel, i, j):
    if j - 1 < 0 or j + 1 >= channel.shape[1]:
        return
    channel[i, j] = np.mean([channel[i, j - 1], channel[i, j + 1]])
        
def interpolate_2_points_vertical(channel, i, j):
    if i - 1 < 0 or i + 1 >= channel.shape[0]:
        return
    channel[i, j] = np.mean([channel[i - 1, j], channel[i + 1, j]])


In [None]:
def interpolate_red_channel(red_channel):
    """
    The red channel looks like:
    R 0 R 0 R 0 R 0
    0 0 0 0 0 0 0 0
    R 0 R 0 R 0 R 0
    0 0 0 0 0 0 0 0
    """
    interpolated_red_channel = red_channel.copy()
    height, width = interpolated_red_channel.shape
    # TODO: interpolate the points on diagonal
    for i in range(1,height,2):
        for j in range(1,width,2):
            interpolate_4_points_diag(interpolated_red_channel, i, j)
    """
    Now red channel looks like:
    R 0 R 0 R 0 R 0
    0 R 0 R 0 R 0 0
    R 0 R 0 R 0 R 0
    0 0 0 0 0 0 0 0
    """
    # TODO: interpolate horizontal  
    
    # TODO: interpolate vertical  
    
    # 'interpolate' last line and last column
    interpolated_red_channel[height - 1] = interpolated_red_channel[height - 2]
    interpolated_red_channel[:, width - 1] = interpolated_red_channel[:, width - 2]
    return interpolated_red_channel
     

In [None]:
def interpolate_blue_channel(blue_channel):
    """
    The blue channel looks like:
    0 0 0 0 0 0 0 0
    0 B 0 B 0 B 0 B
    0 0 0 0 0 0 0 0
    0 B 0 B 0 B 0 B
    """
    interpolated_blue_channel = blue_channel.copy()
    height, width = interpolated_blue_channel.shape
    # TODO: interpolate the points on diagonal
    """
    Now the blue channel looks like:
    0 0 0 0 0 0 0 0
    0 B 0 B 0 B 0 B
    0 0 B 0 B 0 B 0
    0 B 0 B 0 B 0 B
    """
    # TODO: interpolate horizontal  
    
    # TODO: interpolate vertical  
    
    # 'interpolate' first line and first column
    interpolated_blue_channel[0] = interpolated_blue_channel[1]
    interpolated_blue_channel[:, 0] = interpolated_blue_channel[:, 1]
    
    return interpolated_blue_channel

In [None]:
def interpolate_green_channel(green_channel):
    """
    The green channel looks like:
    0 G 0 G 0 G 0 G
    G 0 G 0 G 0 G 0
    0 G 0 G 0 G 0 G
    G 0 G 0 G 0 G 0
    By now, we are going to ignore the margins.
    """
    interpolated_green_channel = green_channel.copy()
    height, width = interpolated_green_channel.shape
    # TODO: interpolate the points on 'plus' 
    
    # TODO: interpolate the points on 'plus' 
            
    return interpolated_green_channel

In [None]:
interpolated_red_channel=interpolate_red_channel(r_channel)
print(interpolated_red_channel[:5,:5])

In [None]:
cv.imwrite('color_image.png', color_image[:, :, [2, 1, 0]])
print(color_image.dtype)

In [None]:
plt.imshow(np.uint8(color_image))

In [None]:
I=encoded_image.copy()
m,n=I.shape
I_padded=np.zeros((m+2,n+2))
I_padded[1:m+1,1:n+1]=I.copy()
I_padded[:,0]=I_padded[:,2]
I_padded[:,n+1]=I_padded[:,n-1]
I_padded[0,:]=I_padded[2,:]
I_padded[m+1,:]=I_padded[m-1,:]
print(I_padded[:10,:10])

In [None]:
#intrpolation for diagonal
X=(I_padded[:-2,:-2]+I_padded[2:,:-2]+I_padded[:-2,2:]+I_padded[2:,2:])/4
#interpolation for horizontal
H=(I_padded[1:-1,:-2]+I_padded[1:-1,2:])/2
#interpolation for vertical
V=(I_padded[:-2,1:-1]+I_padded[2:,1:-1])/2
#interpolation for plus
P=(H+V)/2

In [None]:
C=np.zeros((m,n,3))
#copy red channel
C[::2,::2,0]=I[::2,::2]
#recover from horizontal
C[0::2,1::2,0]=H[0::2,1::2]
#recover from vertical
C[1::2,::2,0]=V[1::2,::2]
#recover from diagonal
C[1::2,1::2,0]=X[1::2,1::2]
#green channel
C[1::2,0::2,1]=I[1::2,0::2]
C[0::2,1::2,1]=I[0::2,1::2]
#recover green from plus
C[0::2,0::2,1]=P[0::2,0::2]
C[1::2,1::2,1]=P[1::2,1::2]
#blue channel
C[1::2,1::2,2]=I[1::2,1::2]

#recover from horizontal
C[1::2,0::2,2]=H[1::2,0::2]
#recover from vertical
C[0::2,1::2,2]=V[0::2,1::2]
#recover from diagonal
C[0::2,0::2,2]=X[0::2,0::2]

In [None]:
cv.imwrite('result.png',C[:, :, [2, 1, 0]])