In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.ndimage import filters, zoom
from scipy.signal import savgol_filter

In [None]:
SIZE = 32

In [None]:
occlusion_map = np.random.choice(a=[False, True], size=(SIZE, SIZE), p=[0.8, 0.2])

In [None]:
plt.imshow(occlusion_map)

In [None]:
light_map = np.zeros(shape=(SIZE, SIZE), dtype=bool)

lights = []
num_lights = 10
while num_lights:
    row, col = np.random.randint(low=0, high=SIZE, size=2)
    if occlusion_map[row, col]:
        continue
    light_map[row, col] = True
    lights.append((row, col))
    num_lights -= 1

In [None]:
plt.imshow(light_map)

In [None]:
def ray_occlusion_test(start, end):
    sx, sy = start
    ex, ey = end
    steep = abs(ey - sy) > abs(ex - sx)
    if steep:
        sx, sy = sy, sx
        ex, ey = ey, ex
    if sx > ex:
        sx, ex = ex, sx
        sy, ey = ey, sy
    dx = ex - sx
    dy = ey - sy
    df = dy / dx
    if dx == 0.0:
        return True
    for x in [sx, ex] + [ix + 0.5 for ix in range(int(sx + 1), int(ex))]:
        ix = int(x)
        iy = int(sy + df * (x - sx))
        if ix >= occlusion_map.shape[0] or iy >= occlusion_map.shape[1]:
            continue
        if not steep and occlusion_map[iy, ix]:
            return False
        elif steep and occlusion_map[ix, iy]:
            return False
        if ix > sx:
            iy = int(sy + df * (ix - sx))
            if not steep and occlusion_map[iy, ix]:
                return False
            elif steep and occlusion_map[ix, iy]:
                return False
        if ix + 1 < ex:
            iy = int(sy + df * (ix + 1 - sx))
            if not steep and occlusion_map[iy, ix]:
                return False
            elif steep and occlusion_map[ix, iy]:
                return False
    return True

In [None]:
%%time

RASTER_SIZE = 8
ATTENUATION = 1
TRUNCATION = 1 / 255.0
label_map = np.zeros(shape=(RASTER_SIZE * SIZE, RASTER_SIZE * SIZE), dtype=np.float)

# Precompute labels.
for row in range(label_map.shape[0]):
    for col in range(label_map.shape[1]):
        x = col // RASTER_SIZE
        y = row // RASTER_SIZE
        s = (0.5 + (col % RASTER_SIZE)) / RASTER_SIZE
        t = (0.5 + (row % RASTER_SIZE)) / RASTER_SIZE
        for light in lights:
            eps = 0.01
            lx, ly = light[1] + 0.5, light[0] + 0.5
            if not ray_occlusion_test((x + s, y + t), (lx, ly)):
                continue
            irradiance = ATTENUATION / ((x + s - lx)**2 + (y + t - ly)**2)
            if irradiance >= TRUNCATION:
              label_map[row, col] = max(label_map[row, col], min(1, label_map[row, col] + irradiance))

In [None]:
blurred_label_map = filters.gaussian_filter(label_map, sigma = 2.0)
plt.imshow(blurred_label_map)

In [None]:
def make_color_map(pixel_map):
    color_map = np.zeros(shape=(RASTER_SIZE * SIZE, RASTER_SIZE * SIZE, 3), dtype=np.uint8)

    for row in range(label_map.shape[0]):
        for col in range(label_map.shape[1]):
            x = col // RASTER_SIZE
            y = row // RASTER_SIZE
            s = (0.5 + (col % RASTER_SIZE)) / RASTER_SIZE
            t = (0.5 + (row % RASTER_SIZE)) / RASTER_SIZE
            if occlusion_map[y, x]:
                color_map[row, col, 0:3] = [255, 0, 0]
            else:
                color_map[row, col, 0:3] = (pixel_map[row, col] * 255).astype(np.uint8)


    for light in lights:
        ly, lx = RASTER_SIZE * light[0], RASTER_SIZE * light[1]
        for row in range(ly + 2, ly + RASTER_SIZE - 2):
            for col in range(lx + 2, lx + RASTER_SIZE - 2):
                color_map[row, col] = [0, 0, 255]
                
    return color_map                

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(make_color_map(blurred_label_map))

In [None]:
def clamp(arr, max_norm):
    arr_norm = np.linalg.norm(arr)
    if arr_norm > max_norm:
        return arr / arr_norm * max_norm
    else:
        return arr

def cross_entropy(label, prediction):
    return -label * np.log(prediction) - (1 - label) * np.log(1 - prediction)

In [None]:
# Define the model:
# - A maps vertex light rays onto vertex ray embeddings
# - B maps the sum of vertex ray embeddings onto a radiance embedding
# - C maps the linear combination of radiance embeddings at a point onto a pixel color
# - ray embedding = [1, l_x, l_y, l_z, l_w, l_d, l_d^2 occluded, f_x, f_y, f_z, f_d, f_d^2, o_1, o_2, o_3, o_4]
# - A -> s-t interpolate -> B -> add_basis -> color
MAX_LIGHTS = 4
RADIANCE_DIMENSION = 16
LIGHT_DIMENSION = 17
BASIS_NUM = 64
BASIS_SIZE = RASTER_SIZE

# Atlas of pixel maps for a face.
basis_map = 0.1 * (np.random.random(size=(BASIS_NUM, BASIS_SIZE, BASIS_SIZE)) - 0.5)

A = 0.01 * (np.random.random(size=(RADIANCE_DIMENSION, MAX_LIGHTS * LIGHT_DIMENSION)) - 0.5)
B = 0.01 * (np.random.random(size=(RADIANCE_DIMENSION,)) - 0.5)
C = 0.01 * (np.random.random(size=(BASIS_NUM, 4 * LIGHT_DIMENSION * MAX_LIGHTS)) - 0.5)

# Labels and example weights.
pixel_labels = blurred_label_map.copy()
pixel_weights = np.absolute(label_map - blurred_label_map)**0.5
pixel_weights = 5.0 * np.clip(0.1 + filters.gaussian_filter(pixel_weights, sigma = 2.0), 0, 1)

# Learning rates
ALPHA_BASE = 0.01
ALPHA_DECAY = 0.01
ALPHA_FINAL = 0.00001

# Gradient clipping
MAX_GRAD = 2.0

# How many times to loop over each face.
LOOP_COUNT = 100

# Initialize training state.
errors = []
ALPHA = ALPHA_BASE
iteration = 0
face_coords = [(x, y) for x in range(SIZE) for y in range(SIZE) if not occlusion_map[y, x]]

ST = np.zeros(shape=(4, RASTER_SIZE, RASTER_SIZE), dtype=np.float32)
for p_row in range(RASTER_SIZE):
    for p_col in range(RASTER_SIZE):
        s = (0.5 + p_col) / RASTER_SIZE
        t = (0.5 + p_row) / RASTER_SIZE        
        ST[0, p_row, p_col] = (1 - s) * (1 - t)
        ST[1, p_row, p_col] = s * (1 - t)
        ST[2, p_row, p_col] = (1 - s) * t
        ST[3, p_row, p_col] = s * t

In [None]:
%%time
%matplotlib notebook

# Create figures.
fig1, ax1 = plt.subplots(figsize=(8, 8))
fig2, ax2 = plt.subplots(figsize=(8, 3))

light_ray_cache = {}

def compute_light_ray(x, y, lx, ly, dx, dy):
    distance = np.linalg.norm([lx - x, ly - y, 0])
    assert distance > 0.01
    return [
        1,
        1.0 / distance**2,
        1 if distance < 1 else 0,
        1 if distance < 2 else 0,
        1 if distance < 4 else 0,
        1 if distance < 6 else 0,
        1 if distance < 8 else 0,
        1 if distance > 1 else 0,
        1 if distance > 2 else 0,
        1 if distance > 4 else 0,
        1 if distance > 6 else 0,
        1 if distance > 8 else 0,
        occlusion_map[y - 1, x - 1] if y > 0 and x > 0 else 1,
        occlusion_map[y - 1, x] if y > 0 and x < occlusion_map.shape[1] - 1 else 1,
        occlusion_map[y, x - 1] if y < occlusion_map.shape[0] - 1 and x > 0 else 1,
        occlusion_map[y, x] if y < occlusion_map.shape[0] - 1 and x < occlusion_map.shape[1] - 1 else 1,
        int(ray_occlusion_test((x, y), (lx, ly))),
    ]   


def forward_propagation(x, y):
    # Cast a ray from each vertex to each light.
    if (x, y) not in light_ray_cache:
        L1 = np.zeros(shape=MAX_LIGHTS * LIGHT_DIMENSION, dtype=np.float32)
        L2 = np.zeros(shape=MAX_LIGHTS * LIGHT_DIMENSION, dtype=np.float32)
        L3 = np.zeros(shape=MAX_LIGHTS * LIGHT_DIMENSION, dtype=np.float32)
        L4 = np.zeros(shape=MAX_LIGHTS * LIGHT_DIMENSION, dtype=np.float32)
        face_lights = sorted(lights, key=lambda l: (l[1] - x)**2 + (l[0] - y)**2)
        for i, light in enumerate(face_lights[:MAX_LIGHTS]):
            ly, lx = light[0] + 0.5, light[1] + 0.5
            ld = LIGHT_DIMENSION
            L1[ld * i : ld * (i + 1)] = compute_light_ray(x, y, lx, ly, 1, 1)
            L2[ld * i : ld * (i + 1)] = compute_light_ray(x + 1, y, lx, ly, -1, 1)
            L3[ld * i : ld * (i + 1)] = compute_light_ray(x, y + 1, lx, ly, 1, -1)
            L4[ld * i : ld * (i + 1)] = compute_light_ray(x + 1, y + 1, lx, ly, -1, -1)
        light_ray_cache[(x, y)] = (L1, L2, L3, L4)
    
    L1, L2, L3, L4 = light_ray_cache[(x, y)]

    # Map each light ray onto its radiance vector.
    r1 = A.dot(L1)
    r2 = A.dot(L2)
    r3 = A.dot(L3)
    r4 = A.dot(L4)

    # Generate the predicted face.
    R = np.zeros(shape=(RASTER_SIZE, RASTER_SIZE, B.shape[0]), dtype=np.float32)
    R += ST[0, :, :, np.newaxis] * r1
    R += ST[1, :, :, np.newaxis] * r2
    R += ST[2, :, :, np.newaxis] * r3
    R += ST[3, :, :, np.newaxis] * r4
    T = np.tensordot(R, B, axes=1)
    P = np.clip(1.0 / (1 + np.exp(-T)), 0, 1)
    
    return L1, L2, L3, L4, R, T, P

for loop_count in range(LOOP_COUNT):
    ALPHA = max(ALPHA_FINAL, ALPHA - ALPHA_DECAY * ALPHA_BASE)
    
    random.shuffle(face_coords)
    for x, y in face_coords:
        # Generate a random training example from a random face.
        x, y = np.random.randint(low=0, high=SIZE, size=2)
        col_s, row_s = x * RASTER_SIZE, y * RASTER_SIZE
        col_e, row_e = (x + 1) * RASTER_SIZE, (y + 1) * RASTER_SIZE
        L = pixel_labels[row_s : row_e, col_s : col_e]

        # Forward progate the model.
        L1, L2, L3, L4, R, T, P = forward_propagation(x, y)
        
        # Compute the back propagation scaling factors.
        scale = L - P
        
        # Compute B gradient.
        B_grad = (scale[:, :, np.newaxis] * R).mean(axis=(0, 1))
                
        # Compute A gradient.
        A_grad = np.zeros(shape=A.shape, dtype=np.float32)
        A_grad += 0.25 * (scale * ST[0, :, :]).mean() * np.outer(B, L1)
        A_grad += 0.25 * (scale * ST[1, :, :]).mean() * np.outer(B, L2)
        A_grad += 0.25 * (scale * ST[2, :, :]).mean() * np.outer(B, L3)
        A_grad += 0.25 * (scale * ST[3, :, :]).mean() * np.outer(B, L4)
        
        # Back propagate the model.
        B += ALPHA * B_grad
        A += ALPHA * A_grad

        # Log statistics periodically.
        if iteration % 100 == 0 or iteration < 100:
            bL = (0.5 + 255 * L).astype(np.uint8)
            bP = (0.5 + 255 * P).astype(np.uint8)
            errors.append(np.sum((bL - bP)**2))
        iteration += 1

    # Plot the current lighting of the trained model.
    train_map = np.zeros(shape=(RASTER_SIZE * SIZE, RASTER_SIZE * SIZE), dtype=np.float32)
    for y in range(SIZE):
        for x in range(SIZE):
            row_s, row_e = y * RASTER_SIZE, (y + 1) * RASTER_SIZE
            col_s, col_e = x * RASTER_SIZE, (x + 1) * RASTER_SIZE
            L1, L2, L3, L4, R, T, P = forward_propagation(x, y)
            train_map[row_s : row_e, col_s : col_e] = P
    ax1.clear()
    ax1.imshow(make_color_map(train_map))
        
    # Plot learning curve.
    plot_errors = savgol_filter(errors, min(2 * (len(errors) // 4) + 1, 11), 3)
    ax2.clear()
    ax2.plot(range(len(plot_errors)), plot_errors, marker="")
    
    fig1.canvas.draw()
    fig2.canvas.draw()
    plt.pause(0.05)
    print(
        "loop =", loop_count,
        "ALPHA =", ALPHA,
    )