# Thresholding

In [32]:
import tensorflow as tf
import numpy as np
from typing import List

In [3]:
%run ../grid.py

In [8]:
filler_filepath = "../../03-psychophysics/data/training-hole-filling/training-encoders/2019-07-18_13-57_encoder_48_48.h5"
filler = tf.keras.models.load_model(filler_filepath)

W0801 11:43:03.353553  4004 hdf5_format.py:171] No training configuration found in save file: the model was *not* compiled. Compile it manually.


In [9]:
encoder_filepath = "../"

<tensorflow.python.keras.engine.training.Model at 0x23ac3d187f0>

In [15]:
def difference(rendered: np.ndarray, filled: np.ndarray) -> float:
    """
    Given the rendered and filled image, return the sum of squared pixel-wise difference.
    """
    
    return np.sum((rendered - filled) ** 2) 

In [24]:
def binarise(encoding: np.ndarray, threshold: float) -> np.ndarray:
    """
    Binarises an encoding at a given threshold. 
    """
    return np.where(encoding > threshold, 1.0, 0.0).astype(np.float32)

In [33]:
def iterate_diffs(encoding: np.ndarray, grid: AbstractGrid) -> List[float]:
    """
    Given an encoding and grid, iterates through 100 thresholds between 0 and 1 at 0.01 intervals. 
    """
    
    original = render(grid, encoding)
    filled = filler(original)
    
    diffs = []
    
    for threshold in np.linspace(0, 0.99, 100):
        binarised = binarise(encoding, threshold)
        rendered = render(grid, binarised)
        diff = difference(rendered, filled)
        diffs.append(diff)
        
    return diffs

In [34]:
np.argmin([1, 2, 3, 0.5, 4])

3

In [35]:
def find_best_threshold(encoder: tf.keras.Model, grid: AbstractGrid, seed: tf.Tensor) -> List[float]:
    """
    Given an encoder and a grid, find the best threshold for each digit for a given seed.
    """
    
    best_thresholds = []
    best_threshold_diffs = []
    
    for digit in range(10):
        one_hot_digit = tf.one_hot(digit, depth=11)
        encoding = encoder((seed, one_hot_digit))
        diffs = iterate_diffs(encoding, grid)
        lowest_diff = np.argmin(diffs)
        best_threshold = 0.01 * lowest_diff
        
        best_threshold_diffs.append(lowest_diff)
        best_thresholds.append(best_threshold)
        
    return best_thresholds

In [38]:
def render_and_process(encoder: tf.keras.Model, digit: int, seed: tf.Tensor, best_thresholds: List[float]) -> np.ndarray:
    
    one_hot_digit = tf.one_hot(digit, depth=11)
    encoding = encoder((seed, one_hot_digit))
    binarised = binarise(encoding, best_thresholds[digit])
    rendered = render(grid, encoding)
    
    return rendered