# Statistical Pattern Recognition - Exercise 10: Inference in graphical models

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.pyplot import imshow
import time
%matplotlib inline


## $\star$ $\star$  Part 1: Message passing for disparity estimation

Implement the message passing method in the disparity estimation setting. Each pixel can take 16 labels (disparities in the Tsukuba pair range from 0 to 15).

For the unary cost take the Euclidean distance between the pixel colors in the two images. For each pixel in the left image, you must compute 16 costs corresponding to the 16 disparity options.

Consider the image as an undirected graph, where all pixels are connected only to their direct left and right neighbor. For the pairwise potential use the Potts model (`0` cost if the disparity is the same, `LAMBDA` else). 

Run the message passing algorithm to find the solution with the highest probability. Visualize this result and play with the parameter `LAMBDA`.


### Load and visualize input images


In [None]:
imgL = np.array(Image.open("../data/tsukubaL.ppm"))
imgR = np.array(Image.open("../data/tsukubaR.ppm"))


In [None]:
plt.title("ImageL")
plt.imshow(imgL)
plt.axis('off')
plt.show()
plt.title("ImageR")
plt.imshow(imgR)
plt.axis('off')


### Disparity estimation via message passing


In [None]:
MAX_DISP = 16
LAMBDA = 100.0

class Timer():
    def __init__(self):
        self.start = None
        self.end = None

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.end = time.time()
        print('elasped secs: {:.2f}'.format(self.end - self.start))


def compute_unary_costs(imgL, imgR, max_disp=MAX_DISP):
    """
    Computes unary costs for disparity matching between two images.

    Args:
        imgL: Left image as numpy array with shape (h,w,3).
        imgL: Right image as numpy array with shape (h,w,3).
        max_disp: Maximal considered disparity.

    Returns:
        Unary costs for all considered disparities as numpy array of shape (h,w,max_disp).
    """
    h, w, _ = imgL.shape
    cost = np.zeros((h,w, max_disp))
    # START TODO ################
    # Compute unary costs
    raise NotImplementedError
    # END TODO ################
    return cost


def potts_model(d1, d2, l=LAMBDA):
    """
    Computes the potts model cost for a pair of disparity values.
    """
    if d1 == d2:
        return 0.
    else:
        return l


def compute_msg_fwd(unary_costs, max_disp=MAX_DISP):
    """
    Computes pairwise costs for disparity matching via message passing in the forward direction.

    Returns:
        Pairwise costs in the forward direction for all considered disparities as numpy array
        of shape (h,w,max_disp).
    """
    h, w , _ = unary_costs.shape
    msgs = np.zeros((h,w, max_disp))
    # START TODO ################
    # Compute forward messages
    raise NotImplementedError
    # END TODO ################
    return msgs


def compute_msg_bwd(unary_costs, max_disp=MAX_DISP):
    """
    Computes pairwise costs for disparity matching via message passing in the backward direction.

    Returns:
        Pairwise costs in the backward direction for all considered disparities as numpy array
        of shape (h,w,max_disp).
    """
    h, w , _ = unary_costs.shape
    msgs = np.zeros((h,w, max_disp))
    # START TODO ################
    # Compute backward messages
    raise NotImplementedError
    # END TODO ################
    return msgs


def compute_disparity(imgL, imgR):
    """
    Computes the disparity between two images via message passing.

    Args:
        imgL: Left image as numpy array with shape (h,w,3).
        imgL: Right image as numpy array with shape (h,w,3).

    Returns:
        Disparity between two images as numpy array with shape (h,w).
    """
    h , w,_ = imgL.shape
    imgL = imgL.astype(np.float32)
    imgR = imgR.astype(np.float32)

    print("Precomputing data costs:")
    with Timer():
        unary_costs = compute_unary_costs(imgL, imgR)

    print("Precomputing forward messages:")
    with Timer():
        msg_fwd = compute_msg_fwd(unary_costs)
    print("Precomputing backward messages:")
    with Timer():
        msg_bwd = compute_msg_bwd(unary_costs)

    print("Computing disparity map:")
    with Timer():
        cost = unary_costs + msg_fwd + msg_bwd
        disparity_map = cost.argmin(axis=-1)
    return disparity_map


def show_disparity_map(disparity_map):
    plt.imshow(disparity_map)
    plt.colorbar()
    plt.axis("off")


### Compute disparity map without belief propagation


In [None]:
data_cost = compute_unary_costs(imgL, imgR)
disparity_map = data_cost.argmin(axis=-1)
show_disparity_map(disparity_map)


### Now compute disparity map with belief propagation


In [None]:
disparity_map = compute_disparity(imgL, imgR)
show_disparity_map(disparity_map)


### Some things to explore further

* Vectorize the solution and measure speedup or even implement in it C/C++. 
* Play around  with regularization term. 
* Play around with the data term.
