Thais Lima de Sousa

# Efficient Graph-Based Image Segmentation
## Falzenszwalb e Huttenlocher's algorithm

In [2]:
from vpi.io import *
from vpi.filters import convolve, normalize
import numpy as np
import sys
import networkx as nx

In [21]:
# gaussian filter for image smoothing

def create_gauss_kernel(sigma):
    x = np.arange(- 4*sigma, 4*sigma + 1)
    k = np.zeros((1, x.size))
    length = int(np.ceil(8*sigma) + 1
    for i in range(length):
        k[0][i] = np.exp(- x[i]**2 / 2*(sigma**2)) / (2*np.pi*(sigma**2))
    k = k / k.sum()

    return k

def smooth(img, sigma=0.8):
    kernel = create_gauss_kernel(sigma)  
    g = convolve(img, kernel)
    g = convolve(g, np.transpose(kernel))
    return g

In [5]:
# disjoint-set forest with union rank and path compression

class UF:
    num = 0

    def __init__(self, n):
        num = n
        self.parent = list(range(n))
        self.rank = [0 for i in range(n)]
        self.size = [1 for i in range(n)]
    
    def find(self, v):
        w = v
        while(w != self.parent[w]):
            w = self.parent[w]
        self.parent[v] = w
        return w
    
    def join(self, x, y):
        if self.rank[x] > self.rank[y]:
            self.parent[y] = x
            self.size[x] += self.size[y]
        
        else:
            self.parent[x] = y
            self.size[y] += self.size[x]
            if self.rank[x] == self.rank[y]:
                self.rank[y]++
        num--
        
    def size(self, x):
        return self.size[x]
        

In [None]:
# graph segmentation

def threshold(size, c):
    return c/size

def segment_graph(G, k):
    V = len(G)
    E = G.size()
    
    u = UF(V)
    edges = list(G.edges_iter(data='weight'))
    sortedEdges = edges.sort(key=lambda w:w[2])
    
    # init thresholdings
    thresholding = V*[None]
    for i in range(V):
        thresholding[i] = threshold(1.0, c)
        
    for i in range(V):
        pedge = sortedEdges[i]
        # components connected by this edge
        a = u.find(pedge[0])
        b = u.find(pedge[1])
        
        if not a == b:
            if pedge[2] <= thresholding[a] and pedge[2] <= thresholding[b]:
                u.join(a, b)
                a = u.find(a)
                thresholding[a] = pedge[2] + threshold(u.size(a), c)
                
    return u   
    
    

In [11]:
def diff(r, g, b, x1, y1, x2, y2):
    return np.sqrt(np.square(r[x1, y1] - r[x2, y2])
                  + np.square(g[x1, y1] - g[x2, y2])
                  + np.square(b[x1, y1] - b[x2, y2]))

def segment_image(im, sigma, k, min_size, num_ccs):
    H, W = im.shape[0:2]
    r = smooth(im[:,:,0])
    g = smooth(im[:,:,1])
    b = smooth(im[:,:,2])
    
    # build weighted edges, 8-connected
    G = nx.Graph()
    
    for y in range(H):
        for x in range(W):
            if x < W - 1: G.add_edge(y*W + x, y*W + (x+1), weight = diff(r, g, b, x, y, x+1, y))
            if y < H - 1: G.add_edge(y*W + x, (y+1)*W + x, diff(r, g, b, x, y, x, y+1))
            if (x < W - 1) and (y < H - 1): G.add_edge(diff(y*w + x, (y+1)*W + (x+1), diff(r, g, b, x, y, x+1, y+1)))
            if (x < W - 1) and (y > 0): G.add_edge(y*W + x, (y-1)*W + (x + 1), diff(r, g, b, x, y, x+1, y-1))
    
    u = segment_graph(G, k)
    
    # post process small components
    num = G.size()
    
    for i in range(num)
    
    

In [None]:
class InputException(Exception):
    def __init__(self, msg):
        self.msg = msg
    def __str__(self):
        return self.msg


# read input

def main:
    if len(sys.argv) != 6: raise InputException('usage: %s sigma k min_size input_name output_name' % sys.argv[0])
    sigma = float(sys.argv[1])
    k = float(sys.argv[2])
    min_size = int(sys.argv[3])
    input = read_image(sys.argv[4])
    
    num_ccs = 0    
    # save output and display segmentation
    # output = segment_image(input, sigma, k, min_size, num_ccs)

        
        

if __name__ == '__main__':
    main()

In [23]:
# f = read_image("img1.png")
# print(smooth(f[:,:,0]))

T = nx.Graph()
T.add_edge(0, 1, weight=1)
T.add_edge(1, 2, weight=2)
T.add_edge(0, 2, weight=0.5)
edges = list(T.edges_iter(data='weight'))

for e in edges:
    print(e)
    
sortedEdges = edges.sort(key=lambda w:w[2])

print('sort')
for e in edges:
    print(e)

(0, 1, 1)
(0, 2, 0.5)
(1, 2, 2)
sort
(0, 2, 0.5)
(0, 1, 1)
(1, 2, 2)
