In [1]:
%matplotlib widget
import array
import imageio
import matplotlib.pyplot as plt
import math
import numpy as np
import random
import sys
import time
from scipy.stats import wasserstein_distance
from PIL import Image

In [2]:
def clamp(val, min_val, max_val):
    return min(max_val, max(val, min_val))

def build(data):
    # Need to build our data structure
    print("Building data structure")
    c2c = {}
    start = time.time()
    xmin = 0; ymin = 0
    xmax = 0; ymax = 0
    i = 0
    while i < len(data):
        x = math.floor(data[i])
        y = math.floor(data[i + 1])
        r = clamp(data[i + 2], 0.0, 1.0)
        g = clamp(data[i + 3], 0.0, 1.0)
        b = clamp(data[i + 4], 0.0, 1.0)
        xmin = math.floor(min(xmin, x))
        xmax = math.ceil(max(xmax, x))
        ymin = math.floor(min(ymin, y))
        ymax = math.ceil(max(ymax, y))        
        if (x, y) not in c2c:
            # Create new tuple
            # The structure is:
            # (X/Y coordinate pair, corresponding and the RGB color value)
            c2c[(x,y)] = [(r, g, b)]
        else:
            # Entry exists. Update bucket
            #print("Append %s to %d %d" % ((r,g,b), x, y))
            c2c[(x,y)].append((r, g, b))
        i += 6
    end = time.time()
    print("Done building data structure (%d)" % (end - start))
    return c2c, xmin, xmax, ymin, ymax

with open("cbox_dump", "rb") as binary_file:
    data = np.fromfile(binary_file, dtype=np.float32)
    print("Size of data array:", len(data))
    print("Number of samples:", int(len(data) / 6))
    c2c, x_min, x_max, y_min, y_max = build(data)
    print("Number of keys:", len(c2c))
    print("Smallest X coordinate:", x_min)
    print("Largest X coordinate:", x_max)
    print("Smallest Y coordinate:", y_min)
    print("Largest Y coordinate:", y_max)

Size of data array: 138215424
Number of samples: 23035904
Building data structure
Done building data structure (224)
Number of keys: 640002
Smallest X coordinate: 0
Largest X coordinate: 800
Smallest Y coordinate: 0
Largest Y coordinate: 799


In [7]:
# A naive implementation of a box filter
def box_filter(x_coord, y_coord, samples):
    # Iterate over all samples and average
    # each channels values
    r = 0.0
    g = 0.0
    b = 0.0
    num_samples = len(samples)
    for sample in samples:
        r += sample[0]
        g += sample[1]
        b += sample[2]
    # Average up each channel and convert to 8-bit colorspace
    # Also gamma correcting
    gamma = 1 / 2.2
    r_avg = int(((r / num_samples) ** gamma) * 255)
    b_avg = int(((b / num_samples) ** gamma) * 255)
    g_avg = int(((g / num_samples) ** gamma) * 255)
    return (r_avg, g_avg, b_avg) 

def reconstruct(filename):
    output = Image.new("RGB", (x_max, y_max))
    pixels = output.load()
    for x_coord in range(x_min + 1, x_max - 1):
        for y_coord in range(y_min + 1, y_max - 1):
            c1 = c2c[(x_coord - 1, y_coord - 1)]
            c2 = c2c[(x_coord, y_coord - 1)]
            c3 = c2c[(x_coord + 1, y_coord - 1)]
            c4 = c2c[(x_coord - 1, y_coord)]
            c5 = c2c[(x_coord + 1, y_coord)]
            c6 = c2c[(x_coord - 1, y_coord + 1)]
            c7 = c2c[(x_coord, y_coord + 1)]
            c8 = c2c[(x_coord + 1, y_coord + 1)]
            samples = c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8
            pixels[x_coord, y_coord] = box_filter(x_coord, y_coord, samples)
    output.save(filename)

recon = True
if recon:
    start = time.time()
    reconstruct("reconstructed.png")
    end = time.time()
    print("Finished (%d)" % (end - start))

Finished (43)


In [8]:
def create_hist(x, y):
    # Grab intensity values around given pixel
    # Should/Could adjust this to be a more random
    # distribution.
    c1 = c2c[(x - 1, y - 1)]
    c2 = c2c[(x, y - 1)]
    c3 = c2c[(x + 1, y - 1)]
    c4 = c2c[(x - 1, y)]
    c5 = c2c[(x + 1, y)]
    c6 = c2c[(x - 1, y + 1)]
    c7 = c2c[(x, y + 1)]
    c8 = c2c[(x + 1, y + 1)]
    samples = c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8
    hist, bins = np.histogram(samples, 256, [0,256], density=True)
    return ((x, y), hist, bins)

def onclick(event):
    x, y = math.floor(event.xdata), math.floor(event.ydata)
    # Generate random coordinates and create histograms
    _, hist_a, bins_a = create_hist(x, y)
    pixel_boundary = 20
    num_candidates = 10
    histograms = []
    for i in range(num_candidates):
        # Can't click pixels that are too close to image borders
        rand_x = random.randint(x - pixel_boundary, x + pixel_boundary)
        rand_y = random.randint(y - pixel_boundary, y + pixel_boundary)
        histograms.append(create_hist(rand_x, rand_y))
    
    # Calculate Wasserstein distance from every sampled pixel to our clicked pixel
    distances = [wasserstein_distance(hist_a, hist[1]) for hist in histograms]
    wd_idx = np.argmin(distances)
    _, hist_b, bins_b = histograms[wd_idx]

    # Plot two histograms
    hist_ax_a.clear()
    hist_ax_a.bar(bins_a[:-1], hist_a, width=np.diff(bins_a), ec="k", align="edge")
    hist_ax_a.set_title("Recon. Pixel Histogram")
    
    hist_ax_b.clear()
    hist_ax_b.bar(bins_b[:-1], hist_b, width=np.diff(bins_b), ec="k", align="edge")
    hist_ax_b.set_title("Histogram Of Pixel With Smallest WD")
    
    # Plot pixels that were picked
    for coordinate, _, _ in histograms:
        r_axis.scatter(coordinate[0], coordinate[1], c='blue', s=0.5)
        r_axis1.scatter(coordinate[0], coordinate[1], c='blue', s=0.5)
    # Plot clicked and chosen pixels in different color
    r_axis.scatter(x, y, c='red', s=0.5)
    r_axis1.scatter(x, y, c='red', s=0.5)
    r_axis.scatter(histograms[wd_idx][0][0], histograms[wd_idx][0][1], c='white', s=0.5)
    r_axis1.scatter(histograms[wd_idx][0][0], histograms[wd_idx][0][1], c='white', s=0.5)
    r_fig.suptitle("Wasserstein Distance: %f" % (distances[wd_idx]))
    
    # Plot color values of clicked and selected pixels
    p_axis.clear()
    norm_rgb = img[y][x] / 255
    p_axis.scatter(1, 1, c=[norm_rgb], s=500.0)
    
    s_x, s_y = histograms[wd_idx][0]
    norm_rgb = r_img[s_y][s_x] / 255
    
    p_axis.scatter(1.5, 1, c=[norm_rgb], s=500.0)
    p_fig.suptitle("Clicked RGB: %s - Selected RGB: %s" % (img[y][x], r_img[s_y][s_x]))
    plt.show()
    
r_fig, (r_axis, r_axis1) = plt.subplots(1, 2, dpi=150, sharey=True)
img = imageio.imread("reconstructed_cbox.png")
r_img = imageio.imread("reconstructed.png")
imgplot = r_axis.imshow(img)
imgplot1 = r_axis1.imshow(r_img)
r_fig.canvas.mpl_connect("button_press_event", onclick)

# Setup plot for histograms for two corresponding images
hist_fig, (hist_ax_a, hist_ax_b) = plt.subplots(1, 2, sharey=True)

# Setup plot for displaying pixel color values
p_fig, p_axis = plt.subplots()

FigureCanvasNbAgg()

FigureCanvasNbAgg()

FigureCanvasNbAgg()

In [43]:
# Some constants
patch_w = 7
pm_iters = 5
rs_max = sys.maxsize
use_wd = False
def dist(a, b, ax, ay, bx, by, cutoff=sys.maxsize):
    ans = 0
    for dy in range(0, patch_w):
        arow = a[ax + dy]
        brow = b[bx + dy]
        for dx in range(0, patch_w):
            ac = arow[dx]
            bc = brow[dx]
            print("ac", ac)
            print("bc", bc)
            dr = (ac[0] & 255) - (bc[0] & 255);
            dg = ((ac[1] >> 8) & 255) - ((bc[1] >> 8) & 255);
            db = (ac[2] >> 16) - (bc[2] >> 16);
            ans += dr * dr + dg * dg + db * db
            print("ar: %d ag: %d ab: %d" % ((ac[0]&255), ((ac[1]>>8)&255), (ac[2]>>16)))
            print("br: %d bg: %d bb: %d" % ((bc[0]&255), ((bc[1]>>8)&255), (bc[2]>>16)))
            print("dr: %d dg: %d db: %d" % (dr, dg, db))
            print("ans", ans)
            raise KeyboardInterrupt
        if ans >= cutoff:
            return cutoff
    return ans

def dist_wd(a, b, ax, ay, bx, by, cutoff=sys.maxsize):
    print("Implement")

def improve_guess(a, b, ax, ay, xbest, ybest, dbest, bx, by):
    # Put a switch in here to determine which distance
    # function to use
    d = 0.0
    if use_wd:
        d = dist_wd(a, b, ax, ay, bx, by, dbest)
    else:
        d = dist(a, b, ax, ay, bx, by, dbest)
    if d < dbest:
        dbest = d
        xbest = bx
        ybest = by
        print("Improved guess: (%d, %d, %d)" % (dbest, xbest, ybest))
    return dbest, xbest, ybest

def xy_to_int(x, y):
    return (int(y) << 12) | int(x)

def int_to_x(v):
    return int(v[0]) & ((1 << 12) - 1)

def int_to_y(v):
    return int(v[0]) >> 12

def patchmatch(a, b, ann, annd):
    # Get effective width and height of images
    aew = a.shape[1] - patch_w + 1
    aeh = a.shape[0] - patch_w + 1
    bew = b.shape[1] - patch_w + 1
    beh = b.shape[0] - patch_w + 1
    for ay in range(0, aeh):
        for ax in range(0, aew):
            #bx = np.random.randint(sys.maxsize) % bew
            #by = np.random.randint(sys.maxsize) % beh
            bx = 607
            by = 460
            print("bx", bx)
            print("by", by)
            ann[ay][ax] = xy_to_int(bx, by)
            if use_wd:
                annd[ay][ax] = dist_wd(a, b, ax, bx, by)
            else:
                annd[ay][ax] = dist(a, b, ax, ay, bx, by)
    #for iter in range(0, pm_iters):
    #for iter in range(0, pm_iters):
    #    ystart = 0; yend = aeh; ychange = 1
    #    xstart = 0; xend = aew; xchange = 1
    #    if iter % 2 == 1:
    #        xstart = xend - 1; xend = -1; xchange = -1
    #        ystart = yend - 1; yend = -1; ychange = -1
    #    ay = ystart
    #    ax = xstart
    #    while ay != yend:
    #        while ax != xend:
    #            v = ann[ay][ax]
    #            xbest = int_to_x(v)
    #            ybest = int_to_y(v)
    #            dbest = annd[ay][ax]
    #            
    #            # Improve current guess
    #            if ax - xchange < aew:
    #                vp = ann[ay][ax - xchange]
    #                xp = int_to_x(vp) + xchange
    #                yp = int_to_y(vp)
    #                if xp < bew:
    #                    dbest, xbest, ybest = improve_guess(a, b, ax, ay, xbest, ybest, dbest, xp, yp)
    #                    
    #            if ay - ychange < aeh:
    #                vp = ann[ay - ychange][ax]
    #                xp = int_to_x(vp)
    #                yp = int_to_y(vp) + ychange
    #                if yp < beh:
    #                    dbest, xbest, ybest = improve_guess(a, b, ax, ay, xbest, ybest, dbest, xp, yp)
    #                
    #            # Random search
    #            rs_start = rs_max
    #            if rs_start > max(b.shape[0], b.shape[1]):
    #                rs_start = max(b.shape[0], b.shape[1])
    #            mag = rs_start
    #            while mag >= 1:
    #                xmin = max(xbest - mag, 0); xmax = min(xbest + mag + 1, bew)
    #                ymin = max(ybest - mag, 0); ymax = min(ybest + mag + 1, beh)
    #                xp = xmin + np.random.rand() % (xmax - xmin)
    #                yp = ymin + np.random.rand() % (ymax - ymin)
    #                dbest, xbest, ybest = improve_guess(a, b, ax, ay, xbest, ybest, dbest, xp, yp)
    #                mag /= 2
    #            ann[ay][ax] = xy_to_int(xbest, ybest)
    #            annd[ay][ax] = dbest
    #            ax += xchange
    #        ay += ychange

def save_ann(bitmap):
    img = Image.new("RGB", (bitmap.shape[0], bitmap.shape[1]))
    out = img.load()
    for x_coord in range(1, bitmap.shape[0] - 1):
        for y_coord in range(1, bitmap.shape[1] - 1):
            data = bitmap[x_coord][y_coord]
            r = (data >> 16) & 255
            g = (data >> 8) & 255
            b = data & 255
            out[x_coord, y_coord] = (r, g, b)
    img.save("ann.jpg")
    print("ANN saved")            

def save_annd(bitmap):
    img = Image.new("RGB", (bitmap.shape[0], bitmap.shape[1]))
    out = img.load()
    for x_coord in range(1, bitmap.shape[0] - 1):
        for y_coord in range(1, bitmap.shape[1] - 1):
            data = bitmap[x_coord][y_coord]
            r = (data >> 16) & 255
            g = (data >> 8) & 255
            b = data & 255
            out[x_coord, y_coord] = (r, g, b)
    img.save("annd.jpg")
    print("ANND saved")

f = open("reconstructed_cbox.png", "rb")
i = np.fromfile(f, dtype=np.ubyte)
j = np.array(i, dtype=np.int)
    
a = imageio.imread("reconstructed_cbox.png")
b = imageio.imread("reconstructed_cbox1.png")
img = np.array(a, dtype=np.int)
img1 = np.array(b, dtype=np.int)

ann = np.zeros((img.shape[0], img.shape[1], 1), dtype=int)
annd = np.zeros((img.shape[0], img.shape[1], 1), dtype=int)

np.random.seed(0)
patchmatch(img, img1, ann, annd)

#save_ann(ann)
save_annd(annd)

(799965,)
102
bx 607
by 460
ac [0 0 0]
bc [0 0 0]
ar: 0 ag: 0 ab: 0
br: 0 bg: 0 bb: 0
dr: 0 dg: 0 db: 0
ans 0


KeyboardInterrupt: 