<a href="https://colab.research.google.com/github/skosch/YinYangFit/blob/master/YinYangFit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set up TF2 and import dependencies

In [0]:
try:
    %tensorflow_version 2.x
except Exception:
    pass

import tensorflow as tf
import os

print("TF version:", tf.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")

if tf.test.is_gpu_available():
    device_name = tf.test.gpu_device_name()
    if device_name != '/device:GPU:0':
      raise SystemError('GPU device not found')
    print('Found GPU at: {}'.format(device_name))
else:
    tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
    
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(cluster_resolver)
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
    tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)


In [0]:
import itertools
import os

import numpy as np
pi = np.pi

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm
from mpl_toolkits.mplot3d import Axes3D
from scipy.ndimage.interpolation import affine_transform
import tensorflow as tf
import random; random.seed()
import math
import pickle
import os
from tqdm import tqdm as tqdm
import sys
from functools import reduce
import random
from itertools import cycle, islice, product
import operator

!pip install --quiet --upgrade git+git://github.com/simoncozens/tensorfont.git
!pip install --quiet fonttools
!pip install --upgrade git+git://github.com/simoncozens/fontParts.git@d444bde6e2a0adbcd9a16593a615a99823089c70
!pip install booleanOperations
!pip install --quiet --upgrade ufo-extractor
!pip install --quiet --upgrade defcon
!pip install --quiet --upgrade ufo2ft
import fontParts
import extractor
import defcon
from ufo2ft import compileOTF

from tensorfont import Font

print("✓ Dependencies imported.")



### Download font files

In [0]:
#!wget -q -O OpenSans-Regular.ttf https://github.com/googlefonts/opensans/blob/master/ttfs/OpenSans-Regular.ttf?raw=true
#!wget -q -O Roboto.ttf https://github.com/google/fonts/blob/master/apache/roboto/Roboto-Regular.ttf?raw=true
#!wget -q -O Roboto.otf https://github.com/AllThingsSmitty/fonts/blob/master/Roboto/Roboto-Regular/Roboto-Regular.otf?raw=true
#!wget -q -O DroidSerif.ttf https://github.com/datactivist/sudweb/blob/master/fonts/droid-serif-v6-latin-regular.ttf?raw=true
!wget -q -O CrimsonItalic.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Italic.otf?raw=true
#!wget -q -O CrimsonBold.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Bold.otf?raw=true 
#!wget -q -O CrimsonRoman.otf https://github.com/alif-type/amiri/blob/master/Amiri-Regular.ttf?raw=true

!wget -q -O CrimsonRoman.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Roman.otf?raw=true
print("✓ Font file(s) downloaded.")

## Load font data and set up global parameters

In [0]:
glyph_char_list = "abcdeghijlmnopqrstuzywvxkf"
#glyph_char_list = "abgjqrst"
#glyph_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#glyph_char_list = "OO"

# ==== Create Font ====
factor = 1.14 #1.539  # This scales the size of everything
filename = "CrimsonRoman.otf"
f = Font(filename, 34 * factor) # Roboto.ttf CrimsonRoman.otf # 34 for lowercase
box_height = f.full_height_px
box_width = int(121 * factor)
box_width += (box_width + 1) % 2
print("Box size:", box_height, "×", box_width)

# 37067520 allocated, maxAllocSize: 873707520

batch_size = 1  # must be divisible by 8 to work on TPU
n_sample_distances = 3 # should be an odd number

n_sizes = 18
n_pseudo_orientations = 4
n_orientations = 4

## Create Gabor filter bank

In [0]:
def get_sigmas(skip_scales=0):
    sigmas = []
    for s in range(n_sizes):
        min_sigma = 0.6
        max_sigma = box_width / 12
        #sigmas.append((max_sigma - min_sigma) * (s + skip_scales)**2 / (n_sizes - 1)**2 + min_sigma)
        sigmas.append((max_sigma - min_sigma) * s / n_sizes + min_sigma)
    return np.array(sigmas)
print("Sigmas are", get_sigmas())

def get_3n_filters(skip_scales, display_filters=False):
    def rotated_mgrid(oi):
        """Generate a meshgrid and rotate it by RotRad radians."""
        rotation = np.array([[ np.cos(pi*oi/n_orientations), np.sin(pi*oi/n_orientations)],
                             [-np.sin(pi*oi/n_orientations), np.cos(pi*oi/n_orientations)]])
        hh = box_height # / 2
        bw = box_width # / 2
        y, x = np.mgrid[-hh:hh, -bw:bw].astype(np.float32)
        y += 0.5 # 0 if box_height % 2 == 0 else 0.5
        x += 0.5 # 0 if box_width % 2 == 0 else 0.5
        return np.einsum('ji, mni -> jmn', rotation, np.dstack([x, y]))

    def get_filter(sigma, theta):
        x, y = rotated_mgrid(theta)

        # To minimize ringing etc., we create the filter as is, then run it through the DFT.
        a1 = 0.25 # See Georgeson et al. 2007
        s1 = sigma #a1 * sigma
        d1_space = -np.exp(-(x**2+y**2)/(2*s1**2))*x/(2*pi*s1**4)
        d1 = np.fft.fft2(d1_space + 1j * np.zeros_like(d1_space)) #, [box_height, box_width])

        # Second derivative:
        s2 = sigma #np.sqrt(1. - a1**2) * sigma # See Georgeson et al. (2007)
        d2_space = np.exp(-(x**2+y**2)/(2*s2**2))/(2*pi*s2**4) - np.exp(-(x**2+y**2)/(2*s2**2))*x**2/(2*pi*s2**6)
        d2 = sigma**1.5 * np.fft.fft2(d2_space + 1j * np.zeros_like(d2_space)) #, [box_height, box_width])

        # For now: d1 is complex(-d2,d1)
        d1c = 1j * (d2 + 1j*d1) #/ sigma #np.sqrt(sigma)
        return (d1c, d2)

    d1_bank = np.zeros((n_sizes, n_orientations, 2*box_height, 2*box_width)).astype(np.complex64)
    d2_bank = np.zeros((n_sizes, n_orientations, 2*box_height, 2*box_width)).astype(np.complex64)

    if display_filters:
        sizediv = 60
        fig, ax = plt.subplots(nrows=n_sizes*2, ncols=n_orientations, gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(box_width * n_orientations / sizediv, box_height * n_sizes * 2 / sizediv))

    sigmas = get_sigmas()
    for s in range(n_sizes):
        sigma = sigmas[s]
        for o in range(n_orientations):
            (d1, d2) = get_filter(sigma, o)
            if display_filters:
                #ax[s*2, o].imshow(np.real(np.fft.ifft2(d1)), cmap="inferno")
                ax[s*2, o].imshow(np.abs(np.fft.fftshift(d1)), cmap="inferno")
                ax[s*2, o].set_aspect("auto")
                ax[s*2, o].set_yticklabels([])
                ax[s*2+1, o].imshow(np.imag(np.fft.ifft2(d1)), cmap="inferno")
                ax[s*2+1, o].set_aspect("auto")
                ax[s*2+1, o].set_yticklabels([])
            d1_bank[s, o, :, :] = d1
            d2_bank[s, o, :, :] = d2

    if display_filters:
        plt.show()

    return (d1_bank.astype(np.complex64), d2_bank)

d1_filter_bank, d2_filter_bank = get_3n_filters(0, False)

## Rasterize the glyphs into numpy arrays, and extract their ink widths in pixels


In [0]:
def apply_filter_bank(input_image, filter_bank):
    """
    Input image should have dimensions <h, w> or <s, o, h, w> or <b, s, o, h, w, d>.
    Filter bank should have dimensions <s, o, h, w>
    """
    if len(input_image.shape) == 2:
        bdsohw_input_image = input_image[None, None, None, None, :, :]
    elif len(input_image.shape) == 4:
        bdsohw_input_image = input_image[None, None, :, :, :, :]
    elif len(input_image.shape) == 6:
        bdsohw_input_image = tf.einsum("bsohwd->bdsohw", input_image)

    # pad image to filter size, which is 2*box_height, 2*box_width (to prevent too much wrapping)
    padded_input = tf.pad(bdsohw_input_image, [[0, 0], [0, 0], [0, 0], [0, 0],
                            [int(np.ceil(box_height / 2)), int(box_height / 2)],
                            [int(np.ceil(box_width / 2)), int(box_width / 2)]], mode='CONSTANT')

    input_in_freqdomain = tf.signal.fft2d(tf.dtypes.cast(tf.complex(padded_input, tf.zeros_like(padded_input)), tf.complex64))

    padded_result = (tf.signal.ifft2d(input_in_freqdomain * filter_bank[None, None, :, :, :, :]))

    if len(input_image.shape) == 2:
        presult = tf.signal.fftshift(padded_result[0, 0, :, :, :, :], axes=[2, 3])
        # Return <s, o, h, w>
        return (presult[:, :, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                           int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))],
                input_in_freqdomain)
    elif len(input_image.shape) == 4:
        presult = tf.signal.fftshift(padded_result[0, 0, :, :, :, :], axes=[2, 3])
        # Return <s, o, h, w>
        return (presult[:, :, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                           int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))],
                input_in_freqdomain)
    elif len(input_image.shape) == 6:
        presult = tf.einsum("bdsohw->bsohwd", tf.signal.fftshift(padded_result, axes=[2, 3]))
        return (presult[:, :, :, :,
                           int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                           int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))],
                input_in_freqdomain)

def get_glyph_data_with_filtered_as_dict(glyph_char):
    """
    Returns a dict containing relevant glyph data, including filtered images.
    @param glyph_char: string of length 1
    """
    glyph_image = f.glyph(glyph_char).as_matrix(normalize=True).with_padding_to_constant_box_width(box_width).astype(np.float32)
    filtered_image, fft_image = apply_filter_bank(glyph_image, d1_filter_bank)

    return {
        'glyph_char': glyph_char,
        'glyph_image': glyph_image,
        'glyph_ink_width': f.glyph(glyph_char).ink_width,
        'glyph_d1_filtered_images': filtered_image,
    }

def get_sample_distances_and_translations(gd1, gd2, target_ink_distance, distances=None):
    """Returns a list of distances at which the box images have to be shifted left and right before they can be overlaid to sample their normalized interaction"""
    total_width_at_minimum_ink_distance = gd1['glyph_ink_width'] + gd2['glyph_ink_width'] - f.minimum_ink_distance(gd1['glyph_char'], gd2['glyph_char'])
    
    if distances is None:
        relative_sample_distances = [0]
        simax = int((n_sample_distances - 1)/2)
        
        pos_ad = +2 #/ 2.
        neg_ad = -2 #(target_ink_distance - 2)
        next_pos = pos_ad
        next_neg = neg_ad
        for si in range(simax):
            # always append a positive, and then ...
            relative_sample_distances.append(next_pos)
            next_pos += pos_ad
            pos_ad *= 1.
            # ... append a negative only if there is room
            if target_ink_distance + next_neg >= -3 or True:
                relative_sample_distances.append(next_neg)
                next_neg += neg_ad
            else:
                relative_sample_distances.append(next_pos)
                next_pos += pos_ad
                pos_ad *= 1.
        
        relative_sample_distances.sort()
        zero_index_val = relative_sample_distances.index(0)
        sample_distances = (np.array(relative_sample_distances) + target_ink_distance)
        
    else:
        relative_sample_distances = distances
        zero_index_val = 0   # doesn't really have a meaning in this case
        sample_distances = np.array(distances) #+ f.minimum_ink_distance(gd1['glyph_char'], gd2['glyph_char'])

    sample_distances_left = np.ceil(sample_distances / 2)
    sample_distances_right = np.floor(sample_distances / 2)
    
    total_ink_width = gd1['glyph_ink_width'] + gd2['glyph_ink_width']
    ink_width_left = np.floor(total_ink_width / 4)
    ink_width_right = np.ceil(total_ink_width / 4)
    
    left_translations = (-(np.ceil(total_width_at_minimum_ink_distance/2) + sample_distances_left) - (-ink_width_left)).astype(np.int32)
    right_translations = ((np.floor(total_width_at_minimum_ink_distance/2) + sample_distances_right) - ink_width_right).astype(np.int32)

    return {
        'sample_distances': sample_distances,
        'relative_sample_distances': relative_sample_distances,
        'left_translations': left_translations,
        'right_translations': right_translations,
        'zero_index': zero_index_val,
    }

def shift_sohw1_into_sohwd(input_images, translations):
    """Shifts images to left/right and back-fills with zeros.
    @param images: <sizes, orientations, height, width, 1>
    @param translations: <len(translations)>
    @output        <sizes, orientations, height, width, len(translations)>
    """
    images = tf.tile(input_images, [1, 1, 1, 1, translations.shape[0]]) # create len(shifts) channel copies
    fill_constant = 0
    left = tf.maximum(0, tf.reduce_max(translations)) # positive numbers are shifts to the right, for which we need to add zeros on the left
    right = -tf.minimum(0, tf.reduce_min(translations)) # negative numbers are shifts to the left, for which we need to add zeros on the right
    left_mask = tf.ones(shape=(tf.shape(images)[0], tf.shape(images)[1], tf.shape(images)[2], left, tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    right_mask = tf.ones(shape=(tf.shape(images)[0], tf.shape(images)[1], tf.shape(images)[2], right, tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    padded_images = tf.concat([left_mask, images, right_mask], axis=3) # pad on axis 3 (i.e. width-wise)

    # Now that the images are all padded, we need to crop them to implement the shifts.
    def crop_image_widthwise(image_and_shift):
        image = image_and_shift[0] # sohw
        shift = image_and_shift[1] # 
        return image[:, :, :, left-shift:left-shift+input_images.shape[3]] # positive shift: left-shift

    return tf.einsum("dsohw->sohwd", tf.map_fn(
        crop_image_widthwise,
        (tf.einsum("sohwd->dsohw", padded_images), translations),
        dtype=images.dtype))

def shift_hwso1_into_hwsod(input_images, translations):
    """Shifts images to left/right and back-fills with zeros.
    @param images: <height, width, sizes, orientations, 1>
    @param translations: <len(translations)>
    @output        <height, width, sizes, orientations, len(translations)>
    """
    images = tf.tile(input_images, [1, 1, 1, 1, translations.shape[0]]) # create len(shifts) channel copies
    fill_constant = 0
    left = tf.maximum(0, tf.reduce_max(translations)) # positive numbers are shifts to the right, for which we need to add zeros on the left
    right = -tf.minimum(0, tf.reduce_min(translations)) # negative numbers are shifts to the left, for which we need to add zeros on the right
    left_mask = tf.ones(shape=(tf.shape(images)[0], left, tf.shape(images)[2], tf.shape(images)[3], tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    right_mask = tf.ones(shape=(tf.shape(images)[0], right, tf.shape(images)[2], tf.shape(images)[3], tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    padded_images = tf.concat([left_mask, images, right_mask], axis=1) # pad on axis 2 (i.e. width-wise)

    # Now that the images are all padded, we need to crop them to implement the shifts.
    def crop_image_widthwise(image_and_shift):
        image = image_and_shift[0]
        shift = image_and_shift[1]
        return image[:, left-shift:left-shift+input_images.shape[1], :, :] # positive shift: left-shift

    return tf.einsum("dhwso->hwsod", tf.map_fn(
        crop_image_widthwise,
        (tf.einsum("hwsod->dhwso", padded_images), translations),
        dtype=images.dtype))

def shift_fft_image(input_image, translations):
    # Input is directly from apply_filter_bank, in format <bdsohw>, padded to 2*box_height, 2*box_width

    xshift_array = np.mgrid[-box_width:box_width][None, None, None, None, :] * translations[:, None, None, None, None] / (2 * box_width)
    tiled_centered_input_image = tf.tile(tf.signal.fftshift(input_image, [3, 4]), [translations.shape[0], 1, 1, 1, 1])
    shifted_images = tf.signal.ifftshift(tiled_centered_input_image * np.exp(-2j * np.pi * xshift_array), [3, 4])

    return tf.einsum("dsohw->sohwd", shifted_images)

def shift_and_overlay_pair_data(gd1, gd2, distances=None):
    """
    Returns a 5D tensor <box_height, box_width, sizes, orientations, distances>.
    """

    target_ink_distance = int(f.pair_distance(gd1['glyph_char'], gd2['glyph_char']) + f.minimum_ink_distance(gd1['glyph_char'], gd2['glyph_char']))
    sdt = get_sample_distances_and_translations(gd1, gd2, target_ink_distance, distances)

    shifted_gd1_d1_filtered_images = shift_sohw1_into_sohwd(gd1['glyph_d1_filtered_images'][..., None], sdt['left_translations'])
    shifted_gd2_d1_filtered_images = shift_sohw1_into_sohwd(gd2['glyph_d1_filtered_images'][..., None], sdt['right_translations'])

    # We want to shift both the original images (for display purposes only), as well as the filtered images.
    pair_images = (shift_hwso1_into_hwsod(gd1['glyph_image'][..., None, None, None], sdt['left_translations']) + 
                   shift_hwso1_into_hwsod(gd2['glyph_image'][..., None, None, None], sdt['right_translations']))[:, :, 0, 0, :]
    zero_index = sdt['zero_index']
    sample_distances = sdt['sample_distances']

    return  {
        'shifted_gd1_d1_filtered_images': shifted_gd1_d1_filtered_images,
        'shifted_gd2_d1_filtered_images': shifted_gd2_d1_filtered_images,
        'ink_distance': target_ink_distance,
        'pair_images': pair_images,
        'zero_index': zero_index,
        'sample_distances': sample_distances,
    }

class InputGenerator(tf.keras.utils.Sequence):
    def __init__(self, batch_size):
        self.batch_size = batch_size
        print("Creating glyph images ...", flush=True)
        self.glyph_data = []
        for glyph_char in tqdm(glyph_char_list):
            self.glyph_data.append(get_glyph_data_with_filtered_as_dict(glyph_char))
        self.n_pairs = len(glyph_char_list) ** 2
        self.cached_pair_data = {}

    def kill(self):
        del self.glyph_data
        del self.cached_pair_data

    def __len__(self):
        """Number of batches in the dataset"""
        return math.ceil(self.n_pairs / self.batch_size)

    def __getitem__(self, idx):
        """Return the content of batch idx.
        Instead of always providing the same data for batch i,
        we just pick batch_size random glyph pairs and return their glyph data.

        This is run on the CPU, because otherwise the calculations sit in GPU memory,
        are never released, and lead to annoying out-of-memory issues all the time.

        Output: 
        ([shifted_gd1_filtered_images, shifted_gd2_filtered_images, sample_distances], [ink_distance, pair_images, zero_index])
        """
        with tf.device('/CPU:0'):
            g_shifted_gd1_d1_filtered_images = []
            g_shifted_gd2_d1_filtered_images = []
            g_sample_distances = []
            g_ink_distance = []
            g_pair_images = []
            g_zero_index = []
            for i in range(batch_size):
                while True:
                    g1 = random.choice(self.glyph_data)
                    g2 = random.choice(self.glyph_data)
                    if f.pair_distance(g1['glyph_char'], g2['glyph_char']) + f.minimum_ink_distance(g1['glyph_char'], g2['glyph_char']) > 2:
                        break
                #if (g1['glyph_char'] + g2['glyph_char']) not in self.cached_pair_data:
                #    self.cached_pair_data[g1['glyph_char'] + g2['glyph_char']] = shift_and_overlay_pair_data(g1, g2)
                #cpd = self.cached_pair_data[g1['glyph_char'] + g2['glyph_char']]

                # Don't cache the data -- does that save RAM?
                cpd = shift_and_overlay_pair_data(g1, g2)

                g_shifted_gd1_d1_filtered_images.append(cpd['shifted_gd1_d1_filtered_images'])
                g_shifted_gd2_d1_filtered_images.append(cpd['shifted_gd2_d1_filtered_images'])
                g_sample_distances.append(cpd['sample_distances'])
                g_ink_distance.append(cpd['ink_distance'] * 1.0)
                g_pair_images.append(cpd['pair_images'])
                g_zero_index.append(cpd['zero_index'])
    
            inputs = [
                tf.stack(g_shifted_gd1_d1_filtered_images),
                tf.stack(g_shifted_gd2_d1_filtered_images),
                tf.stack(g_sample_distances),
                tf.stack(g_pair_images),
                tf.stack(g_zero_index),
            ]   
            outputs = tf.stack(g_ink_distance)
    
            return inputs, outputs 


## Model evaluator

In [0]:
tf.keras.backend.clear_session()  # For easy reset of notebook state.
full_shape = ( n_sizes, n_orientations, box_height, box_width, n_sample_distances)

eps = np.finfo(np.float32).tiny

@tf.function
def nd_softmax(target, axis, name=None):
    max_axis = tf.reduce_max(target, axis, keepdims=True)
    target_exp = tf.exp(target - max_axis)
    normalize = tf.reduce_sum(target_exp, axis, keepdims=True)
    softmax = target_exp / (normalize + eps)
    return softmax


@tf.function
def tilo(t):
    return tf.concat([t[:, :, 0:1, :, :, :], t[:, :, 1:2, :, :, :], t[:, :, 2:3, :, :, :], t[:, :, 1:2, :, :, :]], axis=2)
@tf.function
def tilop(t):
    return tf.nn.softplus(tilo(t))
@tf.function
def ptilo(t):
    return tf.concat([t[:, :, 0:1, :, :, :, :], t[:, :, 1:2, :, :, :, :], t[:, :, 2:3, :, :, :, :], t[:, :, 1:2, :, :, :, :]], axis=2)
@tf.function
def ptilop(t):
    return tf.nn.softplus(ptilo(t))
@tf.function
def tiloa(t):
    return tf.concat([t[:, :, :, :, 0:1, :, :], t[:, :, :, :, 1:2, :, :], t[:, :, :, :, 2:3, :, :], t[:, :, :, :, 1:2, :, :]], axis=4)
@tf.function
def tiloap(t):
    return tf.nn.softplus(tiloa(t))
def invsp(t):
    if t == 0:
        return -1e10
    else:
        return np.log(np.exp(t) - 1).astype(np.float32)

def rectify_phases(inputs):
    # Inputs: <b, s, o, h, w, d>
    # Output: <b, s, o, p, h, w, d> where p is [0, 1, 2, 3]
    return tf.stack([tf.nn.relu(tf.math.real(inputs)),
                     tf.nn.relu(tf.math.imag(inputs)),
                     tf.nn.relu(-tf.math.real(inputs)),
                     tf.nn.relu(-tf.math.imag(inputs))], axis=3)

class BiasedAbs(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(BiasedAbs, self).__init__(**kwargs)

        self.bias_weights = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                       initializer=tf.keras.initializers.Constant(0.),
                                       name='bias_weights',
                                       trainable=True)
        self.exponent = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                       initializer=tf.keras.initializers.Constant(2.2),
                                       name='exponent',
                                       trainable=True)

    def print_weights(self):
        print(tf.nn.sigmoid(tilo(self.bias_weights))[0, :, :, 0, 0, 0])
        print("Exponents:")
        plt.imshow(tilop(self.exponent)[0, :, :, 0, 0, 0])
        plt.colorbar()
        plt.show()

    def call(self, inputs):
        squashed_bias_weights = tf.nn.sigmoid(tilo(self.bias_weights))
        # 0.0 -> abs(real)
        # 0.5 -> abs(inputs)
        # 1.0 -> abs(imag)
        fr = 1. - squashed_bias_weights
        fi = squashed_bias_weights
        biased_abs = (tf.sqrt(eps + fr * tf.math.real(inputs)**2 + fi * tf.math.imag(inputs)**2) * tf.sqrt(2.)) ** tilop(self.exponent)
        #biased_abs = tf.sqrt(eps + fr * (inputs[:, :, :, 0, :, :, :] - inputs[:, :, :, 2, :, :, :])**2 + fi * (inputs[:, :, :, 1, :, :, :] - inputs[:, :, :, 3, :, :, :])**2) * tf.sqrt(2.)

        return biased_abs

class NormalizationPool(tf.keras.layers.Layer):
    # This layer computes the unscaled normalization pool for each neuron.
    # It effectively performs a blur over space, frequency, and orientation, using
    # Gaussian blur via 3D FFT (instead of a convolutional layer) followed by
    # a matrix multiplication over the orientation axis (arbitrary kernel).
    #
    # The input is a 7D tensor of reals <b, s, o, p, y, x, d> (scale_index, orientation_index, phase_index, vertical coordinate, horizontal coordinate)
    # The output is a 7D tensor of reals <b, s, o, p, y, x, d> (scale_index, orientation_index, phase_index, vertical coordinate, horizontal coordinate)
    #
    # For more information, see Sawada & Petrov (2017)

    def __init__(self, **kwargs):
        super(NormalizationPool, self).__init__(**kwargs)

        self.sigmas = get_sigmas()
        self.wavelengths = self.get_wavelengths()
        self.r2grid, self.sgrid = self.get_distgrids()

        self.spatial_pool_size_factor = self.add_weight(shape=(),
                                                        initializer=tf.keras.initializers.Constant(1.),
                                                        name='spatial_pool_size_factor',
                                                        trainable=True)
        self.scale_pool_size_factor = self.add_weight(shape=(),
                                                      initializer=tf.keras.initializers.Constant(4.),
                                                      name='scale_pool_size_factor',
                                                      trainable=True)

        self.rescale_factor = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(-0.0),
                                       name='npool_rescale_factor',
                                       trainable=False)
        #self.phase_congruency_coefficient = self.add_weight(shape=(),
        #                                                      initializer=tf.keras.initializers.Constant(2.),
        #                                                      name="phase_congruency_coefficient",
        #                                                      trainable=True)

        # If this turns out to be symmetric, replace it with a von-Mises distribution
        # factor = exp(1.22 * cos(angle-diff)**2)  -- 1.22 coefficient taken from Sawada
        self.orientation_diff_matrix = np.array([[0, 45, 90, -45], [-45, 0, 45, 90], [90, -45, 0, 45], [45, 90, -45, 0]]).astype(np.float32) * np.pi / 180
        self.orientation_inhibition_coefficient = self.add_weight(shape=(),
                                                             initializer=tf.keras.initializers.Constant(1.22),
                                                             name='orientation_inhibition_coefficient',
                                                             trainable=True)

        # factor = exp(1.22 * sin(0.5 angle-diff)**2)  -- 1.22 coefficient taken from Sawada
        self.phase_diff_matrix = np.array([[0, 90, 180, -90], [-90, 0, 90, 180], [180, -90, 0, 90], [90, 180, -90, 0]]).astype(np.float32) * np.pi / 180
        self.phase_inhibition_coefficient = self.add_weight(shape=(),
                                                             initializer=tf.keras.initializers.Constant(1.0),
                                                             name='phase_inhibition_coefficient',
                                                             trainable=True)
        
        # Instead of an orientation inhibition matrix, we use a combined spatial/orientation filter.
        # Inhibition should be highest for orientations either in the opposition direction, or in the same direction but
        # aligned parallel.

        


    def print_weights(self):
        print("Spatial pool size factor", tf.nn.softplus(self.spatial_pool_size_factor.numpy()))
        print("Scale pool size factor", tf.nn.softplus(self.scale_pool_size_factor.numpy()))
        print("Rescale factor", self.rescale_factor)

        orientation_inhibition_matrix = tf.exp(self.orientation_inhibition_coefficient * tf.cos(self.orientation_diff_matrix)**2)
        phase_inhibition_matrix = tf.exp(self.phase_inhibition_coefficient * tf.sin(0.5 * self.phase_diff_matrix)**2)
        print("Orientation inhibition matrix")
        plt.imshow(orientation_inhibition_matrix.numpy())
        plt.colorbar()
        plt.show()
        print("Phase inhibition matrix")
        plt.imshow(phase_inhibition_matrix.numpy())
        plt.colorbar()
        plt.show()

    def get_wavelengths(self):
        sigmas = self.sigmas
        # We are padding to twice n_sizes
        padded_sigmas = [sigmas[0]] * int(np.ceil(n_sizes / 2)) + list(sigmas) + [sigmas[-1]] * int(n_sizes / 2)
        # Used for the distgrids, which are sorted <b, d, o, p, s, y, x> (because FFT works on innermost axes)
        return np.array(padded_sigmas).astype(np.float32)[None, None, None, None, :, None, None]
    
    def get_distgrids(self):
        # Computes the distance, spatially and in terms of log-wavelength, between two points
        # in <b, d, o, p, s, y, x> space, on a grid that is spatially twice the size (will be zero-padded)
        # in x and y, and twice the size in terms of scale as well (will be same-padded).
        y, x = np.mgrid[-box_height:box_height,
                        -box_width:box_width].astype(np.float32)
        r2grid = (y**2 + x**2)[None, None, None, None, None, :, :]
        sd = np.mgrid[-n_sizes:n_sizes].astype(np.float32)[None, None, None, None, :, None, None]
        sgrid = sd**2 
            # TODO: this may be incorrect; larger wavelengths may be less susceptible to
            # the neighbouring octaves (the absolute difference may count more than the logarithmic)
        return r2grid, sgrid

    def call(self, inputs):
        # Create the 4d Gaussian blur filter.
        # The total distance is computed based on r2grid *and* sgrid.
        # Sawada assumes that we can take the product of both filters.
        psppsf = tf.nn.softplus(self.spatial_pool_size_factor)
        s_spatial_filter = (tf.exp(-self.r2grid/(eps + psppsf*self.wavelengths**2)) / 
                            (tf.sqrt(2*np.pi)*(psppsf*self.wavelengths)**2 + eps))
                            # SPATIAL FILTER: self.wavelengths need to be twice as big

        pscpsf = tf.nn.softplus(self.scale_pool_size_factor)
        s_scale_filter = (tf.exp(-self.sgrid/(eps + pscpsf**2)) / (tf.sqrt(2*np.pi)*pscpsf**2) + eps)
                            
        s_filters = s_spatial_filter * s_scale_filter

        # The same filters, in frequency space
        f_filters = tf.signal.fft3d(tf.signal.fftshift(tf.complex(s_filters, tf.zeros_like(s_filters))))

        # Reshape inputs, so that <s, y, x> are the innermost dimensions, because fft3d works
        # on the innermost dims

        sigma_scale_factors = tf.exp(self.sigmas * self.rescale_factor)[None, :, None, None, None, None, None]
        rescaled_inputs = inputs * sigma_scale_factors

        r_s_inputs = tf.einsum("bsopyxd->bdopsyx", rescaled_inputs)

        # Pad the inputs (except for the phases, which we don't mind if the DFT wraps them)
        #pr_s_inputs1 = tf.pad(r_s_inputs,
        #                    [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
        #                    [int(np.ceil(box_height / 2)), int(box_height / 2)],
        #                    [int(np.ceil(box_width / 2)), int(box_width / 2)]], mode='CONSTANT')
        #pr_s_inputs = tf.pad(pr_s_inputs1,
        #                    [[0, 0], [0, 0], [0, 0], [0, 0],
        #                    [int(np.ceil(n_sizes / 2)), int(n_sizes / 2)],
        #                    [0, 0], [0, 0]], mode='CONSTANT') # we would use edge, but tf only has symmetric

        # Can use the above again as soon as TF2.1 comes out and supports padding above 6 dimensions ... jeez.
        # For now, just use the first entry on dimension 0, relying on batch_size = 1

        pr_s_inputs1 = tf.pad(r_s_inputs[0, :, :, :, :, :, :],
                           [[0, 0], [0, 0], [0, 0], [0, 0], 
                            [int(np.ceil(box_height / 2)), int(box_height / 2)],
                            [int(np.ceil(box_width / 2)), int(box_width / 2)]], mode='CONSTANT') #[None, :, :, :, :, :, :]
        pr_s_inputs = tf.pad(pr_s_inputs1, #[0, :, :, :, :, :, :],
                            [[0, 0], [0, 0], [0, 0],
                            [int(np.ceil(n_sizes / 2)), int(n_sizes / 2)],
                            [0, 0], [0, 0]], mode='CONSTANT')[None, :, :, :, :, :, :]

        # Convert inputs to frequency domain
        pr_f_inputs = tf.signal.fft3d(tf.complex(pr_s_inputs, tf.zeros_like(pr_s_inputs)))

        # Perform the filtering and convert back to space domain
        pr_s_filtered = tf.math.real(tf.signal.ifft3d(pr_f_inputs * f_filters))

        # Crop away the padding
        r_s_filtered = pr_s_filtered[:, :, :, :,
                                     int(np.ceil(n_sizes / 2)):int(n_sizes + np.ceil(n_sizes / 2)),
                                     int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                                     int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))]
        
        # Perform cross-orientation blurring

        # factor = exp(1.22 * cos(angle-diff)**2)  -- 1.22 coefficient taken from Sawada
        orientation_inhibition_matrix = tf.exp(self.orientation_inhibition_coefficient * tf.cos(self.orientation_diff_matrix)**2)
        phase_inhibition_matrix = tf.exp(self.phase_inhibition_coefficient * tf.sin(0.5 * self.phase_diff_matrix)**2)

        # Perform cross-orientation blurring
        r_s_ob_filtered = tf.einsum("bdkpsyx,kq->bdqpsyx", r_s_filtered,
                                    (orientation_inhibition_matrix))
        r_s_obpb_filtered = tf.einsum("bdoksyx,kq->bdoqsyx", r_s_ob_filtered,
                                    (phase_inhibition_matrix))

        # Reorder dimensions
        s_obpb_filtered = tf.einsum("bdopsyx->bsopyxd", r_s_obpb_filtered) # r_s_obpb
         
        return s_obpb_filtered

class ApplyCsf(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ApplyCsf, self).__init__(**kwargs)
        self.active = False

        self.a = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(0.4),
                                       name='a',
                                       trainable=self.active)
        self.b = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(1.0),
                                       name='b',
                                       trainable=self.active)
        self.c = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(2.0),
                                       name='c',
                                       trainable=self.active)
        self.d = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(0.12), # ~0.1
                                       name='d',
                                       trainable=self.active)
        self.e = self.add_weight(shape=(),
                                       initializer=tf.keras.initializers.Constant(0.0), # ~0.1
                                       name='e',
                                       trainable=False) #self.active)
        self.sigmas = get_sigmas()[None, :, None, None, None, None]

    def get_factors(self):
        a, b, c, d, e, s = self.a, self.b, self.c, self.d, self.e, box_width/self.sigmas # tf.nn.softplus(self.a), tf.nn.softplus(self.b), tf.nn.softplus(self.c), tf.nn.softplus(self.d), self.e, box_width/self.sigmas

        factors = b*(s*a)**c * tf.exp(-d*s) * (self.sigmas ** e)
        # s here represents the frequency, whereas sigmas are more like wavelengths
        #factors = tf.exp(-((s-a)/b + tf.exp(-(s-a)/b) ))
        factors /= tf.reduce_max(factors) # Scale so that the max is at 1.0
        return (factors, (a, b, c, d, e, s))

    def print_weights(self):
        factors, (a, b, c, d, e, s) = self.get_factors()
        print("Factors are:", factors[0, :, 0, 0, 0, 0])
        print("sigmas are", (box_width/self.sigmas)[0, :, 0, 0, 0, 0])
        print("A, B, C, D, E:", a, b, c, d, e)

        plt.plot(self.sigmas[0, :, 0, 0, 0, 0], factors[0, :, 0, 0, 0, 0])
        plt.show()

    def call(self, inputs):
        factors, (a, b, c, d, e, s) = self.get_factors()
        return inputs * factors if self.active else inputs

class Exponentiate(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(Exponentiate, self).__init__(**kwargs)

        #self.exponents = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1, 1),
        self.exponents = self.add_weight(shape=(1, 1, n_orientations - 1, 1, 1, 1, 1), #(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                       initializer=tf.keras.initializers.Constant(invsp(2.5)),
                                       name='exponents',
                                       trainable=True)

    def print_weights(self):
        #print("Exponents:", ptilop(self.exponents))
        plt.imshow(tilop(self.exponents)[0, :, :, 0, 0, 0])
        plt.colorbar()
        plt.show()

    def call(self, inputs):
        rectified_inputs, absvals = inputs
        pabsvals = absvals[:, :, :, None, :, :, :]
        factor = (pabsvals + 1.e-6) ** tilop(self.exponents) / (1.e-6 + pabsvals)
        return (rectified_inputs + 1.e-6) * factor
        #return (inputs + eps) ** ptilop(self.exponents)

class DivisiveNormalization(tf.keras.layers.Layer):
    def __init__(self,  **kwargs):
        super(DivisiveNormalization, self).__init__(**kwargs)

        self.factors = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1, 1),
                                       initializer=tf.keras.initializers.Constant(1.),
                                       name='dn_factors',
                                       trainable=False)
        self.beta = self.add_weight(shape=(1, 1, n_orientations - 1, 1, 1, 1, 1), # n_sizes
                                       initializer=tf.keras.initializers.Constant(2.46),
                                       name='dn_beta',
                                       trainable=True)

    def print_weights(self):
        #print("Factors:")
        #plt.imshow(tilop(self.factors)[0, :, :, 0, 0, 0])
        #plt.colorbar()
        #plt.show()

        print("Beta:", ptilop(self.beta)[0, 0, :, 0, 0, 0])
        #plt.imshow(tilop(self.beta)[0, :, :, 0, 0, 0])
        #plt.colorbar()
        #plt.show()

    def call(self, inputs):
        stimulus, normalization_pool = inputs
        return ptilop(self.factors) * stimulus / (eps + ptilop(self.beta) + normalization_pool)

class PenalizeZero(tf.keras.layers.Layer):
    def __init__(self,  **kwargs):
        super(PenalizeZero, self).__init__(**kwargs)

        self.gb = self.add_weight(shape=(),
                                  initializer=tf.keras.initializers.Constant(invsp(1.0)), 
                                  name='gb', trainable=True)
        self.gc = self.add_weight(shape=(),
                                  initializer=tf.keras.initializers.Constant(invsp(2.0)),
                                  name='gc', trainable=True)
        self.gd = self.add_weight(shape=(),
                                  initializer=tf.keras.initializers.Constant(invsp(0.12)), 
                                  name='gd', trainable=True)
        self.ge = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                  initializer=tf.keras.initializers.Constant(invsp(1.0)),
                                  name='ge', trainable=True)

        self.gain_orientation_factors= self.add_weight(shape=(1, 1, n_orientations - 1, 1, 1, 1),
                                                      initializer=tf.keras.initializers.Constant(-0.1),
                                                      name='gap_orientation_factors', trainable=True)

        self.lb = self.add_weight(shape=(),
                                  initializer=tf.keras.initializers.Constant(invsp(1.0)),
                                  name='lb', trainable=True)
        self.lc = self.add_weight(shape=(),
                                  initializer=tf.keras.initializers.Constant(invsp(2.0)),
                                  name='lc', trainable=True)
        self.ld = self.add_weight(shape=(),
                                  initializer=tf.keras.initializers.Constant(invsp(0.12)),
                                  name='ld', trainable=True)
        self.le = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                  initializer=tf.keras.initializers.Constant(invsp(1.0)),
                                  name='le', trainable=True)

        self.elosf = self.add_weight(shape=(),
                                  initializer=tf.keras.initializers.Constant(invsp(1.0)),
                                  name='elosf', trainable=True)

        self.loss_orientation_factors= self.add_weight(shape=(1, 1, n_orientations - 1, 1, 1, 1),
                                                      initializer=tf.keras.initializers.Constant(1.0),
                                                      name='loss_orientation_factors', trainable=True)

        self.sigmas = get_sigmas()[None, :, None, None, None, None]

    def get_factors(self):
        gb, gc, gd, ge, lb, lc, ld, le, s = (
            tf.nn.softplus(self.gb), tf.nn.softplus(self.gc), tf.nn.softplus(self.gd), tilop(self.ge),
            tf.nn.softplus(self.lb), tf.nn.softplus(self.lc), tf.nn.softplus(self.ld), tilop(self.le),
            box_width / self.sigmas
        )
            
        gfactors = gb*(.4*s)**gc * tf.exp(-gd*s)
        lfactors = lb*(.4*s)**lc * tf.exp(-ld*s)

        return (gfactors, lfactors)

    def print_weights(self):
        gfactors, lfactors = self.get_factors()
        print("Gap gain curve, factor=", tilop(self.ge)[0, :, 0, 0, 0, 0])
        print("Gap vertical orientation factors:", tilo(self.gain_orientation_factors)[0, 0, 1:, 0, 0, 0])
        plt.plot(self.sigmas[0, :, 0, 0, 0, 0], gfactors[0, :, 0, 0, 0, 0])
        plt.show()
        print("Edge loss curve, factor=", tilop(self.le)[0, :, 0, 0, 0, 0])
        print("Loss vertical orientation factors:", tilo(self.loss_orientation_factors)[0, 0, 1:, 0, 0, 0])
        plt.plot(self.sigmas[0, :, 0, 0, 0, 0], lfactors[0, :, 0, 0, 0, 0])
        plt.show()
        print("ELOSF:", self.elosf)

    def call(self, inputs):
        original_sums, diffs = inputs
        gfactors, lfactors = self.get_factors()
        gap_gains = tf.nn.relu(diffs[:, :, 0:1, :, :, :])

        edge_losses = tf.nn.relu(-diffs[:, :, 0:1, :, :, :])

        horizontal_gap_gains = (tf.nn.relu(diffs[:, :, 0:1, :, :, :]) + eps) ** tilop(self.ge)[:, :, 0:1, :, :, :]
        vertical_gap_gains = (tf.nn.relu(diffs[:, :, 1:, :, :, :]) + eps) ** tilop(self.ge)[:, :, 1:, :, :, :] * tilo(self.gain_orientation_factors)[:, :, 1:, :, :, :]
        gap_gains = (eps + tf.concat([horizontal_gap_gains, vertical_gap_gains], axis=2)) * gfactors

        horizontal_edge_losses = (tf.nn.relu(-diffs[:, :, 0:1, :, :, :]) + eps) ** tilop(self.le)[:, :, 0:1, :, :, :]
        vertical_edge_losses = (tf.nn.relu(-diffs[:, :, 1:, :, :, :]) + eps) ** tilop(self.le)[:, :, 1:, :, :, :] * tilo(self.loss_orientation_factors)[:, :, 1:, :, :, :]
        edge_losses = (eps + tf.concat([horizontal_edge_losses, vertical_edge_losses], axis=2)) * lfactors * ((original_sums + eps) ** self.elosf)

        return (gap_gains - edge_losses) / (eps + tf.reduce_sum(tf.abs(gap_gains[:, :, :, :, :, 1:2]) + tf.abs(edge_losses[:, :, :, :, :, 1:2]), axis=[1,2,3,4], keepdims=True)) # Try to discourage CSF from just shrinking it


class PenalizeZeroOld(tf.keras.layers.Layer):
    def __init__(self,  **kwargs):
        super(PenalizeZeroOld, self).__init__(**kwargs)

        self.scale_exponent = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                              initializer=tf.keras.initializers.Constant(0.55),
                                              name='scale_exponent', trainable=True) 
        self.scale_beta = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                          initializer=tf.keras.initializers.Constant(0.55),
                                          name='scale_beta', trainable=True) 
        self.edge_loss_relative_factor = self.add_weight(shape=(),
                                          initializer=tf.keras.initializers.Constant(0.62),
                                          name='edge_loss_relative_factor', trainable=False) 
        self.scale_mean = self.add_weight(shape=(),
                                          initializer=tf.keras.initializers.Constant(6),
                                          name='scale_mean', trainable=False) 
        self.scale_sigma = self.add_weight(shape=(),
                                          initializer=tf.keras.initializers.Constant(2),
                                          name='scale_sigma', trainable=False) 
        self.orientation_gain = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                          initializer=tf.keras.initializers.Constant(-0.62),
                                          name='orientation_gain', trainable=True) 
        self.orientation_loss = self.add_weight(shape=(1, n_sizes, n_orientations - 1, 1, 1, 1),
                                          initializer=tf.keras.initializers.Constant(-0.62),
                                          name='orientation_loss', trainable=True)
        
        yfactor_default = np.zeros((1, 1, 1, box_height, 1, 1))
        yfactor_default[0, 0, 0, :, 0, 0] = 1. #18:70
        self.yfactor = self.add_weight(shape=(1, 1, 1, box_height, 1, 1),
                                       initializer=tf.keras.initializers.Constant(yfactor_default),
                                       name='yfactor', trainable=False)

        self.sigmas = get_sigmas().astype(np.float32)[None, :, None, None, None, None]

    def print_weights(self):
        se, sb, elrs, mu, sigma, og, ol, yfactor = self.getw()
        print("Gain/loss factors:")        
        if True:
            fig, ax = plt.subplots(1, 4)
            for oi in range(n_orientations):
                ax[oi].plot(self.sigmas[0, :, 0, 0, 0, 0], og[0, :, oi, 0, 0, 0], color='b')
                ax[oi].plot(self.sigmas[0, :, 0, 0, 0, 0], ol[0, :, oi, 0, 0, 0], color='r')
                ax[oi].plot(self.sigmas[0, :, 0, 0, 0, 0], se[0, :, oi, 0, 0, 0], color='b', linestyle='dotted')
                ax[oi].plot(self.sigmas[0, :, 0, 0, 0, 0], sb[0, :, oi, 0, 0, 0], color='r', linestyle='dotted')
            plt.show()
        else:
            print("Orientation_gain:", og[0, 0, 0, 0, 0, 0], -og[0, 0, 1:, 0, 0, 0])
            print("Orientation_loss:", -ol[0, 0, :, 0, 0, 0])
            print("Factor", se, sb)
        #print("Yfactor:")
        #plt.plot(yfactor[0, 0, 0, :, 0, 0])
        #plt.show()
#
        #print("edge loss relative factor:", elrs.numpy())

    def getw(self):
        return (
            tilop(self.scale_exponent),
            tilop(self.scale_beta),
            (self.edge_loss_relative_factor),
            self.scale_mean,
            self.scale_sigma,
            tilop(self.orientation_gain),
            tilop(self.orientation_loss),
            self.yfactor,
        )

    def call(self, inputs):
        # Flat query coordinates: <1, n_query_coordinates, 4>
        # We need to convert <b, s, o, y, x, d> 
        original_sums, diffs = inputs

        se, sb, elrs, mu, sigma, og, ol, yfactor = self.getw()

        horizontal_gap_gains = (tf.nn.relu(diffs[:, :, 0:1, :, :, :]) + eps) * og[:, :, 0:1, :, :, :]
        vertical_gap_gains = (tf.nn.relu(diffs[:, :, 1:, :, :, :]) + eps) * 0. # (og[:, :, 1:, :, :, :]) # -og
        gap_gains = (eps + tf.concat([horizontal_gap_gains, vertical_gap_gains], axis=2)) ** se

        horizontal_edge_losses = (tf.nn.relu(-diffs[:, :, 0:1, :, :, :]) + eps) * (ol[:, :, 0:1, :, :, :]) #(-1.) # -ol
        vertical_edge_losses = (tf.nn.relu(-diffs[:, :, 1:, :, :, :]) + eps) * 0. #(ol[:, :, 1:, :, :, :]) # -ol
        edge_losses = (eps + tf.concat([horizontal_edge_losses, vertical_edge_losses], axis=2)) ** sb

        # We don't want 
        # Small gaps are actually good, because they separate the two letters.
        # It's not clear how exactly they accomplish that.

        penalties = (
          ((gap_gains + eps) + (edge_losses + eps)) / (tf.reduce_sum(tf.abs(edge_losses[:, :, :, :, :, 1:2]), axis=[1,2,3,4], keepdims=True)) # Try to discourage CSF from just shrinking it
        )
        return penalties

class DistanceEstimator(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DistanceEstimator, self).__init__(**kwargs)

    def call(self, inputs):
        y, x = inputs
        xdelta = (x[:, 1:] - x[:, :-1])
        ydelta = (y[:, 1:] - y[:, :-1]) # Positive when upward
    
        yrange = (tf.reduce_max(y, axis=[1], keepdims=True) - tf.reduce_min(y, axis=[1], keepdims=True)) + eps
        estimate_validities = nd_softmax(1e3 * y/yrange, axis=[1])
        estimated_distances = tf.reduce_sum(estimate_validities * x, axis=[1], name='estimated_distances')
        return estimated_distances

def get_model():
    shifted_gd1_filtered_images = tf.keras.Input(shape=full_shape, name='shifted_gd1_filtered_images', dtype=tf.complex64)
    shifted_gd2_filtered_images = tf.keras.Input(shape=full_shape, name='shifted_gd2_filtered_images', dtype=tf.complex64)

    # Go from <b, s, o, h, w, d> to <b, s, o, p, h, w, d>
    pr_gd1 = rectify_phases(shifted_gd1_filtered_images)
    pr_gd2 = rectify_phases(shifted_gd2_filtered_images)
    pr_pair = rectify_phases(shifted_gd1_filtered_images + shifted_gd2_filtered_images)

    # Then, exponentiate with a power.
    exp = Exponentiate()
    pr_p_gd1 = tf.identity(exp([pr_gd1, tf.abs(shifted_gd1_filtered_images)]), name="pr_p_gd1")
    pr_p_gd2 = tf.identity(exp([pr_gd2, tf.abs(shifted_gd2_filtered_images)]), name="pr_p_gd2")
    pr_p_pair = tf.identity(exp([pr_pair, tf.abs(shifted_gd1_filtered_images + shifted_gd2_filtered_images)]), name="pr_p_pair")

    # Then, calculate the normalization pools
    np = NormalizationPool()
    pr_np_gd1 = tf.identity(np(pr_p_gd1), name="pr_np_gd1")
    pr_np_gd2 = tf.identity(np(pr_p_gd2), name="pr_np_gd2")
    pr_np_pair = tf.identity(np(pr_p_pair), name="pr_np_pair")

    # Then, perform the divisive normalization
    dn = DivisiveNormalization()
    pr_dn_gd1 = tf.identity(dn([pr_p_gd1, pr_np_gd1]), name="pr_dn_gd1")
    pr_dn_gd2 = tf.identity(dn([pr_p_gd2, pr_np_gd2]), name="pr_dn_gd2")
    pr_dn_pair = tf.identity(dn([pr_p_pair, pr_np_pair]), name="pr_dn_pair")

    # Then, compute the absolute energy values, because that's what we care about.
    ba = BiasedAbs()
    apply_csf = ApplyCsf()
    penalize = PenalizeZero()

    #eba_gd1 = ba(pr_dn_gd1) #((ba(shifted_gd1_filtered_images)))
    #eba_gd2 = ba(pr_dn_gd2) #((ba(shifted_gd2_filtered_images)))
    #eba_pair = tf.identity(ba(pr_dn_pair), "bapr_dn_pair") #((ba(shifted_gd1_filtered_images + shifted_gd2_filtered_images)))

    eba_gd1 = ba(shifted_gd1_filtered_images)
    eba_gd2 = ba(shifted_gd2_filtered_images)
    eba_pair = ba(shifted_gd1_filtered_images + shifted_gd2_filtered_images)

    #e_dn_gd1 = tf.identity(apply_csf(dn([eba_gd1, eba_gd1])), "e_dn_gd1")
    #e_dn_gd2 = tf.identity(apply_csf(dn([eba_gd2, eba_gd2])), "e_dn_gd2")
    #e_dn_pair = tf.identity(apply_csf(dn([eba_pair, eba_pair])), "e_dn_pair")
    e_dn_gd1 = tf.identity(apply_csf(eba_gd1), "e_dn_gd1")
    e_dn_gd2 = tf.identity(apply_csf(eba_gd2), "e_dn_gd2")
    e_dn_pair = tf.identity(apply_csf(eba_pair), "e_dn_pair")

    # Then, compute the differences between pair and (gd1 + gd2)

    original_sums = tf.identity(e_dn_gd1 + e_dn_gd2, name="original_sums")
    diffs = tf.identity(e_dn_pair - original_sums, name="diffs")
    penalties = tf.identity(penalize([original_sums, diffs]), name="total_pixel_penalties")

    # Not used for anything
    sample_distances = tf.keras.Input(shape=(n_sample_distances), name='sample_distances')
    pair_images = tf.keras.Input(shape=(box_height, box_width, n_sample_distances), name='pair_images')
    zero_indices = tf.keras.Input(shape=(), name='zero_indices')

    return tf.keras.Model(inputs=[shifted_gd1_filtered_images,
                                  shifted_gd2_filtered_images,
                                  sample_distances, pair_images, zero_indices],
                                  outputs=(penalties))

@tf.function
def compute_loss(_target, penalties):
    d = tf.reduce_sum(penalties, axis=[1, 2, 3, 4]) + eps

    if False:
        # Old zero-finding model, which minimizes the vertical distance from 0, plus some extra constraints for monotonicity
        total_gap_gains = tf.reduce_sum(tf.nn.relu(penalties), axis=[1, 2, 3, 4])
        total_edge_losses = tf.reduce_sum(tf.nn.relu(-penalties), axis=[1, 2, 3, 4])
        gap_gain_increase = 2 + tf.nn.elu(total_gap_gains[:, 0] - total_gap_gains[:, 1]) + tf.nn.elu(total_gap_gains[:, 1] - total_gap_gains[:, 2])
        edge_loss_decrease = 2 + tf.nn.elu(total_edge_losses[:, 1] - total_edge_losses[:, 0]) + tf.nn.elu(total_edge_losses[:, 2] - total_edge_losses[:, 1])
        first_negative = 1 + tf.nn.elu(d[:, 0])
        last_positive = 1 + tf.nn.elu(-d[:, 2])
        first_increase = 1 + tf.nn.elu(d[:, 0] - d[:, 1])
        last_increase = 1 + tf.nn.elu(d[:, 1] - d[:, 2])
    
        l = tf.identity(1000 * d[:, 1]**2 + first_negative + last_positive + first_increase + last_increase + gap_gain_increase + edge_loss_decrease, name="losses") # Gap is negative
        return tf.reduce_sum(l) 
    elif True: # newer zero-finding model, which tries to find the actual most likely x-zero-crossing
        first_negative = tf.nn.relu(d[:, 0]) * 100
        last_positive = tf.nn.relu(-d[:, 2]) * 100
        constrained_slopes = (d[:, 1:] - d[:, 0:-1]) / 2 # right now delta-x is equal spacing
        predicted_zeros = np.array([-2, 0]) - d[:, 0:-1] / constrained_slopes

        first_crosses_up = tf.nn.sigmoid(1000 * d[:, 1]) * tf.nn.sigmoid(-1000 * d[:, 0])
        second_crosses_up = tf.nn.sigmoid(1000 * d[:, 2]) * tf.nn.sigmoid(-1000 * d[:, 1])
        first_above_zero = tf.nn.sigmoid(1000 * d[:, 0])
        last_below_zero = tf.nn.sigmoid(-1000 * d[:, 2])
        relevance_first = tf.nn.tanh(1000 * (first_crosses_up + first_above_zero))
        relevance_last = tf.nn.tanh(1000 * (second_crosses_up + last_below_zero))

        predicted_zero = (relevance_first * predicted_zeros)[:, 0] + (relevance_last * predicted_zeros)[:, 1]
        return tf.reduce_sum(predicted_zero**2) # + first_negative + last_positive)
    else: # Penalize based on whether or not the center one is the minimum
        first_down = tf.nn.softplus(d[:, 1] - d[:, 0])  # 0.1 - 6.6
        second_up = tf.nn.softplus(d[:, 1] - d[:, 2])
        # It's only good when both 
        return tf.reduce_sum(first_down * tf.abs(first_down) + second_up * tf.abs(second_up))




## Monitoring

In [0]:
from matplotlib.patches import Circle

class MonitorProgressCallback(tf.keras.callbacks.Callback):
    def __init__(self, data_generator):
        self.data_generator = data_generator
        self.model_inputs = [l.input for l in model.layers if isinstance(l, tf.keras.layers.InputLayer)]
        #print("MODEL INPUTS ARE", [(l.name, l.input.shape) for l in model.layers if isinstance(l, tf.keras.layers.InputLayer)])
        self.current_data = None

    def get_val(self, name):
        l = [l for l in model.layers if l.name.endswith(name)][0]
        #print("getting value", l.name, l.output.shape)
        #print("values", tf.keras.backend.function(self.model_inputs, [l.output])(self.current_data))
        output = tf.keras.backend.function(self.model_inputs, [l.output])(self.current_data)[0]
        return output

    def get_weights(self, name):
        l = [l for l in model.layers if l.name.endswith(name)][0]
        return l.get_weights()

    def print_weights(self, name):
        l = [l for l in model.layers if l.name.endswith(name)][0]
        l.print_weights()

    def on_test_batch_begin(self, batch_index, logs=None):
        dataset = self.data_generator[batch_index]
        current_data = dataset[0]
        self.current_data = current_data
        shifted_gd1_d1_filtered_images, shifted_gd2_d1_filtered_images, sample_distances, pair_images, zero_indices = current_data

        iix = 0
        
        if False:
            plt.imshow(pair_images[iix, :, :, 0])
            plt.colorbar()
            plt.show()
            plt.imshow(pair_images[iix, :, :, zero_indices[iix]])
            plt.colorbar()
            plt.show()
            plt.imshow(pair_images[iix, :, :, 2])
            plt.colorbar()
            plt.show()

        if True:
            self.print_weights("apply_csf")
            #self.print_weights("exponentiate")
            #self.print_weights("normalization_pool")
            #self.print_weights("divisive_normalization")
            self.print_weights("biased_abs")
            self.print_weights("penalize_zero")

        if False:
            e_gd1 = self.get_val("e_gd1")
            e_pair = self.get_val("e_pair")

            i_gd1 = self.get_val("shifted_gd1_filtered_images")
            i_gd2 = self.get_val("shifted_gd2_filtered_images")
            i_total_angle = tf.math.angle(i_gd1 + i_gd2)

            dn_gd1 = self.get_val("e_dn_gd1")
            dn_gd2 = self.get_val("e_dn_gd2")
            dn_pair = self.get_val("e_dn_pair")
            diffs = self.get_val("diffs")
            for si in range(n_sizes):
                fig, ax = plt.subplots(nrows=1, ncols=5, gridspec_kw={'wspace':0, 'hspace':0}, figsize=(4 * 4 * box_width / 100, 4 * 1 * box_height / 100))
                ax[0].imshow(tf.math.angle(i_gd1)[iix, si, 0, :, :, zero_indices[iix]])
                ax[1].imshow(i_total_angle[iix, si, 0, :, :, zero_indices[iix]])
                ax[2].imshow((dn_gd1 + dn_gd2)[iix, si, 0, :, :, zero_indices[iix]])
                ax[3].imshow(dn_pair[iix, si, 0, :, :, zero_indices[iix]])
                ax[4].imshow(diffs[iix, si, 0, :, :, zero_indices[iix]])
                plt.show()

        if False:
            increases = self.get_val("increases")
            dsvps = self.get_val("dsvps")
            #print("DPSS:", dsvps)
            pcs = self.get_val("total_penalties")
            for pix in range(batch_size):
                plt.plot(np.arange(n_sample_distances) - zero_indices[pix], pcs[pix, :])
                plt.plot(0.5 + np.arange(n_sample_distances - 1) - zero_indices[pix], dsvps[pix, :])
                plt.plot(0.5 + np.arange(n_sample_distances - 1) - zero_indices[pix], increases[pix, :])
            plt.show()

        if False:
            p = self.get_val("pr_dn_pair")
            pa = tf.math.angle(tf.complex(p[:, :, :, 0, :, :, :] - p[:, :, :, 2, :, :, :], p[:, :, :, 1, :, :, :] - p[:, :, :, 3, :, :, :]))
            pv = tf.abs(tf.complex(p[:, :, :, 0, :, :, :] - p[:, :, :, 2, :, :, :], p[:, :, :, 1, :, :, :] - p[:, :, :, 3, :, :, :]))
            for si in range(n_sizes):
                plt.imshow((pa)[iix, si, 0, :, :, zero_indices[iix]]) #, vmin=-lvmax, vmax=lvmax)
                plt.colorbar()
                plt.show()
                plt.imshow((pv)[iix, si, 0, :, :, zero_indices[iix]]) #, vmin=-lvmax, vmax=lvmax)
                plt.colorbar()
                plt.show()

        if False:
            if False:
                va = self.get_val("pr_np_gd1")
                size_factor = 5
                fig, ax = plt.subplots(nrows=n_sizes, ncols=4,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
                for si in range(n_sizes):
                    for pi in range(4):
                        ax[si, pi].imshow(va[iix, si, 0, pi, :, :, zero_indices[iix]], cmap='RdBu') #, vmin=-lvmax, vmax=lvmax)
                        ax[si, pi].set_xticklabels([])
                        ax[si, pi].set_yticklabels([])
                plt.show()
    
                va = self.get_val("pr_dn_gd1")
                size_factor = 5
                fig, ax = plt.subplots(nrows=n_sizes, ncols=4,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
                for si in range(n_sizes):
                    for pi in range(4):
                        ax[si, pi].imshow(va[iix, si, 0, pi, :, :, zero_indices[iix]], cmap='RdBu') #, vmin=-lvmax, vmax=lvmax)
                        ax[si, pi].set_xticklabels([])
                        ax[si, pi].set_yticklabels([])
                plt.show()
    
                va = self.get_val("e_dn_gd1")
                size_factor = 5
                fig, ax = plt.subplots(nrows=n_sizes, ncols=1,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
                for si in range(n_sizes):
                    ax[si].imshow(va[iix, si, 0, :, :, zero_indices[iix]], cmap='RdBu') #, vmin=-lvmax, vmax=lvmax)
                    ax[si].set_xticklabels([])
                    ax[si].set_yticklabels([])
                plt.show()

            va = self.get_val("e_dn_pair")
            pp = self.get_val("e_dn_gd1") + self.get_val("e_dn_gd2")
            size_factor = 5
            #fig, ax = plt.subplots(nrows=n_sizes, ncols=1,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
            for si in range(n_sizes):
                print("size", si, "pair then originalsums")
                plt.imshow(va[iix, si, 0, :, :, zero_indices[iix]]) #, vmin=-lvmax, vmax=lvmax)
                plt.colorbar()
                plt.show()
                plt.imshow(pp[iix, si, 0, :, :, zero_indices[iix]]) #, vmin=-lvmax, vmax=lvmax)
                plt.colorbar()
                plt.show()

        if False:
            if True:
                print("Normalization pool, pair, 2, 0:")
                pixel_penalties = tf.math.real(self.get_val("e_np_pair"))
                plt.imshow(pixel_penalties[iix, 2, 0, :, :, zero_indices[iix]])
                plt.colorbar()
                plt.show()
                print("Normalizataion pool, pair, total:")
                pixel_penalties = tf.math.real(self.get_val("e_np_pair"))
                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, 0, :, :, zero_indices[iix]], axis=[0]))
                plt.colorbar()
                plt.show()
            if True:
                print("Divisive Normalization, pair, 2, 0")
                pixel_penalties = tf.math.real(self.get_val("e_dn_pair"))
                plt.imshow(pixel_penalties[iix, 2, 0, :, :, zero_indices[iix]])
                plt.colorbar()
                plt.show()
                print("Divisive normalization, pair, total:")
                pixel_penalties = tf.math.real(self.get_val("e_dn_pair"))
                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, 0, :, :, zero_indices[iix]], axis=[0]))
                plt.colorbar()
                plt.show()
            if True:
                print("Original sums, pair, 2, 0")
                pixel_penalties = tf.math.real(self.get_val("original_sums"))
                plt.imshow(pixel_penalties[iix, 2, 0, :, :, zero_indices[iix]])
                plt.colorbar()
                plt.show()
                print("Original sums, pair, total:")
                pixel_penalties = tf.math.real(self.get_val("original_sums"))
                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, 0, :, :, zero_indices[iix]], axis=[0]))
                plt.colorbar()
                plt.show()
            if True:
                print("Diffs, pair, 2, 0")
                pixel_penalties = tf.math.real(self.get_val("diffs"))
                plt.imshow(pixel_penalties[iix, 2, 0, :, :, zero_indices[iix]])
                plt.colorbar()
                plt.show()
        if False:
            diffs = tf.math.real(self.get_val("diffs"))
            for di in range(n_sample_distances):
                for si in range(n_sizes):
                    print("distance", di, "size", si)
                    plt.imshow(diffs[iix, si, 0, :, :, di])
                    plt.colorbar()
                    plt.show()
        if True:

            if False:
                # Visualize the normalization pools for each size/orientation with and without phase congruency
                e_dn_gd1 = tf.reduce_sum(self.get_val("e_dn_gd1"), axis=[2], keepdims=True)
                e_dn_gd2 = tf.reduce_sum(self.get_val("e_dn_gd2"), axis=[2], keepdims=True)
                e_dn_pair = tf.reduce_sum(self.get_val("e_dn_pair"), axis=[2], keepdims=True)
    
                size_factor = 5
                fig, ax = plt.subplots(nrows=n_sizes, ncols=5,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * 5 * box_width / 100, size_factor * n_sizes * box_height / 100))
                vm = tf.reduce_max(e_dn_pair[iix, :, :, :, :, zero_indices[iix]])
                for si in range(n_sizes):
                    ax[si, 0].imshow(e_dn_gd1[iix, si, 0, :, :, zero_indices[iix]], vmin=0, vmax=vm)
                    ax[si, 0].set_xticklabels([])
                    ax[si, 0].set_yticklabels([])
                    ax[si, 1].imshow(e_dn_gd2[iix, si, 0, :, :, zero_indices[iix]], vmin=0, vmax=vm)
                    ax[si, 1].set_xticklabels([])
                    ax[si, 1].set_yticklabels([])
                    ax[si, 2].imshow((e_dn_gd1 + e_dn_gd2)[iix, si, 0, :, :, zero_indices[iix]], vmin=0, vmax=vm)
                    ax[si, 2].set_xticklabels([])
                    ax[si, 2].set_yticklabels([])
                    ax[si, 3].imshow(e_dn_pair[iix, si, 0, :, :, zero_indices[iix]], vmin=0, vmax=vm)
                    ax[si, 3].set_xticklabels([])
                    ax[si, 3].set_yticklabels([])
                    ax[si, 4].imshow((e_dn_pair - e_dn_gd1 - e_dn_gd2)[iix, si, 0, :, :, zero_indices[iix]])
                    ax[si, 4].set_xticklabels([])
                    ax[si, 4].set_yticklabels([])
                plt.show()

            print("Diffs, pair, total:")
            diffs = tf.math.real(self.get_val("diffs"))
            plt.imshow(-tf.reduce_sum(diffs[iix, :, :, :, :, zero_indices[iix]], axis=[0, 1]))
            plt.colorbar()
            plt.show()
            print("Penalties:")
            pixel_penalties = tf.math.real(self.get_val("total_pixel_penalties"))
            sigmas = get_sigmas()

            plt.plot(sigmas, np.zeros_like(sigmas), color='k')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(diffs[iix, :, :, :, :, 0]), axis=[1,2,3]), color='r')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-diffs[iix, :, :, :, :, 0]), axis=[1,2,3]), color='r')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(diffs[iix, :, :, :, :, 1]), axis=[1,2,3]), color='b')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-diffs[iix, :, :, :, :, 1]), axis=[1,2,3]), color='b')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(diffs[iix, :, :, :, :, 2]), axis=[1,2,3]), color='g')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-diffs[iix, :, :, :, :, 2]), axis=[1,2,3]), color='g')
            plt.show()

            plt.plot(sigmas, np.zeros_like(sigmas), color='k')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(pixel_penalties[iix, :, :, :, :, 0]), axis=[1,2,3]), color='r')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-pixel_penalties[iix, :, :, :, :, 0]), axis=[1,2,3]), color='r')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(pixel_penalties[iix, :, :, :, :, 1]), axis=[1,2,3]), color='b')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-pixel_penalties[iix, :, :, :, :, 1]), axis=[1,2,3]), color='b')
            plt.plot(sigmas, tf.reduce_sum(tf.nn.relu(pixel_penalties[iix, :, :, :, :, 2]), axis=[1,2,3]), color='g')
            plt.plot(sigmas, tf.reduce_sum(-tf.nn.relu(-pixel_penalties[iix, :, :, :, :, 2]), axis=[1,2,3]), color='g')
            plt.show()

            d = tf.reduce_sum(pixel_penalties, [1,2,3,4])
    
            first_negative = tf.nn.relu(d[:, 0]) * 100
            last_positive = tf.nn.relu(-d[:, 2]) * 100
            constrained_slopes = (d[:, 1:] - d[:, 0:-1]) / 2 # right now delta-x is equal spacing
            predicted_zeros = np.array([-2, 0]) - d[:, 0:-1] / constrained_slopes
    
            first_crosses_up = tf.nn.sigmoid(1000 * d[:, 1]) * tf.nn.sigmoid(-1000 * d[:, 0])
            second_crosses_up = tf.nn.sigmoid(1000 * d[:, 2]) * tf.nn.sigmoid(-1000 * d[:, 1])
            first_above_zero = tf.nn.sigmoid(1000 * d[:, 0])
            last_below_zero = tf.nn.sigmoid(-1000 * d[:, 2])
            relevance_first = tf.nn.tanh(1000 * (first_crosses_up + first_above_zero))
            relevance_last = tf.nn.tanh(1000 * (second_crosses_up + last_below_zero))
            print("relevance first", relevance_first, "relevance last", relevance_last, "firstcrosses up or above zero:", first_crosses_up, first_above_zero, "lastcrossesup or below zero", second_crosses_up, last_below_zero)

            predicted_zero = (relevance_first * predicted_zeros)[:, 0] + (relevance_last * predicted_zeros)[:, 1]

            print("predicted zero", predicted_zero, "squared:", predicted_zero**2, "loss:", predicted_zero**2 + first_negative + last_positive)

            for di in range(n_sample_distances):
                print("DISTANCE INDEX", di)
                if (di == zero_indices[iix]):
                    print("best distance:")
                vm = max(tf.reduce_max(pixel_penalties[iix, :, :, :, :, di]), -tf.reduce_min(pixel_penalties[iix, :, :, :, :, di]))
                print("max:", vm, "sum:", tf.reduce_sum(pixel_penalties[iix, :, :, :, :, di]))
                if True:
                    # Show the image with penalties overlaid:
                    plt.imshow(pair_images[iix, :, :, di], alpha=0.7, cmap='gray')
                    plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, :, :, :, di], [0, 1]), alpha=0.7)
                    plt.colorbar()
                    plt.show()
                if di == 1 and False:
                    for si in range(n_sizes):
                        print("SIZE", si)
                        fig, ax = plt.subplots(1, 2,  gridspec_kw = {'wspace':0, 'hspace':0},  figsize=(5 * 2 * box_width / 100, 5 * box_height / 100))
                        ax[0].imshow(pair_images[iix, :, :, di], alpha=0.7, cmap='gray')
                        pp = tf.reduce_sum(pixel_penalties[iix, si, :, :, :, di], [0])
                        aim = ax[0].imshow(pp, alpha=0.7)
                        max_y, max_x = np.unravel_index(pp.numpy().argmax(), pp.shape)
                        min_y, min_x = np.unravel_index(pp.numpy().argmin(), pp.shape)
                        ax[0].add_patch(Circle((max_x,max_y),sigmas[si]*4,edgecolor='w',facecolor=None,fill=False))
                        ax[0].add_patch(Circle((min_x,min_y),sigmas[si]*4,edgecolor='w',facecolor=None,fill=False))
                        fig.colorbar(aim, ax=ax[0])
                        ax[1].imshow(pair_images[iix, :, :, di], alpha=0.7, cmap='gray')
                        dim = ax[1].imshow(tf.reduce_sum(diffs[iix, si, :, :, :, zero_indices[iix]], axis=[0]))
                        fig.colorbar(dim, ax=ax[1])
                        plt.show()

                plt.imshow(tf.reduce_sum(pixel_penalties[iix, :, :, :, :, di], [2, 3]))
                plt.colorbar()
                plt.show()


        if False:
            print("Sigmas")
            sigmas = self.get_weights("SpatialAverage")[0]
            plt.plot(sigmas[0, 0, :, 0, 0])
            plt.show()

        if False:
            print("After s/o dense convolution: lines, edges")
            blurred = self.get_val("SpatialAverage")
            subbed = self.get_val("pow")
            hra_total = self.get_val("hr_atotal")
            size_factor = 2
            fig, ax = plt.subplots(nrows=n_sizes, ncols=4,  gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(size_factor * n_orientations * box_width / 100, size_factor * n_sizes * box_height / 100))
            for si in range(n_sizes):
                ax[si, 0].imshow(tf.reduce_sum(abstotal, axis=[2])[iix, si, :, :, zero_indices[iix]])
                ax[si, 0].set_xticklabels([])
                ax[si, 0].set_yticklabels([])
                ax[si, 1].imshow(blurred[iix, si, :, :, zero_indices[iix]])
                ax[si, 1].set_xticklabels([])
                ax[si, 1].set_yticklabels([])
                ax[si, 2].imshow(subbed[iix, si, :, :, zero_indices[iix]])
                ax[si, 2].set_xticklabels([])
                ax[si, 2].set_yticklabels([])
                ax[si, 3].imshow(hra_total[iix, si, :, :, zero_indices[iix]])
                ax[si, 3].set_xticklabels([])
                ax[si, 3].set_yticklabels([])

                print("size", si, tf.reduce_sum(hra_total[iix, si, :, :, zero_indices[iix]]))
            plt.show()

        if False:
            hravars = self.get_weights("hr_atotal")
            print("hravars", hravars)
            print("Hyperbolic ratio variables.")
            print("Exponents")
            plt.plot(tf.nn.softplus(hravars[0][0, :, 0, 0, 0]))
            #plt.colorbar()
            plt.show()
            print("Alphas")
            plt.plot(tf.nn.softplus(hravars[1][0, :, 0, 0, 0]))
            #plt.colorbar()
            plt.show()
            print("m-scale")
            plt.plot(hravars[2][0, :, 0, 0, 0])
            #plt.colorbar()
            plt.show()
    
            print("Total for each pixel")
            out = self.get_val("Sum_1")
            plt.imshow(out[iix, :, :, 0])
            plt.colorbar()
            plt.show()
            plt.imshow(out[iix, :, :, zero_indices[iix]])
            plt.colorbar()
            plt.show()
            plt.imshow(out[iix, :, :, -1])
            plt.colorbar()
            plt.show()
            #for i in range(n_sample_distances):
            #    plt.imshow(tf.reduce_sum(out, axis=[3,4], keepdims=True)[iix, :, :, 0, 0, i])
            #    plt.colorbar()
            #    plt.show()
            #    print("Index:", i, "total is", tf.reduce_sum(out, axis=[1,2,3,4])[iix, i])
        if False:
            # Display what the penalties look like
            pass

        if True:
            self.print_weights("penalize_zero")

        if False:             
            print("Diglyphiness values")
            out = self.get_val("Sum")
            pred = self.get_val("distance_estimator")
            print("OUT PRED", out.shape, pred.shape)
            plt.plot(sample_distances[iix, :], out[iix, :])
            print("Pred distance is", pred[iix], "correct distance is", sample_distances[iix, zero_indices[iix]])
            plt.scatter([pred[iix]], [0])
            plt.show()


## Model fitting

In [0]:
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()

ig = InputGenerator(batch_size)
if True:
    model = get_model()
    model.compile(loss=compute_loss,
                optimizer=tf.keras.optimizers.Adagrad(0.05))
    #model.summary()
    history = model.fit_generator(ig,
                                callbacks=[MonitorProgressCallback(ig)],
                                validation_data=ig,
                                validation_steps=1,
                                validation_freq=40,
                                epochs=1000,
                                steps_per_epoch=20, use_multiprocessing=False)
ig.kill()
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()

## Use trained model to predict letter distances

In [0]:
distances = [-3, 0, 2, 4, 6, 8, 10, 12, 15, 20]
distances = [0, 8, 16]
predicted_kerning_dict = {}
glyph_data = {}
for glyph_char in tqdm(glyph_char_list):
    glyph_data[glyph_char] = get_glyph_data_with_filtered_as_dict(glyph_char)

for gl in glyph_char_list:
    for gr in glyph_char_list:
        with tf.device('/CPU:0'):
            cpd = shift_and_overlay_pair_data(glyph_data[gl], glyph_data[gr], distances)
            if False:
                print("dist 0")
                plt.imshow(cpd['pair_images'][:, :, 0])
                plt.show()
                plt.imshow(cpd['pair_images'][:, :, 1])
                plt.show()
                plt.imshow(cpd['pair_images'][:, :, 2])
                plt.show()
                print(cpd['pair_images'][None, :, :, :].shape) 
            inputs = [
                cpd['shifted_gd1_d1_filtered_images'][None, :, :, :, :, :],
                cpd['shifted_gd2_d1_filtered_images'][None, :, :, :, :, :],
                cpd['sample_distances'][None, :],
                cpd['pair_images'][None, :, :, :],
                np.array(cpd['zero_index'])[None],
            ]   

            d = tf.reduce_sum(model.predict(inputs), axis=[1,2,3,4])
            print("RESULT:", gl, gr, d[0, :].numpy())

            first_negative = tf.nn.relu(d[:, 0]) * 100
            last_positive = tf.nn.relu(-d[:, 2]) * 100
            constrained_slopes = (d[:, 1:] - d[:, 0:-1]) / 8 # right now delta-x is equal spacing
            predicted_zeros = np.array([-8, 0]) - d[:, 0:-1] / constrained_slopes
    
            first_crosses_up = tf.nn.sigmoid(1000 * d[:, 1]) * tf.nn.sigmoid(-1000 * d[:, 0])
            second_crosses_up = tf.nn.sigmoid(1000 * d[:, 2]) * tf.nn.sigmoid(-1000 * d[:, 1])
            first_above_zero = tf.nn.sigmoid(1000 * d[:, 0])
            last_below_zero = tf.nn.sigmoid(-1000 * d[:, 2])
            relevance_first = tf.nn.tanh(1000 * (first_crosses_up + first_above_zero))
            relevance_last = tf.nn.tanh(1000 * (second_crosses_up + last_below_zero))

            predicted_zero = (relevance_first * predicted_zeros)[:, 0] + (relevance_last * predicted_zeros)[:, 1]
            predicted_kerning_dict[gl, gr] = int(predicted_zero) + 8 - f.minimum_ink_distance(gl, gr)
            print("Predicted dist,", gl, gr, ":", predicted_kerning_dict[gl, gr])
            print("Predicted zero was:", predicted_zero, "mindist:", f.minimum_ink_distance(gl, gr))


In [0]:
print("Font scaling factor is", f.scale_factor)

In [0]:
ufo = defcon.Font()
extractor.extractUFO(filename, ufo, doKerning=False)

print("Adding kerning data ...", gl)
import string
for gl in string.ascii_lowercase:
    ufo.layers.defaultLayer[gl].leftMargin = 0 #lsb
    ufo.layers.defaultLayer[gl].rightMargin = 0 #rsb
    for gr in string.ascii_lowercase:
        ufo.kerning[(gl, gr)] = predicted_kerning_dict[(gl, gr)] / f.scale_factor
print("Compiling to OTF ...")
otf = compileOTF(ufo)
otf.save('example-output.otf')
print("Compilation done.")
