<a href="https://colab.research.google.com/github/skosch/YinYangFit/blob/master/YinYangFit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set up TF2 and import dependencies

In [0]:
try:
    %tensorflow_version 2.x
except Exception:
    pass

import tensorflow as tf
import os

print("TF version:", tf.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")

if tf.test.is_gpu_available():
    device_name = tf.test.gpu_device_name()
    if device_name != '/device:GPU:0':
      raise SystemError('GPU device not found')
    print('Found GPU at: {}'.format(device_name))
else:
    tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
    
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(cluster_resolver)
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
    tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)


In [0]:
import itertools
import os

import numpy as np
pi = np.pi

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm
from mpl_toolkits.mplot3d import Axes3D
from scipy.ndimage.interpolation import affine_transform
import tensorflow as tf
import random; random.seed()
import math
import pickle
import os
from tqdm import tqdm as tqdm
import sys
from functools import reduce
import random
from itertools import cycle, islice, product
import operator

!pip install --quiet --upgrade git+git://github.com/simoncozens/tensorfont.git
!pip install --quiet fonttools
!pip install --upgrade git+git://github.com/simoncozens/fontParts.git@d444bde6e2a0adbcd9a16593a615a99823089c70
!pip install booleanOperations
import fontParts

from tensorfont import Font

print("✓ Dependencies imported.")



### Download font files

In [0]:
#!wget -q -O OpenSans-Regular.ttf https://github.com/googlefonts/opensans/blob/master/ttfs/OpenSans-Regular.ttf?raw=true
#!wget -q -O Roboto.ttf https://github.com/google/fonts/blob/master/apache/roboto/Roboto-Regular.ttf?raw=true
#!wget -q -O Roboto.otf https://github.com/AllThingsSmitty/fonts/blob/master/Roboto/Roboto-Regular/Roboto-Regular.otf?raw=true
#!wget -q -O DroidSerif.ttf https://github.com/datactivist/sudweb/blob/master/fonts/droid-serif-v6-latin-regular.ttf?raw=true
#!wget -q -O CrimsonItalic.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Italic.otf?raw=true
#!wget -q -O CrimsonBold.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Bold.otf?raw=true 
#!wget -q -O CrimsonRoman.otf https://github.com/alif-type/amiri/blob/master/Amiri-Regular.ttf?raw=true

!wget -q -O CrimsonRoman.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Roman.otf?raw=true
print("✓ Font file(s) downloaded.")

## Load font data and set up global parameters

In [0]:
glyph_char_list = "abcdeghijlmnopqrstuzywvxkf"
#glyph_char_list = "abgjqrst"
glyph_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#glyph_char_list = "OO"

# ==== Create Font ====
factor = 1.14 #1.539  # This scales the size of everything
f = Font("CrimsonRoman.otf", 24 * factor) # Roboto.ttf CrimsonRoman.otf # 34 for lowercase
box_height = f.full_height_px
box_width = int(141 * factor)
box_width += (box_width + 1) % 2
print("Box size:", box_height, "×", box_width)

# 37067520 allocated, maxAllocSize: 873707520

batch_size = 1  # must be divisible by 8 to work on TPU
n_sample_distances = 3 # should be an odd number

n_sizes = 14
n_pseudo_orientations = 4
n_orientations = 4

## Create Gabor filter bank

In [0]:
def get_sigmas(skip_scales=0):
    sigmas = []
    for s in range(n_sizes):
        min_sigma = 1.5
        max_sigma = box_width / 12
        #sigmas.append((max_sigma - min_sigma) * (s + skip_scales)**2 / (n_sizes - 1)**2 + min_sigma)
        sigmas.append((max_sigma - min_sigma) * s / n_sizes + min_sigma)
    return np.array(sigmas)
print("Sigmas are", get_sigmas())

def get_3n_filters(skip_scales, display_filters=False):
    def rotated_mgrid(oi):
        """Generate a meshgrid and rotate it by RotRad radians."""
        rotation = np.array([[ np.cos(pi*oi/n_orientations), np.sin(pi*oi/n_orientations)],
                             [-np.sin(pi*oi/n_orientations), np.cos(pi*oi/n_orientations)]])
        hh = box_height # / 2
        bw = box_width # / 2
        y, x = np.mgrid[-hh:hh, -bw:bw].astype(np.float32)
        y += 0.5 # 0 if box_height % 2 == 0 else 0.5
        x += 0.5 # 0 if box_width % 2 == 0 else 0.5
        return np.einsum('ji, mni -> jmn', rotation, np.dstack([x, y]))

    def get_filter(sigma, theta):
        x, y = rotated_mgrid(theta)

        # To minimize ringing etc., we create the filter as is, then run it through the DFT.
        a1 = 0.25 # See Georgeson et al. 2007
        s1 = sigma #a1 * sigma
        d1_space = -np.exp(-(x**2+y**2)/(2*s1**2))*x/(2*pi*s1**4)
        d1 = np.fft.fft2(d1_space + 1j * np.zeros_like(d1_space)) #, [box_height, box_width])

        # Second derivative:
        s2 = sigma #np.sqrt(1. - a1**2) * sigma # See Georgeson et al. (2007)
        d2_space = np.exp(-(x**2+y**2)/(2*s2**2))/(2*pi*s2**4) - np.exp(-(x**2+y**2)/(2*s2**2))*x**2/(2*pi*s2**6)
        d2 = sigma**1.5 * np.fft.fft2(d2_space + 1j * np.zeros_like(d2_space)) #, [box_height, box_width])

        # For now: d1 is complex(-d2,d1)
        d1c = 1j * (d2 + 1j*d1) #/ sigma #np.sqrt(sigma)
        return (d1c, d2)

    d1_bank = np.zeros((n_sizes, n_orientations, 2*box_height, 2*box_width)).astype(np.complex64)
    d2_bank = np.zeros((n_sizes, n_orientations, 2*box_height, 2*box_width)).astype(np.complex64)

    if display_filters:
        sizediv = 60
        fig, ax = plt.subplots(nrows=n_sizes*2, ncols=n_orientations, gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(box_width * n_orientations / sizediv, box_height * n_sizes * 2 / sizediv))

    sigmas = get_sigmas()
    for s in range(n_sizes):
        sigma = sigmas[s]
        for o in range(n_orientations):
            (d1, d2) = get_filter(sigma, o)
            if display_filters:
                #ax[s*2, o].imshow(np.real(np.fft.ifft2(d1)), cmap="inferno")
                ax[s*2, o].imshow(np.abs(np.fft.fftshift(d1)), cmap="inferno")
                ax[s*2, o].set_aspect("auto")
                ax[s*2, o].set_yticklabels([])
                ax[s*2+1, o].imshow(np.imag(np.fft.ifft2(d1)), cmap="inferno")
                ax[s*2+1, o].set_aspect("auto")
                ax[s*2+1, o].set_yticklabels([])
            d1_bank[s, o, :, :] = d1
            d2_bank[s, o, :, :] = d2

    if display_filters:
        plt.show()

    return (d1_bank.astype(np.complex64), d2_bank)

d1_filter_bank, d2_filter_bank = get_3n_filters(0, False)

## Rasterize the glyphs into numpy arrays, and extract their ink widths in pixels


In [0]:
def apply_filter_bank(input_image, filter_bank):
    """
    Input image should have dimensions <h, w> or <s, o, h, w> or <b, s, o, h, w, d>.
    Filter bank should have dimensions <s, o, h, w>
    """
    if len(input_image.shape) == 2:
        bdsohw_input_image = input_image[None, None, None, None, :, :]
    elif len(input_image.shape) == 4:
        bdsohw_input_image = input_image[None, None, :, :, :, :]
    elif len(input_image.shape) == 6:
        bdsohw_input_image = tf.einsum("bsohwd->bdsohw", input_image)

    # pad image to filter size, which is 2*box_height, 2*box_width (to prevent too much wrapping)
    padded_input = tf.pad(bdsohw_input_image, [[0, 0], [0, 0], [0, 0], [0, 0],
                            [int(np.ceil(box_height / 2)), int(box_height / 2)],
                            [int(np.ceil(box_width / 2)), int(box_width / 2)]], mode='CONSTANT')

    input_in_freqdomain = tf.signal.fft2d(tf.dtypes.cast(tf.complex(padded_input, tf.zeros_like(padded_input)), tf.complex64))

    padded_result = (tf.signal.ifft2d(input_in_freqdomain * filter_bank[None, None, :, :, :, :]))

    if len(input_image.shape) == 2:
        presult = tf.signal.fftshift(padded_result[0, 0, :, :, :, :], axes=[2, 3])
        # Return <s, o, h, w>
        return presult[:, :, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                           int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))]
    elif len(input_image.shape) == 4:
        presult = tf.signal.fftshift(padded_result[0, 0, :, :, :, :], axes=[2, 3])
        # Return <s, o, h, w>
        return presult[:, :, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                           int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))]
    elif len(input_image.shape) == 6:
        presult = tf.einsum("bdsohw->bsohwd", tf.signal.fftshift(padded_result, axes=[2, 3]))
        return presult[:, :, :, :,
                           int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                           int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))]

def get_glyph_data_with_filtered_as_dict(glyph_char):
    """
    Returns a dict containing relevant glyph data, including filtered images.
    @param glyph_char: string of length 1
    """
    glyph_image = f.glyph(glyph_char).as_matrix(normalize=True).with_padding_to_constant_box_width(box_width).astype(np.float32)

    return {
        'glyph_char': glyph_char,
        'glyph_image': glyph_image,
        'glyph_ink_width': f.glyph(glyph_char).ink_width,
        'glyph_d1_filtered_images': apply_filter_bank(glyph_image, d1_filter_bank),
    }

def get_sample_distances_and_translations(gd1, gd2, target_ink_distance):
    """Returns a list of distances at which the box images have to be shifted left and right before they can be overlaid to sample their normalized interaction"""
    total_width_at_minimum_ink_distance = gd1['glyph_ink_width'] + gd2['glyph_ink_width'] - f.minimum_ink_distance(gd1['glyph_char'], gd2['glyph_char'])
    
    relative_sample_distances = [0]
    simax = int((n_sample_distances - 1)/2)
    
    pos_ad = +2 #/ 2.
    neg_ad = -2 #(target_ink_distance - 2)
    next_pos = pos_ad
    next_neg = neg_ad
    for si in range(simax):
        # always append a positive, and then ...
        relative_sample_distances.append(next_pos)
        next_pos += pos_ad
        pos_ad *= 1.
        # ... append a negative only if there is room
        if target_ink_distance + next_neg >= -3 or True:
            relative_sample_distances.append(next_neg)
            next_neg += neg_ad
        else:
            relative_sample_distances.append(next_pos)
            next_pos += pos_ad
            pos_ad *= 1.
    
    relative_sample_distances.sort()
    
    zero_index_val = relative_sample_distances.index(0)
    desired_penalty_slope_sign = np.ones((n_sample_distances - 1))
    desired_penalty_slope_sign[:zero_index_val] = -1

    sample_distances = (np.array(relative_sample_distances) + target_ink_distance)
    sample_distances_left = np.ceil(sample_distances / 2)
    sample_distances_right = np.floor(sample_distances / 2)
    
    total_ink_width = gd1['glyph_ink_width'] + gd2['glyph_ink_width']
    ink_width_left = np.floor(total_ink_width / 4)
    ink_width_right = np.ceil(total_ink_width / 4)
    
    left_translations = (-(np.ceil(total_width_at_minimum_ink_distance/2) + sample_distances_left) - (-ink_width_left)).astype(np.int32)
    right_translations = ((np.floor(total_width_at_minimum_ink_distance/2) + sample_distances_right) - ink_width_right).astype(np.int32)

    return {
        'sample_distances': sample_distances,
        'relative_sample_distances': relative_sample_distances,
        'left_translations': left_translations,
        'right_translations': right_translations,
        'zero_index': zero_index_val,
        'desired_penalty_slope_sign': desired_penalty_slope_sign,
    }

def shift_sohw1_into_sohwd(input_images, translations):
    """Shifts images to left/right and back-fills with zeros.
    @param images: <sizes, orientations, height, width, 1>
    @param translations: <len(translations)>
    @output        <sizes, orientations, height, width, len(translations)>
    """
    images = tf.tile(input_images, [1, 1, 1, 1, translations.shape[0]]) # create len(shifts) channel copies
    fill_constant = 0
    left = tf.maximum(0, tf.reduce_max(translations)) # positive numbers are shifts to the right, for which we need to add zeros on the left
    right = -tf.minimum(0, tf.reduce_min(translations)) # negative numbers are shifts to the left, for which we need to add zeros on the right
    left_mask = tf.ones(shape=(tf.shape(images)[0], tf.shape(images)[1], tf.shape(images)[2], left, tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    right_mask = tf.ones(shape=(tf.shape(images)[0], tf.shape(images)[1], tf.shape(images)[2], right, tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    padded_images = tf.concat([left_mask, images, right_mask], axis=3) # pad on axis 3 (i.e. width-wise)

    # Now that the images are all padded, we need to crop them to implement the shifts.
    def crop_image_widthwise(image_and_shift):
        image = image_and_shift[0] # sohw
        shift = image_and_shift[1] # 
        return image[:, :, :, left-shift:left-shift+input_images.shape[3]] # positive shift: left-shift

    return tf.einsum("dsohw->sohwd", tf.map_fn(
        crop_image_widthwise,
        (tf.einsum("sohwd->dsohw", padded_images), translations),
        dtype=images.dtype))

def shift_hwso1_into_hwsod(input_images, translations):
    """Shifts images to left/right and back-fills with zeros.
    @param images: <height, width, sizes, orientations, 1>
    @param translations: <len(translations)>
    @output        <height, width, sizes, orientations, len(translations)>
    """
    images = tf.tile(input_images, [1, 1, 1, 1, translations.shape[0]]) # create len(shifts) channel copies
    fill_constant = 0
    left = tf.maximum(0, tf.reduce_max(translations)) # positive numbers are shifts to the right, for which we need to add zeros on the left
    right = -tf.minimum(0, tf.reduce_min(translations)) # negative numbers are shifts to the left, for which we need to add zeros on the right
    left_mask = tf.ones(shape=(tf.shape(images)[0], left, tf.shape(images)[2], tf.shape(images)[3], tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    right_mask = tf.ones(shape=(tf.shape(images)[0], right, tf.shape(images)[2], tf.shape(images)[3], tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    padded_images = tf.concat([left_mask, images, right_mask], axis=1) # pad on axis 2 (i.e. width-wise)

    # Now that the images are all padded, we need to crop them to implement the shifts.
    def crop_image_widthwise(image_and_shift):
        image = image_and_shift[0]
        shift = image_and_shift[1]
        return image[:, left-shift:left-shift+input_images.shape[1], :, :] # positive shift: left-shift

    return tf.einsum("dhwso->hwsod", tf.map_fn(
        crop_image_widthwise,
        (tf.einsum("hwsod->dhwso", padded_images), translations),
        dtype=images.dtype))

def shift_and_overlay_pair_data(gd1, gd2):
    """
    Returns a 5D tensor <box_height, box_width, sizes, orientations, distances>.
    """

    target_ink_distance = int(f.pair_distance(gd1['glyph_char'], gd2['glyph_char']) + f.minimum_ink_distance(gd1['glyph_char'], gd2['glyph_char']))
    sdt = get_sample_distances_and_translations(gd1, gd2, target_ink_distance)

    shifted_gd1_d1_filtered_images = shift_sohw1_into_sohwd(gd1['glyph_d1_filtered_images'][..., None], sdt['left_translations'])
    shifted_gd2_d1_filtered_images = shift_sohw1_into_sohwd(gd2['glyph_d1_filtered_images'][..., None], sdt['right_translations'])

    # We want to shift both the original images (for display purposes only), as well as the filtered images.
    pair_images = (shift_hwso1_into_hwsod(gd1['glyph_image'][..., None, None, None], sdt['left_translations']) + 
                   shift_hwso1_into_hwsod(gd2['glyph_image'][..., None, None, None], sdt['right_translations']))[:, :, 0, 0, :]
    zero_index = sdt['zero_index']
    desired_penalty_slope_sign = sdt['desired_penalty_slope_sign']
    sample_distances = sdt['sample_distances']

    return  {
        'shifted_gd1_d1_filtered_images': shifted_gd1_d1_filtered_images,
        'shifted_gd2_d1_filtered_images': shifted_gd2_d1_filtered_images,
        'ink_distance': target_ink_distance,
        'pair_images': pair_images,
        'zero_index': zero_index,
        'desired_penalty_slope_sign': desired_penalty_slope_sign,
        'sample_distances': sample_distances,
    }

class InputGenerator(tf.keras.utils.Sequence):
    def __init__(self, batch_size):
        self.batch_size = batch_size
        print("Creating glyph images ...", flush=True)
        self.glyph_data = []
        for glyph_char in tqdm(glyph_char_list):
            self.glyph_data.append(get_glyph_data_with_filtered_as_dict(glyph_char))
        self.n_pairs = len(glyph_char_list) ** 2
        self.cached_pair_data = {}

    def kill(self):
        del self.glyph_data
        del self.cached_pair_data

    def __len__(self):
        """Number of batches in the dataset"""
        return math.ceil(self.n_pairs / self.batch_size)

    def __getitem__(self, idx):
        """Return the content of batch idx.
        Instead of always providing the same data for batch i,
        we just pick batch_size random glyph pairs and return their glyph data.

        This is run on the CPU, because otherwise the calculations sit in GPU memory,
        are never released, and lead to annoying out-of-memory issues all the time.

        Output: 
        ([shifted_gd1_filtered_images, shifted_gd2_filtered_images, sample_distances], [ink_distance, pair_images, zero_index])
        """
        with tf.device('/CPU:0'):
            g_shifted_gd1_d1_filtered_images = []
            g_shifted_gd2_d1_filtered_images = []
            g_sample_distances = []
            g_ink_distance = []
            g_pair_images = []
            g_zero_index = []
            g_desired_penalty_slope_sign = []
            for i in range(batch_size):
                g1 = random.choice(self.glyph_data)
                g2 = random.choice(self.glyph_data)
    
                #if (g1['glyph_char'] + g2['glyph_char']) not in self.cached_pair_data:
                #    self.cached_pair_data[g1['glyph_char'] + g2['glyph_char']] = shift_and_overlay_pair_data(g1, g2)
                #cpd = self.cached_pair_data[g1['glyph_char'] + g2['glyph_char']]

                # Don't cache the data -- does that save RAM?
                cpd = shift_and_overlay_pair_data(g1, g2)

                g_shifted_gd1_d1_filtered_images.append(cpd['shifted_gd1_d1_filtered_images'])
                g_shifted_gd2_d1_filtered_images.append(cpd['shifted_gd2_d1_filtered_images'])
                g_sample_distances.append(cpd['sample_distances'])
                g_ink_distance.append(cpd['ink_distance'] * 1.0)
                g_pair_images.append(cpd['pair_images'])
                g_zero_index.append(cpd['zero_index'])
                g_desired_penalty_slope_sign.append(cpd['desired_penalty_slope_sign'])
    
            inputs = [
                tf.stack(g_shifted_gd1_d1_filtered_images),
                tf.stack(g_shifted_gd2_d1_filtered_images),
                tf.stack(g_sample_distances),
                tf.stack(g_pair_images),
                tf.stack(g_zero_index),
                tf.stack(g_desired_penalty_slope_sign),
            ]   
            outputs = tf.stack(g_ink_distance)
    
            return inputs, outputs 


## Model evaluator

In [0]:
tf.keras.backend.clear_session()  # For easy reset of notebook state.
full_shape = ( n_sizes, n_orientations, box_height, box_width, n_sample_distances)

eps = np.finfo(np.float32).tiny

@tf.function
def nd_softmax(target, axis, name=None):
    max_axis = tf.reduce_max(target, axis, keepdims=True)
    target_exp = tf.exp(target - max_axis)
    normalize = tf.reduce_sum(target_exp, axis, keepdims=True)
    softmax = target_exp / (normalize + eps)
    return softmax


@tf.function
def tilo(t):
    return tf.concat([t[:, :, 0:1, :, :, :], t[:, :, 1:2, :, :, :], t[:, :, 2:3, :, :, :], t[:, :, 1:2, :, :, :]], axis=2)
@tf.function
def tilop(t):
    return tf.nn.softplus(tilo(t))
@tf.function
def ptilo(t):
    return tf.concat([t[:, :, 0:1, :, :, :, :], t[:, :, 1:2, :, :, :, :], t[:, :, 2:3, :, :, :, :], t[:, :, 1:2, :, :, :, :]], axis=2)
@tf.function
def ptilop(t):
    return tf.nn.softplus(ptilo(t))
@tf.function
def tiloa(t):
    return tf.concat([t[:, :, :, :, 0:1, :, :], t[:, :, :, :, 1:2, :, :], t[:, :, :, :, 2:3, :, :], t[:, :, :, :, 1:2, :, :]], axis=4)
@tf.function
def tiloap(t):
    return tf.nn.softplus(tiloa(t))

def rectify_phases(inputs):
    # Inputs: <b, s, o, h, w, d>
    # Output: <b, s, o, p, h, w, d> where p is [0, 1, 2, 3]
    return tf.stack([tf.nn.relu(tf.math.real(inputs)),
                     tf.nn.relu(tf.math.imag(inputs)),
                     tf.nn.relu(-tf.math.real(inputs)),
                     tf.nn.relu(-tf.math.imag(inputs))], axis=3)

class BiasedAbs(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(BiasedAbs, self).__init__(**kwargs)

        self.bias_weights = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                       initializer=tf.keras.initializers.Constant(0.),
                                       name='bias_weights',
                                       trainable=True)

    def print_weights(self):
        print(tf.nn.sigmoid(tilo(self.bias_weights))[0, :, :, 0, 0, 0])

    def call(self, inputs):
        squashed_bias_weights = tf.nn.sigmoid(tilo(self.bias_weights))
        # 0.0 -> abs(real)
        # 0.5 -> abs(inputs)
        # 1.0 -> abs(imag)
        fr = 1. - squashed_bias_weights
        fi = squashed_bias_weights
        biased_abs = tf.sqrt(eps + fr * tf.math.real(inputs)**2 + fi * tf.math.imag(inputs)**2) * tf.sqrt(2.)
        #biased_abs = tf.sqrt(eps + fr * (inputs[:, :, :, 0, :, :, :] - inputs[:, :, :, 2, :, :, :])**2 + fi * (inputs[:, :, :, 1, :, :, :] - inputs[:, :, :, 3, :, :, :])**2) * tf.sqrt(2.)

        return biased_abs

class NormalizationPool(tf.keras.layers.Layer):
    # This layer computes the unscaled normalization pool for each neuron.
    # It effectively performs a blur over space, frequency, and orientation, using
    # Gaussian blur via 3D FFT (instead of a convolutional layer) followed by
    # a matrix multiplication over the orientation axis (arbitrary kernel).
    #
    # The input is a 7D tensor of reals <b, s, o, p, y, x, d> (scale_index, orientation_index, phase_index, vertical coordinate, horizontal coordinate)
    # The output is a 7D tensor of reals <b, s, o, p, y, x, d> (scale_index, orientation_index, phase_index, vertical coordinate, horizontal coordinate)
    #
    # For more information, see Sawada & Petrov (2017)

    def __init__(self, **kwargs):
        super(NormalizationPool, self).__init__(**kwargs)

        self.sigmas = get_sigmas()
        self.wavelengths = self.get_wavelengths()
        self.r2grid, self.sgrid = self.get_distgrids()

        self.spatial_pool_size_factor = self.add_weight(shape=(),
                                                        initializer=tf.keras.initializers.Constant(1.),
                                                        name='spatial_pool_size_factor',
                                                        trainable=True)
        self.scale_pool_size_factor = self.add_weight(shape=(),
                                                      initializer=tf.keras.initializers.Constant(1.),
                                                      name='scale_pool_size_factor',
                                                      trainable=True)

        self.rescale_factor = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(-0.0),
                                       name='npool_rescale_factor',
                                       trainable=True)
        #self.phase_congruency_coefficient = self.add_weight(shape=(),
        #                                                      initializer=tf.keras.initializers.Constant(2.),
        #                                                      name="phase_congruency_coefficient",
        #                                                      trainable=True)

        # If this turns out to be symmetric, replace it with a von-Mises distribution
        # factor = exp(1.22 * cos(angle-diff)**2)  -- 1.22 coefficient taken from Sawada
        ro = np.eye(n_orientations)
        self.orientation_inhibition_matrix = self.add_weight(shape=(n_orientations, n_orientations),
                                                             initializer=tf.keras.initializers.Constant(ro),
                                                             name='orientation_inhibition_matrix',
                                                             trainable=True)
        po = 2. - np.eye(4)
        self.phase_inhibition_matrix = self.add_weight(shape=(4, 4),
                                                             initializer=tf.keras.initializers.Constant(po),
                                                             name='phase_inhibition_matrix',
                                                             trainable=True)
        
        # Instead of an orientation inhibition matrix, we use a combined spatial/orientation filter.
        # Inhibition should be highest for orientations either in the opposition direction, or in the same direction but
        # aligned parallel.

        


    def print_weights(self):
        print("Spatial pool size factor", tf.nn.softplus(self.spatial_pool_size_factor.numpy()))
        print("Scale pool size factor", tf.nn.softplus(self.scale_pool_size_factor.numpy()))
        print("Rescale factor", self.rescale_factor)
        print("Orientation inhibition matrix")
        plt.imshow(tf.nn.softplus(self.orientation_inhibition_matrix.numpy()))
        plt.colorbar()
        plt.show()
        print("Phase inhibition matrix")
        plt.imshow(tf.nn.softplus(self.phase_inhibition_matrix.numpy()))
        plt.colorbar()
        plt.show()

    def get_wavelengths(self):
        sigmas = self.sigmas
        # We are padding to twice n_sizes
        padded_sigmas = [sigmas[0]] * int(np.ceil(n_sizes / 2)) + list(sigmas) + [sigmas[-1]] * int(n_sizes / 2)
        # Used for the distgrids, which are sorted <b, d, o, p, s, y, x> (because FFT works on innermost axes)
        return np.array(padded_sigmas).astype(np.float32)[None, None, None, None, :, None, None]
    
    def get_distgrids(self):
        # Computes the distance, spatially and in terms of log-wavelength, between two points
        # in <b, d, o, p, s, y, x> space, on a grid that is spatially twice the size (will be zero-padded)
        # in x and y, and twice the size in terms of scale as well (will be same-padded).
        y, x = np.mgrid[-box_height:box_height,
                        -box_width:box_width].astype(np.float32)
        r2grid = (y**2 + x**2)[None, None, None, None, None, :, :]
        sd = np.mgrid[-n_sizes:n_sizes].astype(np.float32)[None, None, None, None, :, None, None]
        sgrid = sd**2 
            # TODO: this may be incorrect; larger wavelengths may be less susceptible to
            # the neighbouring octaves (the absolute difference may count more than the logarithmic)
        return r2grid, sgrid

    def call(self, inputs):
        # Create the 4d Gaussian blur filter.
        # The total distance is computed based on r2grid *and* sgrid.
        # Sawada assumes that we can take the product of both filters.
        psppsf = tf.nn.softplus(self.spatial_pool_size_factor)
        s_spatial_filter = (tf.exp(-self.r2grid/(eps + psppsf*self.wavelengths**2)) / 
                            (tf.sqrt(2*np.pi)*(psppsf*self.wavelengths)**2 + eps))
                            # SPATIAL FILTER: self.wavelengths need to be twice as big

        pscpsf = tf.nn.softplus(self.scale_pool_size_factor)
        s_scale_filter = (tf.exp(-self.sgrid/(eps + pscpsf**2)) / (tf.sqrt(2*np.pi)*pscpsf**2) + eps)
                            
        s_filters = s_spatial_filter * s_scale_filter

        # The same filters, in frequency space
        f_filters = tf.signal.fft3d(tf.signal.fftshift(tf.complex(s_filters, tf.zeros_like(s_filters))))

        # Reshape inputs, so that <s, y, x> are the innermost dimensions, because fft3d works
        # on the innermost dims

        sigma_scale_factors = tf.exp(self.sigmas * self.rescale_factor)[None, :, None, None, None, None, None]
        rescaled_inputs = inputs * sigma_scale_factors

        r_s_inputs = tf.einsum("bsopyxd->bdopsyx", rescaled_inputs)

        # Pad the inputs (except for the phases, which we don't mind if the DFT wraps them)
        #pr_s_inputs1 = tf.pad(r_s_inputs,
        #                    [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
        #                    [int(np.ceil(box_height / 2)), int(box_height / 2)],
        #                    [int(np.ceil(box_width / 2)), int(box_width / 2)]], mode='CONSTANT')
        #pr_s_inputs = tf.pad(pr_s_inputs1,
        #                    [[0, 0], [0, 0], [0, 0], [0, 0],
        #                    [int(np.ceil(n_sizes / 2)), int(n_sizes / 2)],
        #                    [0, 0], [0, 0]], mode='CONSTANT') # we would use edge, but tf only has symmetric

        # Can use the above again as soon as TF2.1 comes out and supports padding above 6 dimensions ... jeez.
        # For now, just use the first entry on dimension 0, relying on batch_size = 1

        pr_s_inputs1 = tf.pad(r_s_inputs[0, :, :, :, :, :, :],
                           [[0, 0], [0, 0], [0, 0], [0, 0], 
                            [int(np.ceil(box_height / 2)), int(box_height / 2)],
                            [int(np.ceil(box_width / 2)), int(box_width / 2)]], mode='CONSTANT') #[None, :, :, :, :, :, :]
        pr_s_inputs = tf.pad(pr_s_inputs1, #[0, :, :, :, :, :, :],
                            [[0, 0], [0, 0], [0, 0],
                            [int(np.ceil(n_sizes / 2)), int(n_sizes / 2)],
                            [0, 0], [0, 0]], mode='CONSTANT')[None, :, :, :, :, :, :]

        # Convert inputs to frequency domain
        pr_f_inputs = tf.signal.fft3d(tf.complex(pr_s_inputs, tf.zeros_like(pr_s_inputs)))

        # Perform the filtering and convert back to space domain
        pr_s_filtered = tf.math.real(tf.signal.ifft3d(pr_f_inputs * f_filters))

        # Crop away the padding
        r_s_filtered = pr_s_filtered[:, :, :, :,
                                     int(np.ceil(n_sizes / 2)):int(n_sizes + np.ceil(n_sizes / 2)),
                                     int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                                     int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))]
        
        # Perform cross-orientation blurring
        r_s_ob_filtered = tf.einsum("bdkpsyx,kq->bdqpsyx", r_s_filtered,
                                    tf.nn.softplus(self.orientation_inhibition_matrix))
        r_s_obpb_filtered = tf.einsum("bdoksyx,kq->bdoqsyx", r_s_ob_filtered,
                                    tf.nn.softplus(self.phase_inhibition_matrix))

        # Reorder dimensions
        s_obpb_filtered = tf.einsum("bdopsyx->bsopyxd", r_s_obpb_filtered)
         
        return s_obpb_filtered

class ApplyCsf(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ApplyCsf, self).__init__(**kwargs)

        self.a = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(5.0),
                                       name='a',
                                       trainable=True)
        self.b = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(5.0),
                                       name='b',
                                       trainable=True)
        self.orientation_factors = self.add_weight(shape=(1, 1, n_orientations - 1, 1, 1, 1),
                                                   initializer=tf.keras.initializers.Constant(0.62),
                                                   name='orientation_factors', trainable=True)
        self.sigmas = get_sigmas()[None, :, None, None, None, None]

    def print_weights(self):
        a, b, of, s = tf.nn.softplus(self.a), tf.nn.softplus(self.b), tilop(self.orientation_factors), 10./self.sigmas
        factors = tf.exp(-((s-a)/b + tf.exp(-(s-a)/b) )) * of / of[0, 0, 0, 0, 0, 0]

        fig, ax = plt.subplots(1, 4)
        for oi in range(n_orientations):
            ax[oi].plot(self.sigmas[0, :, 0, 0, 0, 0], factors[0, :, oi, 0, 0, 0])
        plt.show()

    def call(self, inputs):
        a, b, of, s = tf.nn.softplus(self.a), tf.nn.softplus(self.b), tilop(self.orientation_factors), 10./self.sigmas
        factors = tf.exp(-((s-a)/b + tf.exp(-(s-a)/b) )) * of / of[0, 0, 0, 0, 0, 0]

        #factors = (a/b) * (self.sigmas/b)**(a-1) * tf.exp(-(self.sigmas/b)**a)
        #factors = tf.exp(-((self.sigmas-a)/b + tf.exp(-(self.sigmas-a)/b) ))
        print(factors.shape, "factorsshape")

        return inputs * factors

class Exponentiate(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(Exponentiate, self).__init__(**kwargs)

        #self.exponents = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1, 1),
        self.exponents = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                       initializer=tf.keras.initializers.Constant(2.5),
                                       name='exponents',
                                       trainable=False)

    def print_weights(self):
        print("Exponents:")
        plt.imshow(tilop(self.exponents)[0, :, :, 0, 0, 0])
        plt.colorbar()
        plt.show()

    def call(self, inputs):
        #rectified_inputs, absvals = inputs
        #pabsvals = absvals[:, :, :, None, :, :, :]
        #factor = (pabsvals + 1.e-6) ** ptilop(self.exponents) / (1.e-6 + pabsvals)
        #return (rectified_inputs + 1.e-6) * factor
        return (inputs + eps) ** tilop(self.exponents)

class DivisiveNormalization(tf.keras.layers.Layer):
    def __init__(self,  **kwargs):
        super(DivisiveNormalization, self).__init__(**kwargs)

        self.factors = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1, 1),
                                       initializer=tf.keras.initializers.Constant(2.46),
                                       name='dn_factors',
                                       trainable=True)
        self.beta = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1, 1),
                                       initializer=tf.keras.initializers.Constant(2.46),
                                       name='dn_beta',
                                       trainable=True)

    def print_weights(self):
        print("Factors:")
        plt.imshow(ptilop(self.factors)[0, :, :, 0, 0, 0, 0])
        plt.colorbar()
        plt.show()

        print("Beta:")
        plt.imshow(ptilop(self.beta)[0, :, :, 0, 0, 0, 0])
        plt.colorbar()
        plt.show()

    def call(self, inputs):
        stimulus, normalization_pool = inputs
        return ptilop(self.factors) * stimulus / (eps + ptilop(self.beta) + normalization_pool)



class PenalizeZero(tf.keras.layers.Layer):
    def __init__(self,  **kwargs):
        super(PenalizeZero, self).__init__(**kwargs)

        self.scale_exponent = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                              initializer=tf.keras.initializers.Constant(0.62),
                                              name='scale_exponent', trainable=False) 
        self.scale_beta = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                          initializer=tf.keras.initializers.Constant(0.62),
                                          name='scale_beta', trainable=False) 
        self.edge_loss_relative_factor = self.add_weight(shape=(),
                                          initializer=tf.keras.initializers.Constant(0.62),
                                          name='edge_loss_relative_factor', trainable=False) 
        self.scale_mean = self.add_weight(shape=(),
                                          initializer=tf.keras.initializers.Constant(6),
                                          name='scale_mean', trainable=False) 
        self.scale_sigma = self.add_weight(shape=(),
                                          initializer=tf.keras.initializers.Constant(2),
                                          name='scale_sigma', trainable=False) 

        self.sigmas = get_sigmas().astype(np.float32)[None, :, None, None, None, None]

    def print_weights(self):
        se, sb, elrs, mu, sigma = self.getw()
        print("Diff exponent:")
        fig, ax = plt.subplots(1, 4)
        for oi in range(n_orientations):
        #    gaussian_mask = tf.exp(-(self.sigmas - mu)**2/(2*tf.nn.softplus(sigma)**2))
        #    ax[oi].plot(self.sigmas[0, :, 0, 0, 0, 0], gaussian_mask[0, :, 0, 0, 0, 0], color='b')
        #    ax[oi].plot(self.sigmas[0, :, 0, 0, 0, 0], sb[0, :, oi, 0, 0, 0])
            ax[oi].plot(self.sigmas[0, :, 0, 0, 0, 0], se[0, :, oi, 0, 0, 0], linestyle='dotted')
        plt.show()
#
        #print("edge loss relative factor:", elrs.numpy())

    def getw(self):
        return (
            tilop(self.scale_exponent),
            tilop(self.scale_beta),
            (self.edge_loss_relative_factor),
            self.scale_mean,
            self.scale_sigma,
        )

    def call(self, inputs):
        # Flat query coordinates: <1, n_query_coordinates, 4>
        # We need to convert <b, s, o, y, x, d> 
        original_sums, diffs = inputs

        se, sb, elrs, mu, sigma = self.getw()

        gap_gains = (tf.nn.relu(diffs) + eps)
        edge_losses = (tf.nn.relu(-diffs) + eps)

        # Both gains and losses are masked with the same Gaussian.
        #gaussian_mask = tf.exp(-(self.sigmas - mu)**2/(2*tf.nn.softplus(sigma)**2))

        #gap_gains_hra = gaussian_mask * (gap_gains + eps) ** se #/ (eps + sb ** se + (gap_gains + eps) ** se)
        #edge_losses_hra = elrs * gaussian_mask * (edge_losses + eps) ** se #/ (eps + sb ** se + (edge_losses + eps) ** se)

        #no_vertical_gap = np.array([1., 0.5, 0., 0.5])[None, None, :, None, None, None]
        #sf_gaussian = tf.exp(-(self.sigmas - sf[:, 0:1, :, :, :, :])**2/(2*tf.nn.softplus(sf[:, 1:2, :, :, :, :]**2)))*tf.nn.softplus(sf[:, 2:3, :, :, :, :])
        #sf_gaussian = tf.exp(-(self.sigmas - self.sigmas[0, 6, 0, 0, 0, 0])**2/(2*tf.nn.softplus(sf[:, 1:2, :, :, :, :]**2)))*tf.nn.softplus(sf[:, 2:3, :, :, :, :])
        #gap_gaussian = tf.exp(-(self.sigmas - sf[:, 0:1, :, :, :, :])**2/(2*tf.nn.softplus(sf[:, 1:2, :, :, :, :]**2)))*tf.nn.softplus(gsf[:, 2:3, :, :, :, :])
        #gap_gaussian = tf.exp(-(self.sigmas - self.sigmas[0, 6, 0, 0, 0, 0])**2/(2*tf.nn.softplus(gsf[:, 1:2, :, :, :, :]**2)))*tf.nn.softplus(gsf[:, 2:3, :, :, :, :])

        penalties = (
        #    gap_gains_hra - edge_losses_hra
          (gap_gains + eps) - (edge_losses + eps)
        )
        return penalties

class TotalPenalty(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TotalPenalty, self).__init__(**kwargs)

        self.total_factor = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(1.),
                                       name='total_factor', trainable=False)
        self.total_exponent = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(1.1),
                                       name='total_exponent', trainable=False)
        self.total_beta = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(1.0),
                                       name='total_beta', trainable=False)

    def print_weights(self):
        tfa, te, tb = self.getw()
        print("Total factor:", tfa)
        print("Total exponent:", te)
        print("Total beta:", tb)

    def getw(self):
        return (
            tf.nn.softplus(self.total_factor),
            tf.nn.softplus(self.total_exponent),
            tf.nn.softplus(self.total_beta),
        )

    def call(self, inputs):
        # Now we want to add up all of the inputs, and feed them through a HRA
        total_penalty = tf.reduce_sum(inputs, axis=[1, 2, 3, 4]) + eps
        tfa, te, tb = self.getw()

        return total_penalty #tfa * (total_penalty + eps) ** te / (eps + tb ** te + (total_penalty + eps) ** te)


class DistanceEstimator(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DistanceEstimator, self).__init__(**kwargs)

    def call(self, inputs):
        y, x = inputs
        xdelta = (x[:, 1:] - x[:, :-1])
        ydelta = (y[:, 1:] - y[:, :-1]) # Positive when upward
    
        yrange = (tf.reduce_max(y, axis=[1], keepdims=True) - tf.reduce_min(y, axis=[1], keepdims=True)) + eps
        estimate_validities = nd_softmax(1e3 * y/yrange, axis=[1])
        estimated_distances = tf.reduce_sum(estimate_validities * x, axis=[1], name='estimated_distances')
        return estimated_distances

def get_model():
    shifted_gd1_filtered_images = tf.keras.Input(shape=full_shape, name='shifted_gd1_filtered_images', dtype=tf.complex64)
    shifted_gd2_filtered_images = tf.keras.Input(shape=full_shape, name='shifted_gd2_filtered_images', dtype=tf.complex64)

    # Go from <b, s, o, h, w, d> to <b, s, o, p, h, w, d>
    pr_gd1 = rectify_phases(shifted_gd1_filtered_images)
    pr_gd2 = rectify_phases(shifted_gd2_filtered_images)
    pr_pair = rectify_phases(shifted_gd1_filtered_images + shifted_gd2_filtered_images)

    # Then, exponentiate with a power.
    exp = Exponentiate()
    #pr_p_gd1 = tf.identity(exp([pr_gd1, tf.abs(shifted_gd1_filtered_images)]), name="pr_p_gd1")
    #pr_p_gd2 = tf.identity(exp([pr_gd2, tf.abs(shifted_gd2_filtered_images)]), name="pr_p_gd2")
    #pr_p_pair = tf.identity(exp([pr_pair, tf.abs(shifted_gd1_filtered_images + shifted_gd2_filtered_images)]), name="pr_p_pair")

    # Then, calculate the normalization pools
    #np = NormalizationPool()
    #pr_np_gd1 = tf.identity(np(pr_p_gd1), name="pr_np_gd1")
    #pr_np_gd2 = tf.identity(np(pr_p_gd2), name="pr_np_gd2")
    #pr_np_pair = tf.identity(np(pr_p_pair), name="pr_np_pair")

    # Then, perform the divisive normalization
    #dn = DivisiveNormalization()
    #pr_dn_gd1 = tf.identity(dn([pr_p_gd1, pr_np_gd1]), name="pr_dn_gd1")
    #pr_dn_gd2 = tf.identity(dn([pr_p_gd2, pr_np_gd2]), name="pr_dn_gd2")
    #pr_dn_pair = tf.identity(dn([pr_p_pair, pr_np_pair]), name="pr_dn_pair")

    # Then, compute the absolute energy values, because that's what we care about.
    ba = BiasedAbs()

    apply_csf = ApplyCsf()
    #e_dn_gd1 = tf.identity(ba(pr_dn_gd1), "e_dn_gd1")
    #e_dn_gd2 = tf.identity(ba(pr_dn_gd2), "e_dn_gd2")
    #e_dn_pair = tf.identity(ba(pr_dn_pair), "e_dn_pair")
    e_dn_gd1 = tf.identity(apply_csf(exp(ba((shifted_gd1_filtered_images)))), "e_dn_gd1")
    e_dn_gd2 = tf.identity(apply_csf(exp(ba((shifted_gd2_filtered_images)))), "e_dn_gd2")
    e_dn_pair = tf.identity(apply_csf(exp(ba((shifted_gd1_filtered_images + shifted_gd2_filtered_images)))), "e_dn_pair")

    #le_dn_gd1 = tf.math.imag(pr_dn_gd1)
    #le_dn_gd2 = tf.math.imag(pr_dn_gd2)
    #le_dn_pair = tf.math.imag(pr_dn_pair)
    #ee_dn_gd1 = tf.math.real(pr_dn_gd1)
    #ee_dn_gd2 = tf.math.real(pr_dn_gd2)
    #ee_dn_pair = tf.math.real(pr_dn_pair)

    #line_original_sums = tf.identity(tf.abs(le_dn_gd1 + le_dn_gd2), name="line_original_sums")
    #line_diffs = tf.identity(tf.abs(le_dn_pair) - line_original_sums, name="line_original_sums")

    # Then, compute the differences between pair and (gd1 + gd2)

    original_sums = tf.identity(e_dn_gd1 + e_dn_gd2, name="original_sums")
    diffs = tf.identity(e_dn_pair - original_sums, name="diffs")

    # Then, use a 4D polyharmonic spline to figure out the best way to penalize certain diffs
    # (Use <s, o, original_sum, diffs>) TODO: perhaps use 5D by including y-factor
    penalize = PenalizeZero()
    penalties = tf.identity(penalize([original_sums, diffs]), name="total_pixel_penalties")

    totalPenalty = TotalPenalty()
    d = tf.identity(totalPenalty(penalties), name="total_penalties")
    #d = tf.identity(tf.nn.elu((d[:, 1] - d[:, 0])) + tf.nn.elu((d[:, 1] - d[:, 2])), name="losses")

    total_gap_gains = tf.reduce_sum(tf.nn.relu(penalties), axis=[1, 2, 3, 4])
    total_edge_losses = tf.reduce_sum(tf.nn.relu(-penalties), axis=[1, 2, 3, 4])

    gap_gain_increase = tf.nn.elu(total_gap_gains[:, 0] - total_gap_gains[:, 1]) + tf.nn.elu(total_gap_gains[:, 1] - total_gap_gains[:, 2])
    edge_loss_decrease = tf.nn.elu(total_edge_losses[:, 1] - total_edge_losses[:, 0]) + tf.nn.elu(total_edge_losses[:, 2] - total_edge_losses[:, 1])

    d = tf.identity(d[:, 1]**2 + tf.nn.elu(d[:, 0] - d[:, 1]) + tf.nn.elu(d[:, 1] - d[:, 2]) + gap_gain_increase + edge_loss_decrease, name="losses") # Gap is negative
    #d = tf.reduce_sum(penalties, axis=[1, 2], name="pixel_penalties") # sum over orientations: <b, s1, o, h, w, d>
    #d = tf.identity(tf.reduce_sum(d, axis=[1, 2]), name="total_penalties") # / tf.reduce_sum(tf.nn.relu(-penalties), axis=[1,2,3,4]) # <b, h, w, d> → <b, d>

    sample_distances = tf.keras.Input(shape=(n_sample_distances), name='sample_distances')
    pair_images = tf.keras.Input(shape=(box_height, box_width, n_sample_distances), name='pair_images')
    zero_indices = tf.keras.Input(shape=(), name='zero_indices')
    desired_penalty_slope_sign = tf.keras.Input(shape=(n_sample_distances - 1), name='desired_penalty_slope_sign')


    # Format: [n_batch_size, n_distances]
    #increases = tf.identity(d[:, 1:] - d[:, :-1], name="increases") # The increase during every interval.
    # We need increases to be negative *before* the zero index, and positive after
    # zero_indices is shape [batch_size]
    #ki = (increases * [1, 1, 1, -1, -1, -1])
    #desired_slope_variation_penalty = tf.identity(tf.nn.elu(ki) + 1., name="dsvps") #(-desired_penalty_slope_sign

    # Step 10. Estimate best distance
    #predicted_ink_distances = DistanceEstimator()([d, sample_distances])

    # Step 10. Deviation from zero

    return tf.keras.Model(inputs=[shifted_gd1_filtered_images,
                                  shifted_gd2_filtered_images,
                                  sample_distances, pair_images, zero_indices, desired_penalty_slope_sign],
                                  outputs=(d))
                          #outputs=predicted_ink_distances)

#@tf.function
#def compute_loss(target_ink_distance, predicted_ink_distance):
#    return tf.sqrt(1.e-7 + tf.reduce_sum((target_ink_distance - predicted_ink_distance) ** 2) / batch_size, name='sqrt') #* 100 / factor

@tf.function
def compute_loss(_target, deviation):
    return tf.reduce_sum(deviation) #(tf.reduce_sum(deviation**2)) / batch_size


## Monitoring

In [0]:
from matplotlib.patches import Circle

class MonitorProgressCallback(tf.keras.callbacks.Callback):
    def __init__(self, data_generator):
        self.data_generator = data_generator
        self.model_inputs = [l.input for l in model.layers if isinstance(l, tf.keras.layers.InputLayer)]
        #print("MODEL INPUTS ARE", [(l.name, l.input.shape) for l in model.layers if isinstance(l, tf.keras.layers.InputLayer)])
        self.current_data = None

    def get_val(self, name):
        l = [l for l in model.layers if l.name.endswith(name)][0]
        #print("getting value", l.name, l.output.shape)
        #print("values", tf.keras.backend.function(self.model_inputs, [l.output])(self.current_data))
        output = tf.keras.backend.function(self.model_inputs, [l.output])(self.current_data)[0]
        return output

    def get_weights(self, name):
        l = [l for l in model.layers if l.name.endswith(name)][0]
        return l.get_weights()

    def print_weights(self, name):
        l = [l for l in model.layers if l.name.endswith(name)][0]
        l.print_weights()

    def on_test_batch_begin(self, batch_index, logs=None):
        dataset = self.data_generator[batch_index]
        current_data = dataset[0]
        self.current_data = current_data
        shifted_gd1_d1_filtered_images, shifted_gd2_d1_filtered_images, sample_distances, pair_images, zero_indices, desired_penalty_slope_sign = current_data

        iix = 0
        
        if False:
            plt.imshow(pair_images[iix, :, :, 0])
            plt.colorbar()
            plt.show()
            plt.imshow(pair_images[iix, :, :, zero_indices[iix]])
            plt.colorbar()
            plt.show()
            plt.imshow(pair_images[iix, :, :, 2])
            plt.colorbar()
            plt.show()

        if True:
            self.print_weights("apply_csf")
            self.print_weights("exponentiate")
            #self.print_weights("normalization_pool")
            #self.print_weights("divisive_normalization")
            #self.print_weights("biased_abs")
            self.print_weights("total_penalty")

        if False:
            e_gd1 = self.get_val("e_gd1")
            e_pair = self.get_val("e_pair")

            i_gd1 = self.get_val("shifted_gd1_filtered_images")
            i_gd2 = self.get_val("shifted_gd2_filtered_images")
            i_total_angle = tf.math.angle(i_gd1 + i_gd2)

            dn_gd1 = self.get_val("e_dn_gd1")
            dn_gd2 = self.get_val("e_dn_gd2")
            dn_pair = self.get_val("e_dn_pair")
            diffs = self.get_val("diffs")
            for si in range(n_sizes):
                fig, ax = plt.subplots(nrows=1, ncols=5, gridspec_kw={'wspace':0, 'hspace':0}, figsize=(4 * 4 * box_width / 100, 4 * 1 * box_height / 100))
                ax[0].imshow(tf.math.angle(i_gd1)[iix, si, 0, :, :, zero_indices[iix]])
                ax[1].imshow(i_total_angle[iix, si, 0, :, :, zero_indices[iix]])
                ax[2].imshow((dn_gd1 + dn_gd2)[iix, si, 0, :, :, zero_indices[iix]])
                ax[3].imshow(dn_pair[iix, si, 0, :, :, zero_indices[iix]])
                ax[4].imshow(diffs[iix, si, 0, :, :, zero_indices[iix]])
                plt.show()

        if False:
            increases = self.get_val("increases")
            dsvps = self.get_val("dsvps")
            #print("DPSS:", dsvps)
            pcs = self.get_val("total_penalties")
            for pix in range(batch_size):
                plt.plot(np.arange(n_sample_distances) - zero_indices[pix], pcs[pix, :])
                plt.plot(0.5 + np.arange(n_sample_distances - 1) - zero_indices[pix], dsvps[pix, :])
                plt.plot(0.5 + np.arange(n_sample_distances - 1) - zero_indices[pix], increases[pix, :])
            plt.show()

        if False:
            p = self.get_val("pr_dn_pair")
            pa = tf.math.angle(tf.complex(p[:, :, :, 0, :, :, :] - p[:, :, :, 2, :, :, :], p[:, :, :, 1, :, :, :] - p[:, :, :, 3, :, :, :]))
            pv = tf.abs(tf.complex(p[:, :, :, 0, :, :, :] - p[:, :, :, 2, :, :, :], p[:, :, :, 1, :, :, :] - p[:, :, :, 3, :, :, :]))
            for si in range(n_sizes):
                plt.imshow((pa)[iix, si, 0, :, :, zero_indices[iix]]) #, vmin=-lvmax, vmax=lvmax)
                plt.colorbar()
                plt.show()
                plt.imshow((pv)[iix, si, 0, :, :, zero_indices[iix]]) #, vmin=-lvmax, vmax=lvmax)
                plt.colorbar()
                plt.show()

        if False:
            if False:
                va = self.get_val("pr_np_gd1")
                size_factor = 5
                fig, ax = plt.subplots(nrows=n_sizes, ncols=4,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
                for si in range(n_sizes):
                    for pi in range(4):
                        ax[si, pi].imshow(va[iix, si, 0, pi, :, :, zero_indices[iix]], cmap='RdBu') #, vmin=-lvmax, vmax=lvmax)
                        ax[si, pi].set_xticklabels([])
                        ax[si, pi].set_yticklabels([])
                plt.show()
    
                va = self.get_val("pr_dn_gd1")
                size_factor = 5
                fig, ax = plt.subplots(nrows=n_sizes, ncols=4,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
                for si in range(n_sizes):
                    for pi in range(4):
                        ax[si, pi].imshow(va[iix, si, 0, pi, :, :, zero_indices[iix]], cmap='RdBu') #, vmin=-lvmax, vmax=lvmax)
                        ax[si, pi].set_xticklabels([])
                        ax[si, pi].set_yticklabels([])
                plt.show()
    
                va = self.get_val("e_dn_gd1")
                size_factor = 5
                fig, ax = plt.subplots(nrows=n_sizes, ncols=1,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
                for si in range(n_sizes):
                    ax[si].imshow(va[iix, si, 0, :, :, zero_indices[iix]], cmap='RdBu') #, vmin=-lvmax, vmax=lvmax)
                    ax[si].set_xticklabels([])
                    ax[si].set_yticklabels([])
                plt.show()

            va = self.get_val("e_dn_pair")
            pp = self.get_val("e_dn_gd1") + self.get_val("e_dn_gd2")
            size_factor = 5
            #fig, ax = plt.subplots(nrows=n_sizes, ncols=1,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
            for si in range(n_sizes):
                print("size", si, "pair then originalsums")
                plt.imshow(va[iix, si, 0, :, :, zero_indices[iix]]) #, vmin=-lvmax, vmax=lvmax)
                plt.colorbar()
                plt.show()
                plt.imshow(pp[iix, si, 0, :, :, zero_indices[iix]]) #, vmin=-lvmax, vmax=lvmax)
                plt.colorbar()
                plt.show()

        if False:
            if True:
                print("Normalization pool, pair, 2, 0:")
                pixel_penalties = tf.math.real(self.get_val("e_np_pair"))
                plt.imshow(pixel_penalties[iix, 2, 0, :, :, zero_indices[iix]])
                plt.colorbar()
                plt.show()
                print("Normalizataion pool, pair, total:")
                pixel_penalties = tf.math.real(self.get_val("e_np_pair"))
                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, 0, :, :, zero_indices[iix]], axis=[0]))
                plt.colorbar()
                plt.show()
            if True:
                print("Divisive Normalization, pair, 2, 0")
                pixel_penalties = tf.math.real(self.get_val("e_dn_pair"))
                plt.imshow(pixel_penalties[iix, 2, 0, :, :, zero_indices[iix]])
                plt.colorbar()
                plt.show()
                print("Divisive normalization, pair, total:")
                pixel_penalties = tf.math.real(self.get_val("e_dn_pair"))
                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, 0, :, :, zero_indices[iix]], axis=[0]))
                plt.colorbar()
                plt.show()
            if True:
                print("Original sums, pair, 2, 0")
                pixel_penalties = tf.math.real(self.get_val("original_sums"))
                plt.imshow(pixel_penalties[iix, 2, 0, :, :, zero_indices[iix]])
                plt.colorbar()
                plt.show()
                print("Original sums, pair, total:")
                pixel_penalties = tf.math.real(self.get_val("original_sums"))
                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, 0, :, :, zero_indices[iix]], axis=[0]))
                plt.colorbar()
                plt.show()
            if True:
                print("Diffs, pair, 2, 0")
                pixel_penalties = tf.math.real(self.get_val("diffs"))
                plt.imshow(pixel_penalties[iix, 2, 0, :, :, zero_indices[iix]])
                plt.colorbar()
                plt.show()
        if False:
            diffs = tf.math.real(self.get_val("diffs"))
            for di in range(n_sample_distances):
                for si in range(n_sizes):
                    print("distance", di, "size", si)
                    plt.imshow(diffs[iix, si, 0, :, :, di])
                    plt.colorbar()
                    plt.show()
        if True:
            print("Diffs, pair, total:")
            diffs = tf.math.real(self.get_val("diffs"))
            plt.imshow(-tf.reduce_sum(diffs[iix, :, :, :, :, zero_indices[iix]], axis=[0, 1]))
            plt.colorbar()
            plt.show()
            print("Penalties:")
            pixel_penalties = tf.math.real(self.get_val("total_pixel_penalties"))
            total_penalties = tf.math.real(self.get_val("total_penalties"))
            losses = tf.math.real(self.get_val("losses"))
            sigmas = get_sigmas()

            plt.plot(sigmas, np.zeros_like(sigmas), color='k')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(diffs[iix, :, :, :, :, 0]), axis=[1,2,3]), color='r')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-diffs[iix, :, :, :, :, 0]), axis=[1,2,3]), color='r')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(diffs[iix, :, :, :, :, 1]), axis=[1,2,3]), color='b')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-diffs[iix, :, :, :, :, 1]), axis=[1,2,3]), color='b')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(diffs[iix, :, :, :, :, 2]), axis=[1,2,3]), color='g')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-diffs[iix, :, :, :, :, 2]), axis=[1,2,3]), color='g')
            plt.show()

            plt.plot(sigmas, np.zeros_like(sigmas), color='k')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(pixel_penalties[iix, :, :, :, :, 0]), axis=[1,2,3]), color='r')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-pixel_penalties[iix, :, :, :, :, 0]), axis=[1,2,3]), color='r')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(pixel_penalties[iix, :, :, :, :, 1]), axis=[1,2,3]), color='b')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-pixel_penalties[iix, :, :, :, :, 1]), axis=[1,2,3]), color='b')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(pixel_penalties[iix, :, :, :, :, 2]), axis=[1,2,3]), color='g')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-pixel_penalties[iix, :, :, :, :, 2]), axis=[1,2,3]), color='g')
            plt.show()

            for di in range(n_sample_distances):
                print("DISTANCE INDEX", di)
                if (di == zero_indices[iix]):
                    print("best distance:")
                vm = max(tf.reduce_max(pixel_penalties[iix, :, :, :, :, di]), -tf.reduce_min(pixel_penalties[iix, :, :, :, :, di]))
                print("max:", vm, "sum:", tf.reduce_sum(pixel_penalties[iix, :, :, :, :, di]), "total_sum:", total_penalties[iix, di], "losses:", losses[iix])
                plt.imshow(pair_images[iix, :, :, di], alpha=0.7, cmap='gray')
                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, :, :, :, di], [0, 1]), alpha=0.7)
                plt.colorbar()
                plt.show()
                if di == 1 and True:
                    for si in range(n_sizes):
                        print("SIZE", si)
                        fig, ax = plt.subplots(1, 2,  gridspec_kw = {'wspace':0, 'hspace':0},  figsize=(5 * 2 * box_width / 100, 5 * box_height / 100))
                        ax[0].imshow(pair_images[iix, :, :, di], alpha=0.7, cmap='gray')
                        pp = tf.reduce_sum(pixel_penalties[iix, si, :, :, :, di], [0])
                        aim = ax[0].imshow(pp, alpha=0.7)
                        max_y, max_x = np.unravel_index(pp.numpy().argmax(), pp.shape)
                        min_y, min_x = np.unravel_index(pp.numpy().argmin(), pp.shape)
                        ax[0].add_patch(Circle((max_x,max_y),sigmas[si]*4,edgecolor='w',facecolor=None,fill=False))
                        ax[0].add_patch(Circle((min_x,min_y),sigmas[si]*4,edgecolor='w',facecolor=None,fill=False))
                        fig.colorbar(aim, ax=ax[0])
                        ax[1].imshow(pair_images[iix, :, :, di], alpha=0.7, cmap='gray')
                        dim = ax[1].imshow(tf.reduce_sum(diffs[iix, si, :, :, :, zero_indices[iix]], axis=[0]))
                        fig.colorbar(dim, ax=ax[1])
                        plt.show()

                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, :, :, :, di], [2, 3]))
                plt.colorbar()
                plt.show()


        if False:
            print("Sigmas")
            sigmas = self.get_weights("SpatialAverage")[0]
            plt.plot(sigmas[0, 0, :, 0, 0])
            plt.show()

        if False:
            print("After s/o dense convolution: lines, edges")
            blurred = self.get_val("SpatialAverage")
            subbed = self.get_val("pow")
            hra_total = self.get_val("hr_atotal")
            size_factor = 2
            fig, ax = plt.subplots(nrows=n_sizes, ncols=4,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
            for si in range(n_sizes):
                ax[si, 0].imshow(tf.reduce_sum(abstotal, axis=[2])[iix, si, :, :, zero_indices[iix]])
                ax[si, 0].set_xticklabels([])
                ax[si, 0].set_yticklabels([])
                ax[si, 1].imshow(blurred[iix, si, :, :, zero_indices[iix]])
                ax[si, 1].set_xticklabels([])
                ax[si, 1].set_yticklabels([])
                ax[si, 2].imshow(subbed[iix, si, :, :, zero_indices[iix]])
                ax[si, 2].set_xticklabels([])
                ax[si, 2].set_yticklabels([])
                ax[si, 3].imshow(hra_total[iix, si, :, :, zero_indices[iix]])
                ax[si, 3].set_xticklabels([])
                ax[si, 3].set_yticklabels([])

                print("size", si, tf.reduce_sum(hra_total[iix, si, :, :, zero_indices[iix]]))
            plt.show()

        if False:
            hravars = self.get_weights("hr_atotal")
            print("hravars", hravars)
            print("Hyperbolic ratio variables.")
            print("Exponents")
            plt.plot(tf.nn.softplus(hravars[0][0, :, 0, 0, 0]))
            #plt.colorbar()
            plt.show()
            print("Alphas")
            plt.plot(tf.nn.softplus(hravars[1][0, :, 0, 0, 0]))
            #plt.colorbar()
            plt.show()
            print("m-scale")
            plt.plot(hravars[2][0, :, 0, 0, 0])
            #plt.colorbar()
            plt.show()
    
            print("Total for each pixel")
            out = self.get_val("Sum_1")
            plt.imshow(out[iix, :, :, 0])
            plt.colorbar()
            plt.show()
            plt.imshow(out[iix, :, :, zero_indices[iix]])
            plt.colorbar()
            plt.show()
            plt.imshow(out[iix, :, :, -1])
            plt.colorbar()
            plt.show()
            #for i in range(n_sample_distances):
            #    plt.imshow(tf.reduce_sum(out, axis=[3,4], keepdims=True)[iix, :, :, 0, 0, i])
            #    plt.colorbar()
            #    plt.show()
            #    print("Index:", i, "total is", tf.reduce_sum(out, axis=[1,2,3,4])[iix, i])
        if False:
            # Display what the penalties look like
            pass

        if True:
            self.print_weights("penalize_zero")

        if False:             
            print("Diglyphiness values")
            out = self.get_val("Sum")
            pred = self.get_val("distance_estimator")
            print("OUT PRED", out.shape, pred.shape)
            plt.plot(sample_distances[iix, :], out[iix, :])
            print("Pred distance is", pred[iix], "correct distance is", sample_distances[iix, zero_indices[iix]])
            plt.scatter([pred[iix]], [0])
            plt.show()


## Keras pipeline setup

In [0]:
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()

ig = InputGenerator(batch_size)
if True:
    model = get_model()
    model.compile(loss=compute_loss,
                optimizer=tf.keras.optimizers.Adam(0.15))
    #model.summary()
    history = model.fit_generator(ig,
                                callbacks=[MonitorProgressCallback(ig)],
                                validation_data=ig,
                                validation_steps=1,
                                validation_freq=10,
                                epochs=3000,
                                steps_per_epoch=10, use_multiprocessing=False)
ig.kill()
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()

In [0]:
from fontParts.world import *
import string
font = OpenFont("DroidSerif.ttf", showInterface=False)
print(dir(font.kerning))
for gl in string.ascii_lowercase:
    #lsb, rsb = yourletterfitter.find_sidebearings(g)
    font.layers[0][gl].leftMargin = 0#lsb
    font.layers[0][gl].rightMargin = 0#rsb
    for gr in string.ascii_lowercase:
        font.kerning._setItem((gl, gr), 10)
font.save("Autospaced.otf") 
