In [1]:
from jupyterthemes import jtplot
jtplot.style(theme='chesterish')

In [69]:
import numpy as np
from PIL import Image, ImageDraw
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import math

In [28]:
"""
Corners Dictionary maps the every image to the bounding box coordinates: x0, y0, x1, y1
"""
image_bounding_box = {
    "IMG_0854.jpg": [879,1061,295,568],
    "IMG_0855.jpg": [127,669,846,1300],
    "IMG_0856.jpg": [381,747,888,1196],
    "IMG_0857.jpg": [349,577,748,1068],
    "IMG_0858.jpg": [339,630,681,1062],
    "IMG_0859.jpg": [383,670,816,1170],
    "IMG_0860.jpg": [833,639,1383,1232],
    "IMG_0861.jpg": [965,611,1373,1049],
    "IMG_0862.jpg": [926,712,1336,1124],
    "IMG_0863.jpg": [1190,1088,610,438],
}
data_dir = "data/image_segmentation_dataset/"

In [122]:
def draw_rectangle(img_path, rect_coordinates):
    """
    Rect co-ordinates in [x0, y0, x1, y1]
    """
    im = Image.open(img_path)
    draw = ImageDraw.Draw(im)
    draw.rectangle(rect_coordinates)
    im.show()
    
def get_dataset(data_dir):
    """
    Returns:
        colors_box        : N1x3 RGB data within bounding box
        colors_background : N2x3 RGB data for background
    """
    colors_box = []
    colors_background = []
    for img_name, box in tqdm(image_bounding_box.items()):
        image_path = os.path.join(data_dir, img_name)
        img = Image.open(image_path)
        draw_rectangle(image_path, box)
        values = img.load()
        for i in range(im.size[0]): #width x-coordinates
            for j in range(im.size[1]): #height y-coordinates
                if i >= box[0] and i <= box[2] and j >= box[1] and j <= box[3]:
                    colors_box.append(values[i, j])
                else:
                    colors_background.append(values[i, j])
                    
    return np.asarray(colors_box), np.asarray(colors_background)


In [123]:
# Generate the dataset for computation
colors_box, colors_bg = get_dataset(data_dir)
print(colors_box.shape, colors_bg.shape)

100%|██████████| 10/10 [00:15<00:00,  1.59s/it]


(1761694, 3) (19263746, 3)


In [118]:
# Optimize this function if memory errors or taking too long
def covariance_over_N(values):
    V = np.array([np.outer(x, x) for x in values])
    V = np.sum(V, axis=0)
    V = V/len(values)
    return V

In [119]:
# CAREFUL BEFORE RERUNNING THIS: TAKES LONG.
# Find mean and variance for colors_bg and colors_box
mean_box, mean_bg = np.mean(colors_box, axis=0), np.mean(colors_bg, axis=0)
covariance_box, covariance_bg = covariance_over_N(colors_box), covariance_over_N(colors_bg)

In [130]:
# On Test Time, we need to call only this function
def get_likelihood(c, mean, covariance):
    c_minus_mean = c - mean
    exp_value = math.exp((-1./2)*(c_minus_mean.T.dot(np.linalg.inv(covariance)).dot(c_minus_mean)))
    likelihood = (1./(math.pow(math.pi*2,3/2)*math.pow(np.linalg.det(covariance), 1/2)))*exp_value
    return likelihood
    
def predict(c, mean_ball, covariance_ball, mean_bg, covariance_bg):
    """
    Given an RGB pixel value, predict wether the color is of the ball or not.
    This is currentl based on the assumption that the ball color is a Gaussian random
    variable with the estimated mean and covariance provided in the parameters.
    Params:
        c               : 3x1 vector of color values of a pixel
        mean_ball       : 3x1 vector of mean ball color
        covariance_ball : 3x3 matrix of covariance of ball
        mean_bg         : 3x1 vector of mean background color
        covariance_bg   : 3x3 matrix of covariance of background
    Returns:
        is_ball : True if ball else False
    """
    likelihood_ball = get_likelihood(c, mean_ball, covariance_ball)
    likelihood_bg = get_likelihood(c, mean_bg, covariance_bg)
    neuman_test = likelihood_ball/likelihood_bg
    return True if neuman_test >= 1 else False

In [115]:
c = [192, 161, 122]
print(mean_box, covariance_box)
print(mean_bg, covariance_bg)
predict(c, mean_ball=mean_box, mean_bg=mean_bg, covariance_ball=covariance_box, covariance_bg=covariance_bg)

[192.89465878 161.11108058 122.47488043] [[40903.41178718 34572.39653652 26573.13959519]
 [34572.39653652 29438.7280958  22785.92881511]
 [26573.13959519 22785.92881511 17905.51499296]]
[98.19560666 86.25705525 72.70906489] [[15605.9879892  13563.11986968 11125.28629343]
 [13563.11986968 12047.85838429 10141.88069948]
 [11125.28629343 10141.88069948  8919.89747747]]
1.7096645440771482e-06
1.7200743981727961e-06


False

In [132]:
im = Image.open(os.path.join(data_dir, "IMG_0862.jpg"))
pixels = im.load() # create the pixel map
for i in range(im.size[0]): # for every pixel:
    for j in range(im.size[1]):
        if predict(c=pixels[i,j], mean_ball=mean_box, mean_bg=mean_bg, covariance_ball=covariance_box, covariance_bg=covariance_bg):
            pixels[i,j] = (0, 0 ,0)
im.show()
