<a href="https://colab.research.google.com/github/skosch/YinYangFit/blob/master/YinYangFit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
try:
    %tensorflow_version 2.x
except Exception:
    pass

import tensorflow as tf
import os

print("TF version:", tf.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")

if tf.test.is_gpu_available():
    device_name = tf.test.gpu_device_name()
    if device_name != '/device:GPU:0':
      raise SystemError('GPU device not found')
    print('Found GPU at: {}'.format(device_name))
elif False:
    tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
    
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(cluster_resolver)
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
    tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)

In [None]:
import itertools
import os

import numpy as np
pi = np.pi

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm
import tensorflow as tf
import tensorflow_probability as tfp
import random; random.seed()
import math
import pickle
import os
from tqdm import tqdm as tqdm
import sys
from functools import reduce
import random
from itertools import cycle, islice, product
import operator
from scipy.linalg import toeplitz
from scipy.optimize import minimize_scalar

!pip install --quiet tensorfont
!pip install --quiet fonttools
!pip install --quiet --upgrade fontParts
!pip install booleanOperations
!pip install --quiet --upgrade ufo-extractor
!pip install --quiet --upgrade defcon
!pip install --quiet --upgrade ufo2ft
import fontParts
import extractor
import defcon
from ufo2ft import compileOTF

from tensorfont import Font

print("✓ Dependencies imported.")

In [None]:
#!wget -q -O OpenSans-Regular.ttf https://github.com/googlefonts/opensans/blob/master/ttfs/OpenSans-Regular.ttf?raw=true
#!wget -q -O Roboto.ttf https://github.com/google/fonts/blob/master/apache/roboto/Roboto-Regular.ttf?raw=true
#!wget -q -O Roboto.otf https://github.com/AllThingsSmitty/fonts/blob/master/Roboto/Roboto-Regular/Roboto-Regular.otf?raw=true
#!wget -q -O DroidSerif.ttf https://github.com/datactivist/sudweb/blob/master/fonts/droid-serif-v6-latin-regular.ttf?raw=true
!wget -q -O CrimsonItalic.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Italic.otf?raw=true
#!wget -q -O CrimsonBold.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Bold.otf?raw=true 
#!wget -q -O CrimsonRoman.otf https://github.com/alif-type/amiri/blob/master/Amiri-Regular.ttf?raw=true

!wget -q -O CrimsonRoman.otf https://github.com/skosch/Crimson/blob/master/Desktop%20Fonts/OTF/Crimson-Roman.otf?raw=true
print("✓ Font file(s) downloaded.")

In [None]:
glyph_char_list = "abcdefghijklmnopqrstuvwxyz"
#glyph_char_list = "bdghijlmnopqu" # straight letters only
#glyph_char_list = "abgjqrst"
#glyph_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#glyph_char_list = "OO"
#glyph_char_list = "abc"

# ==== Create Font ====
factor = 1.0 #1.539  # This scales the size of everything
filename = "CrimsonRoman.otf"
f = Font(filename, 34 * factor) # Roboto.ttf CrimsonRoman.otf # 34 for lowercase
box_height = int(f.full_height_px)
box_width = int(161 * factor) # 121
box_width += (box_width + 1) % 2
print("Box size:", box_height, "×", box_width)

batch_size = 2
sample_distance_deltas = [-2, 0, 2]
sample_distance_factors = [.5, 1., 2.0]
n_sample_distances = len(sample_distance_deltas)

n_v1_scales = 5
n_b_scales = 1
n_v1_orientations = 4
n_v4_scales = 8

In [None]:
def get_sigmas(skip_scales=0):
    sigmas = []
    for s in range(n_v1_scales):
        min_sigma = 0.7
        max_sigma = box_width / 15
        sigmas.append((max_sigma - min_sigma) * (s + skip_scales)**2 / (n_v1_scales - 1)**2 + min_sigma)
        #sigmas.append((max_sigma - min_sigma) * s / n_v1_scales + min_sigma)
    return np.array(sigmas)

print("Spatial frequency scales:", get_sigmas())

def get_v1_filter_bank(skip_scales, display_filters=False):
    def rotated_mgrid(oi):
        """Generate a meshgrid and rotate it by RotRad radians."""
        rotation = np.array([[ np.cos(pi*oi/n_v1_orientations), np.sin(pi*oi/n_v1_orientations)],
                             [-np.sin(pi*oi/n_v1_orientations), np.cos(pi*oi/n_v1_orientations)]])
        hh = box_height # / 2
        bw = box_width # / 2
        y, x = np.mgrid[-hh:hh, -bw:bw].astype(np.float32)
        y += 0.5 # 0 if box_height % 2 == 0 else 0.5
        x += 0.5 # 0 if box_width % 2 == 0 else 0.5
        return np.einsum('ji, mni -> jmn', rotation, np.dstack([x, y]))

    def get_filter(s, theta):
        x, y = rotated_mgrid(theta)

        # To minimize ringing etc., we create the filter as is, then run it through the DFT.

        # First derivative (odd filter/up-down)
        d1_space = np.exp(-(x**2+y**2)/(2*s**2))*x/(2*pi*s**4)
        d1_relu_sum = np.sum(d1_space * (d1_space > 0))
        d1 = np.fft.fft2(np.fft.ifftshift(d1_space + 1j * np.zeros_like(d1_space)))

        # Second derivative (even filter/mexican hat):
        s2 = s * .85 # To make them about the same width
        d2_space = np.exp(-(x**2+y**2)/(2*s2**2))/(2*pi*s2**4) - np.exp(-(x**2+y**2)/(2*s2**2))*x**2/(2*pi*s2**6)
        d2_relu_sum = np.sum(d2_space * (d2_space > 0))
        d2 = (d1_relu_sum / d2_relu_sum) * np.fft.fft2(np.fft.ifftshift(d2_space + 1j * np.zeros_like(d2_space)))

        return (d1 + 1j*d2) / (np.max(tf.abs(d1+1j*d2))) # Max output should be about 0.2, which leaves lots of flexibility for the HRA later

    filter_bank = np.zeros((n_v1_scales, n_v1_orientations, 2*box_height, 2*box_width)).astype(np.complex64)

    if display_filters:
        sizediv = 20
        fig, ax = plt.subplots(nrows=n_v1_scales*2, ncols=n_v1_orientations, gridspec_kw = {'wspace':0, 'hspace':0}, figsize=(box_width * n_v1_orientations / sizediv, box_height * n_v1_scales * 2 / sizediv))

    sigmas = get_sigmas()
    for s in range(n_v1_scales):
        sigma = sigmas[s]
        for o in range(n_v1_orientations):
            f = get_filter(sigma, o)
            if display_filters:
                mx = np.max(np.abs(np.imag(np.fft.ifft2(f))))
                ax[s*2, o].imshow(np.real(np.fft.fftshift(np.fft.ifft2(f))), cmap="RdBu", vmin=-mx, vmax=mx)
                ax[s*2, o].set_aspect("auto")
                ax[s*2, o].set_yticklabels([])
                ax[s*2+1, o].imshow(np.imag(np.fft.fftshift(np.fft.ifft2(f))), cmap="RdBu", vmin=-mx, vmax=mx)
                ax[s*2+1, o].set_aspect("auto")
                ax[s*2+1, o].set_yticklabels([])
            filter_bank[s, o, :, :] = f

    if display_filters:
        plt.show()

    return filter_bank.astype(np.complex64)

filter_bank = get_v1_filter_bank(0, display_filters=True)


def apply_filter_bank(input_image, filter_bank):
    """
    Input image should have dimensions <h, w> or <s, o, h, w> or <b, s, o, h, w, d>.
    Filter bank should have dimensions <s, o, h, w>
    """
    bdsohw_input_image = input_image[None, None, None, None, :, :]

    # pad image to filter size, which is 2*box_height, 2*box_width (to prevent too much wrapping)
    padded_input = tf.pad(bdsohw_input_image, [[0, 0], [0, 0], [0, 0], [0, 0],
                            [int(np.ceil(box_height / 2)), int(box_height / 2)],
                            [int(np.ceil(box_width / 2)), int(box_width / 2)]], mode='CONSTANT')

    input_in_freqdomain = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(padded_input, tf.zeros_like(padded_input))))

    padded_result = tf.signal.ifft2d(input_in_freqdomain * filter_bank[None, None, :, :, :, :])

    presult = tf.signal.fftshift(padded_result[0, 0, :, :, :, :], [2, 3])

    return presult[:, :, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                        int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))]

In [None]:
pf = 3

eps = np.finfo(np.float32).tiny
u, v = np.mgrid[-box_height*pf/2:box_height*pf/2,-box_width*pf/2:box_width*pf/2].astype(np.float32)
r = np.sqrt(u**2 + v**2)[None, None, :, :] / (box_height*pf/2)
r[r == 0] = 0.5
angle = np.arctan2(u, v)[None, None, :, :] # <b, d, s, o, c, h, w>
angles = np.arange(n_v1_orientations)[None, :, None, None].astype(np.float32)/n_v1_orientations
angle_mask_widths = 4. * np.ones((n_v1_scales, n_v1_orientations)).astype(np.float32)

def make_blur_filters(e): # Returns masks of shape <o, c, h, w>
    radial_mask = (1/(r + eps)) ** e
    distance_mask = (1/(r + eps)) ** (e+1)

    # Uses von-Mises distribution (via Bessel function)
    bp_angle_masks = tf.exp(-angle_mask_widths[:, :, None, None] * tf.cos(angle - pi - pi * angles)) / (2*pi*tf.math.bessel_i0(angle_mask_widths[:, :, None, None]))
    bn_angle_masks = tf.exp(-angle_mask_widths[:, :, None, None] * tf.cos(angle - pi * angles)) / (2*pi*tf.math.bessel_i0(angle_mask_widths[:, :, None, None]))

    x1 = tf.concat([radial_mask * bp_angle_masks, radial_mask * bn_angle_masks], axis=1)
    x2 = tf.concat([distance_mask * bp_angle_masks, distance_mask * bn_angle_masks], axis=1)

    return (x1, x2)

blur_e = 2.
(x1_filter, x2_filter) = make_blur_filters(blur_e)
x1_filter_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(x1_filter, 0.), [2,3]))
x2_filter_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(x2_filter, 0.), [2,3]))
print(x1_filter_fft.shape)

In [None]:
# 1. Render glyphs

def get_glyph_image(glyph_char):
    """Returns a np.array of shape [box_height, box_width] containing the glyph at the center."""
    return f.glyph(glyph_char).as_matrix(normalize=True).with_padding_to_constant_box_width(box_width).astype(np.float32)

def get_glyph_ink_width(glyph_char):
    """Returns the width of the rendered glyph in pixels."""
    return f.glyph(glyph_char).ink_width

def get_v1_response(glyph_image):
    """Returns a np.array of shape [n_v1_scales, n_v1_orientations, box_height, box_width] and type complex64,
    containing the local responses to the V1 filter bank (after inverse Fourier transform, i.e. in the spatial domain)."""
    with tf.device("/gpu:0"):
        filtered = apply_filter_bank(glyph_image, filter_bank)
    return filtered

def get_v4_strength(glyph_v1_response):
    """Returns a np.array of shape [n_v1_scales, 2*n_v1_orientations, box_height, box_width]."""
    with tf.device("/gpu:0"):
        # Pad the image, fft convolve it, 
        padded_image = tf.pad(tf.abs(glyph_v1_response), [[0, 0], [0, 0], [int((pf-1)*box_height/2), int((pf-1)*box_height/2)],
                                                                  [int((pf-1)*box_width/2), int((pf-1)*box_width/2)]])
        padded_image_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(tf.concat(2*[padded_image], axis=1), 0.), [2, 3]))
        x1_full = tf.math.real(tf.signal.fftshift(tf.signal.ifft2d(padded_image_fft * x1_filter_fft), [2, 3]))
        x2_full = tf.math.real(tf.signal.fftshift(tf.signal.ifft2d(padded_image_fft * x2_filter_fft), [2, 3]))

        x1 = x1_full[:, :, int(np.ceil(box_height * (pf-1) / 2)):int(box_height + np.ceil(box_height * (pf-1)/ 2)),
                        int(np.ceil(box_width  * (pf-1)/ 2)):int(box_width + np.ceil(box_width * (pf-1)/ 2))]
        x2 = x2_full[:, :, int(np.ceil(box_height * (pf-1) / 2)):int(box_height + np.ceil(box_height * (pf-1)/ 2)),
                        int(np.ceil(box_width  * (pf-1)/ 2)):int(box_width + np.ceil(box_width * (pf-1)/ 2))]

        distances = x1 / x2
        fullnesses = x1 * distances ** (blur_e - 1)
        strengths = fullnesses / distances
        return strengths


glyph_images = {c: get_glyph_image(c) for c in tqdm(glyph_char_list)}
print("  ✓", len(glyph_char_list), "glyphs rendered.", flush=True)
glyph_ink_widths = {c: get_glyph_ink_width(c) for c in tqdm(glyph_char_list)}
print("  ✓", len(glyph_char_list), "glyphs measured.", flush=True)
glyph_v1_responses = {c: get_v1_response(glyph_images[c]) for c in tqdm(glyph_char_list)}
print("  ✓", len(glyph_char_list), "glyphs filtered.", flush=True)
glyph_v4_strengths = {c: get_v4_strength(glyph_v1_responses[c]) for c in tqdm(glyph_char_list)}
print("  ✓", len(glyph_char_list), "glyphs strengthed.", flush=True)
print("glyph", glyph_v4_strengths["a"].shape)
print("glyph", glyph_v1_responses["a"].shape)

# 1a. Show an example of filtered glyphs
for si in range(n_v1_scales):
    print("Scale:", si)
    #plt.imshow(glyph_images["b"], cmap="gray")
    plt.imshow(np.sum(np.abs(glyph_v1_responses["a"][si, :, :, :]), (0))**2, cmap="Reds", alpha=1.0)
    plt.colorbar()
    plt.show()

# 1b. Show an example of filtered glyphs
for si in range(n_v1_scales):
    print("SCALE:", si)
    #plt.imshow(glyph_images["b"], cmap="gray")
    plt.imshow(tf.reduce_sum(glyph_v4_strengths["a"][si, :, :, :], [0]), cmap="Reds", alpha=1.0)
    plt.colorbar()
    plt.show()

plt.imshow(glyph_images["a"], cmap="gray")
plt.show()



In [None]:
# 2. Assemble pairs

def get_pair_translations(char1, char2, distance_deltas, distance_factors=None):
    """Returns two 1D arrays of distances (in pixels) by which the left and right glyph need to be translated (i.e. shifted horizontally)
    in order to place the two glyphs at the desired distances.
    
    Example: distance_deltas = [-2, 0, 2] or distance_factors=[0.7, 1.0, 1.5]
    """

    optimal_distance = int(f.pair_distance(char1, char2) + f.minimum_ink_distance(char1, char2))

    if distance_factors is None:
        if distance_deltas is None:
            raise ValueError("Must provide either distance_deltas or distance_factors")
        
        sample_distances = optimal_distance + np.array(distance_deltas)
    else:
        if distance_deltas is not None:
            raise ValueError("Must provide either distance_deltas or distance_factors, not both")

        sample_distances = optimal_distance * np.array(distance_factors)

    total_width_at_minimum_ink_distance = glyph_ink_widths[char1] + glyph_ink_widths[char2] - f.minimum_ink_distance(char1, char2)
    total_ink_width = glyph_ink_widths[char1] + glyph_ink_widths[char2]
    ink_width_left = np.floor(total_ink_width / 4)
    ink_width_right = np.ceil(total_ink_width / 4)
    sample_distances_left = np.ceil(sample_distances / 2)
    sample_distances_right = np.floor(sample_distances / 2)

    left_translations = (-(np.ceil(total_width_at_minimum_ink_distance/2) + sample_distances_left) - (-ink_width_left)).astype(np.int32)
    right_translations = ((np.floor(total_width_at_minimum_ink_distance/2) + sample_distances_right) - ink_width_right).astype(np.int32)
    
    return (left_translations, right_translations)
    
left_images = []
right_images = []
left_v1_responses = []
right_v1_responses = []
left_v4_strengths = []
right_v4_strengths = []
left_translations = []
right_translations = []

for c1 in tqdm(glyph_char_list):
    for c2 in reversed(glyph_char_list):
        left_images.append(glyph_images[c1])
        right_images.append(glyph_images[c2])
        left_v1_responses.append(glyph_v1_responses[c1])
        right_v1_responses.append(glyph_v1_responses[c2])
        left_v4_strengths.append(glyph_v4_strengths[c1])
        right_v4_strengths.append(glyph_v4_strengths[c2])

        lt, rt = get_pair_translations(c1, c2, None, sample_distance_factors) #sample_distance_deltas
        left_translations.append(lt)
        right_translations.append(rt)

print("  ✓", len(glyph_char_list)**2, "pairs assembled.")

# 3. Set up generator to yield pairs, and wrap generator in a tf.Dataset

def return_pair():
    i = 0
    while i < len(left_images):
        yield {
            "left_image": left_images[i],
            "right_image": right_images[i],
            "left_v1_response": left_v1_responses[i],
            "right_v1_response": right_v1_responses[i],
            "left_v4_strength": left_v4_strengths[i],
            "right_v4_strength": right_v4_strengths[i],
            "left_translations": left_translations[i],
            "right_translations": right_translations[i],
        }
        i = (i + 1) % len(left_images)

dataset = tf.data.Dataset.from_generator(
     return_pair,
     {
      "left_image": tf.float32,
      "right_image": tf.float32,
      "left_v1_response": tf.complex64,
      "right_v1_response": tf.complex64,
      "left_v4_strength": tf.float32,
      "right_v4_strength": tf.float32,
      "left_translations": tf.int32,
      "right_translations": tf.int32,
     },
     {
      "left_image": tf.TensorShape([box_height, box_width]),
      "right_image": tf.TensorShape([box_height, box_width]),
      "left_v1_response": tf.TensorShape([n_v1_scales, n_v1_orientations, box_height, box_width]),
      "right_v1_response": tf.TensorShape([n_v1_scales, n_v1_orientations, box_height, box_width]),
      "left_v4_strength": tf.TensorShape([n_v1_scales, 2*n_v1_orientations, box_height, box_width]),
      "right_v4_strength": tf.TensorShape([n_v1_scales, 2*n_v1_orientations, box_height, box_width]),
      "left_translations": tf.TensorShape([n_sample_distances,]),
      "right_translations": tf.TensorShape([n_sample_distances,])
     },
)

print("\n  ✓ Dataset ready.")

In [None]:
# 4. Apply horizontal translations in the dataset

def translate_4d_image(input_image, translations):
    """Shifts images to left/right and back-fills with zeros.
    @param image: <sizes, orientations, height, width>
    @param translations: <len(translations)>
    @output        <len(translations), sizes, orientations, height, width>
    """

    images = tf.tile(input_image[:, :, :, :, None], [1, 1, 1, 1, translations.shape[0]]) # create len(shifts) channel copies
    fill_constant = 0
    left = tf.maximum(0, tf.reduce_max(translations)) # positive numbers are shifts to the right, for which we need to add zeros on the left
    right = -tf.minimum(0, tf.reduce_min(translations)) # negative numbers are shifts to the left, for which we need to add zeros on the right
    left_mask = tf.ones(shape=(tf.shape(images)[0], tf.shape(images)[1], tf.shape(images)[2], left, tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    right_mask = tf.ones(shape=(tf.shape(images)[0], tf.shape(images)[1], tf.shape(images)[2], right, tf.shape(images)[4]), dtype=images.dtype) * fill_constant
    padded_images = tf.concat([left_mask, images, right_mask], axis=3) # pad on axis 3 (i.e. width-wise)

    # Now that the images are all padded, we need to crop them to implement the shifts.
    def crop_image_widthwise(image_and_shift):
        image = image_and_shift[0] # sohw
        shift = image_and_shift[1] # 
        return image[:, :, :, left-shift:left-shift+input_image.shape[3]] # positive shift: left-shift

    result = tf.map_fn(
        crop_image_widthwise,
        (tf.einsum("sohwd->dsohw", padded_images), translations),
        dtype=images.dtype)

    # Manually ensure that the width-dimension hasn't changed
    s = list(result.shape)
    s[-1] = box_width
    result.set_shape(s)

    return result

def apply_translations(d):
    d["left_image"] = translate_4d_image(d["left_image"][None, None, :, :], d["left_translations"])[:, 0, 0, :, :]
    d["right_image"] = translate_4d_image(d["right_image"][None, None, :, :], d["right_translations"])[:, 0, 0, :, :]
    d["left_v1_response"] = translate_4d_image(d["left_v1_response"], d["left_translations"])
    d["right_v1_response"] = translate_4d_image(d["right_v1_response"], d["right_translations"])
    d["left_v4_strength"] = translate_4d_image(d["left_v4_strength"], d["left_translations"])
    d["right_v4_strength"] = translate_4d_image(d["right_v4_strength"], d["right_translations"])
    del d["left_translations"]
    del d["right_translations"]
    return (d, 0.)  # The zero here doesn't do anything and is just to make Keras happy, because model.fit expects a dataset of 2-tuples where the second entry is the target value.

translated_dataset = dataset.map(apply_translations)

print("dataset shapes:", translated_dataset.element_spec)

In [None]:
# 5. Utility functions
eps = np.finfo(np.float32).tiny

def invspa(t):
    return np.log(np.exp(t) - 1).astype(np.float32)

def invsp(t):
    if t == 0:
        return -1e10
    else:
        return np.log(np.exp(t) - 1).astype(np.float32)

def sp(t):
    return tf.nn.softplus(t)

# 6. Generating G-cell fragments

u, v = np.mgrid[-box_height:box_height,-box_width:box_width].astype(np.float32)
u = u / (box_width)
v = v / (box_width)
r = np.sqrt(u**2 + v**2)[None, None, :, :]
r[r == 0] = 0.5
angle = np.arctan2(u, v)[None, None, :, :] # <b, d, s, o, c, h, w>
angles = np.arange(n_v1_orientations)[:, None, None, None].astype(np.float32)/n_v1_orientations
angle_mask_widths = 4. * np.ones((n_v1_orientations, n_v4_scales)).astype(np.float32)

def make_v4_filters(k, spa, sn, hp, hn, cp, cn): # Returns masks of shape <o, c, h, w>
    x_n = n_v4_scales + 2
    
    xs = tf.linspace(cp, 1, x_n) ** k * cn # k can be one or above (or below)
    
    a = xs[:-2][None, :, None, None]
    c = xs[1:-1][None, :, None, None]
    b = xs[2:][None, :, None, None]
    
    triangles = tf.where(tf.reduce_all([r > a, r <= c], axis=0), 2*(r-a)/((b-a)*(c-a)),
                        tf.where(tf.reduce_all([r > c, r < b], axis=0), 2*(b-r)/((b-a)*(b-c)), eps))

    radial_mask = (tf.nn.relu(triangles) + eps) / (eps + 2*(b-a))

    #flat_indices = tf.reshape((r * box_width).astype(np.int), [4 * box_height * box_width])
    #radial_mask = tf.reshape(tf.gather(sp(k), flat_indices), [1, 1, 2 * box_height, 2*box_width])

    # Uses von-Mises distribution (via Bessel function)
    bp_angle_masks = tf.exp(-angle_mask_widths[:, :, None, None] * tf.cos(angle - pi - pi * angles)) / (2*pi*tf.math.bessel_i0(angle_mask_widths[:, :, None, None]))
    bn_angle_masks = tf.exp(-angle_mask_widths[:, :, None, None] * tf.cos(angle - pi * angles)) / (2*pi*tf.math.bessel_i0(angle_mask_widths[:, :, None, None]))

    bp_masks = radial_mask * bp_angle_masks
    bn_masks = radial_mask * bn_angle_masks

    # Each bp/bn_mask fragment (the positive part) should add up to exactly one.
    bp_masks_normed = 4*bp_masks / (eps + tf.reduce_sum((bp_masks)**2, [0, 2, 3], keepdims=True))
    bn_masks_normed = 4*bn_masks / (eps + tf.reduce_sum((bn_masks)**2, [0, 2, 3], keepdims=True))

    return tf.concat([bp_masks_normed, bn_masks_normed], axis=0)


def make_blur_filters(exp1, exp2): # Returns masks of shape <o, c, h, w>
    # TODO: add anisotropy
    radial_mask = (1/(r + eps)) ** exp1
    distance_mask = (1/(r + eps)) ** exp2

    # Uses von-Mises distribution (via Bessel function)
    bp_angle_masks = tf.exp(-angle_mask_widths[:, :, None, None] * tf.cos(angle - pi - pi * angles)) / (2*pi*tf.math.bessel_i0(angle_mask_widths[:, :, None, None]))
    bn_angle_masks = tf.exp(-angle_mask_widths[:, :, None, None] * tf.cos(angle - pi * angles)) / (2*pi*tf.math.bessel_i0(angle_mask_widths[:, :, None, None]))

    strength = tf.concat([radial_mask * bp_angle_masks, radial_mask * bn_angle_masks], axis=0)
    distance = tf.concat([distance_mask * bp_angle_masks, distance_mask * bn_angle_masks], axis=0)

    return (strength, distance)

def make_losses_filters(g_spreads):
    rs = r[None, ...] # <b, d, c, h, w>
    gs = g_spreads[None, None, :, None, None]
    
    return tf.exp(-rs**2 / (2*gs**2)) / (gs * tf.math.sqrt(2.*3.14159276))


# 7. V4 layer
class V4Layer(tf.keras.layers.Layer):
    def __init__(self, skip_v4_convolution=False, **kwargs):
        super(V4Layer, self).__init__(**kwargs)

        self.skip_v4_convolution = skip_v4_convolution

        self.csf = self.add_weight(shape=((n_v1_scales, n_v1_orientations)),
                                 initializer=tf.keras.initializers.Constant(np.tile(np.array([3, 15, 25, 3, 1])[:, None], [1, n_v1_orientations]).astype(np.float32) / 200.),
                                 name="csf",
                                 trainable=True)

        self.v1_hra_k = self.add_weight(shape=((n_v1_scales, n_v1_orientations)),
                                 initializer=tf.keras.initializers.Constant(np.tile(invsp([2.,2.,2.,2.,2.])[:, None], [1, n_v1_orientations])), # 1 to .02
                                 name="v1_hra_k",
                                 trainable=True)
        self.v1_hra_b = self.add_weight(shape=((n_v1_scales, n_v1_orientations)),
                                 initializer=tf.keras.initializers.Constant(np.tile(invsp([.5,.5,.5,.5,.5])[:, None], [1, n_v1_orientations])), # 1 to .02
                                 name="v1_hra_b",
                                 trainable=True)


        self.exp1 = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(3.0),
                                 name="exp1",
                                 trainable=False)
        self.exp2 = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(4.0),
                                 name="exp2",
                                 trainable=False)

        #self.v4_scales_exponent = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(1.5), name="v4_scales_exponent", trainable=False)
        #self.v4_scales_min = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(.54), name="v4_scales_min", trainable=False)
        #self.v4_scales_factor = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(0.75), name="v4_scales_factor", trainable=False)
#
        #self.v4_widths_exponent = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(1.5), name="v4_widths_exponent", trainable=False)
        #self.v4_widths_min = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(.6), name="v4_widths_min", trainable=False)
        #self.v4_widths_factor = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(0.2), name="v4_widths_factor", trainable=False)
#
        #self.v4_depression_scale_fraction = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(np.array([0.3, .3, .3, .3, .3]).astype(np.float32)), name="v4_depression_scale_fraction", trainable=True)
        #self.v4_inner_negative_depth = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(np.array([1.,1.,1.,1.,1.]).astype(np.float32)), name="v4_inner_negative_depth", trainable=True)

        # How far away from the center we are
        #self.v4_cp = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([0.0125, 0.016, 0.021, 0.03, 0.056, 0.08, .12, .2]).astype(np.float32))), name="v4_cp", trainable=True)
        #self.v4_cn = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([0.0125, 0.016, 0.021, 0.03, 0.056, 0.08, .12, .2]).astype(np.float32) * 0.16)), name="v4_cn", trainable=True)
        self.v4_cp = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(0.2), name="v4_cp", trainable=False)
        self.v4_cn = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(0.25), name="v4_cn", trainable=False)
        # How wide the rims are
        self.v4_sp = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([0.00015, 0.00026, 0.00057, 0.00109, .00216, .0036, .0060, .0097]).astype(np.float32))), name="v4_sp", trainable=True)
        self.v4_sn = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([0.00015, 0.00026, 0.00057, 0.00109, .00216, .0036, .0060, .0097]).astype(np.float32) * .25)), name="v4_sn", trainable=True)
        # How deep the rims are
        self.v4_hp = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([.8, .8, .8, .8, .8, .8, .8, .8]).astype(np.float32))), name="v4_hp", trainable=True)
        self.v4_hn = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([.4, .4, .4, .4, .4, .4, .4, .4]).astype(np.float32)*4.)), name="v4_hn", trainable=True)

        #self.v4_cp = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([0.03]).astype(np.float32))), name="v4_cp", trainable=True)
        #self.v4_cp = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(np.array([-3]).astype(np.float32)), name="v4_cp", trainable=True)
        #self.v4_cn = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([0.03]).astype(np.float32))), name="v4_cn", trainable=True)
        ## How wide the rims are
        ##self.v4_sp = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([0.00109]).astype(np.float32))), name="v4_sp", trainable=True)
        #self.v4_sp = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(np.array([.2]).astype(np.float32)), name="v4_sp", trainable=True)
        #self.v4_sn = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([0.00109]).astype(np.float32))), name="v4_sn", trainable=True)
        ## How deep the rims are
        #self.v4_hp = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([.8]).astype(np.float32))), name="v4_hp", trainable=True)
        #self.v4_hn = self.add_weight(shape=(n_v4_scales), initializer=tf.keras.initializers.Constant(invspa(np.array([.4]).astype(np.float32)*4.)), name="v4_hn", trainable=True)

        self.kk = self.add_weight(shape=(), initializer=tf.keras.initializers.Constant(2.80), name="v4_kk", trainable=False)

        #self.v4_scales = self.add_weight(shape=(n_v4_scales),
        #                              initializer=tf.keras.initializers.Constant(np.array([.4,1.1,2.1,4.0,7.5]).astype(np.float32)/box_width),
        #                              #initializer=tf.keras.initializers.Constant(np.array([1.4,1.6,1.75,1.83,1.9]).astype(np.float32)),
        #                              name="v4_scales", trainable=False)
        #self.v4_widths = self.add_weight(shape=(n_v4_scales),
        #                                    initializer=tf.keras.initializers.Constant(np.array([.5,0.7,1.,1.7,3.8]).astype(np.float32)/box_width),
        #                                    #initializer=tf.keras.initializers.Constant(np.array([1.4,1.25,1.15,1.03,.98]).astype(np.float32)),
        #                                    name="v4_widths", trainable=False)
        self.v4_angle_mask_widths = self.add_weight(shape=(n_v1_orientations, n_v4_scales),
                                                 initializer=tf.keras.initializers.Constant(4.),
                                                 name="v4_angle_mask_widths", trainable=False)
        self.v4_filter_strengths = self.add_weight(shape=(n_v1_orientations, n_v4_scales),
                                               initializer=tf.keras.initializers.Constant(1.),
                                               name="v4_filter_strengths", trainable=False)

        self.v4_hra_k = self.add_weight(shape=(2*n_v1_orientations, n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(np.tile(invsp([1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4])[None, :], [2*n_v1_orientations, 1])),
                                 #initializer=tf.keras.initializers.Constant(np.tile(invsp([1.4])[None, :], [2*n_v1_orientations, 1])),
                                 name="v4_hra_k",
                                 trainable=True)
        self.v4_hra_b = self.add_weight(shape=(2*n_v1_orientations, n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(np.tile(invsp([1.5,3,2,1,1, 1,1,1])[None, :], [2*n_v1_orientations, 1])),
                                 #initializer=tf.keras.initializers.Constant(np.tile(invsp([1])[None, :], [2*n_v1_orientations, 1])),
                                 name="v4_hra_b",
                                 trainable=True)


        # Each ring should only be able to draw from 
        self.v1_v4_scale_weights = self.add_weight(shape=(n_v1_scales, n_v4_scales),
                                 #initializer=tf.keras.initializers.Constant((eps + np.triu(np.ones((n_v1_scales, n_v4_scales))).astype(np.float32) * .5**toeplitz(np.zeros(n_v1_scales), np.arange(n_v4_scales)))),
                                 initializer=tf.keras.initializers.Constant(np.tile(np.array([.04, .1, .13, .04, .01])[:, None], [1, n_v4_scales]).astype(np.float32)),
                                 #initializer=tf.keras.initializers.Constant(np.array([1])[None, :].astype(np.float32)),
                                 name="v1_v4_scale_weights",
                                 trainable=True)
        self.v1_v4_orientation_weights = self.add_weight(shape=(2*n_v1_orientations, n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(1.),
                                 name="v1_v4_orientation_weights",
                                 trainable=True)


        self.v1_blur_orientation_weights = self.add_weight(shape=(2*n_v1_orientations, n_v1_scales),
                                 initializer=tf.keras.initializers.Constant(1.),
                                 name="v1_blur_orientation_weights",
                                 trainable=True)

        self.v4_b_scale_weights = self.add_weight(shape=(1, n_v4_scales),
                                 #initializer=tf.keras.initializers.Constant((eps + np.triu(np.ones((n_v1_scales, n_v4_scales))).astype(np.float32) * .5**toeplitz(np.zeros(n_v1_scales), np.arange(n_v4_scales)))),
                                 initializer=tf.keras.initializers.Constant(invsp(1.)),
                                 name="v4_b_scale_weights",
                                 trainable=True)
        self.v4_b_orientation_weights = self.add_weight(shape=(2*n_v1_orientations, n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(1.),
                                 name="v4_b_orientation_weights",
                                 trainable=True)

        # Using this for filtering
        self.g_hra_k = self.add_weight(shape=(n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(np.array([0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4]).astype(np.float32) / 20.),
                                 name="g_hra_k",
                                 trainable=True)
        self.g_hra_b = self.add_weight(shape=(n_v4_scales),
                                 #initializer=tf.keras.initializers.Constant(0.1),
                                 initializer=tf.keras.initializers.Constant(invsp([.4]).astype(np.float32)),
                                 name="g_hra_b",
                                 trainable=True)


        #g_scale_scores = self.g_scale_score_factor ** (sp(self.g_hra_k) * np.arange(n_v4_scales).astype(np.float32))

        self.g_scale_score_factor = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(0.4),
                                 name="g_scale_score_factor",
                                 trainable=True)

        self.feedback_modulation_strength = self.add_weight(shape=(1, 2*n_v1_orientations, 1, 1),
                                 initializer=tf.keras.initializers.Constant(invsp(1.)),
                                 name="feedback_modulation_strength",
                                 trainable=True)

        self.b_dn_b = self.add_weight(shape=(1, 2*n_v1_orientations),
                                 initializer=tf.keras.initializers.Constant(invsp(0.25)),
                                 name="b_dn_b",
                                 trainable=True)
        self.b_dn_k = self.add_weight(shape=(1, 2*n_v1_orientations),
                                 initializer=tf.keras.initializers.Constant(invsp(2.2)),
                                 name="b_dn_k",
                                 trainable=True)
        self.b_dn_k_pool = self.add_weight(shape=(1, 2*n_v1_orientations),
                                 initializer=tf.keras.initializers.Constant(invsp(3.9)),
                                 name="b_dn_k_pool",
                                 trainable=True)

        # We would want the normalization pool to mostly include, for each size/orientation, the opposite orientation.
        # We would also want to include smaller sizes. But perhaps that's not so important?

        basic_dn_matrix = np.zeros((1, 2*n_v1_orientations, 1, 2*n_v1_orientations)).astype(np.float32)

        # The first ones are the ones that count towards the second
        for s1 in range(1):
            for o1 in range(2*n_v1_orientations):
                for s2 in range(1):
                    for o2 in range(2*n_v1_orientations):
                        s_distance = np.exp(-(s1 - s2)**2)
                        basic_dn_matrix[s1, o1, s2, o2] = s_distance

        self.b_dn_weights = self.add_weight(shape=((1, 2*n_v1_orientations, 1, 2*n_v1_orientations)),
                                 initializer=tf.keras.initializers.Constant(invspa(basic_dn_matrix.astype(np.float32))),
                                 name="b_dn_weights",
                                 trainable=True)

        self.g_dn_b = self.add_weight(shape=(n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(invsp(1.5)),
                                 name="g_dn_b",
                                 trainable=True)
        self.g_dn_k = self.add_weight(shape=(n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(invsp(2.2)),
                                 name="g_dn_k",
                                 trainable=True)
        self.g_dn_k_pool = self.add_weight(shape=(n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(invsp(1.5)),
                                 name="g_dn_k_pool",
                                 trainable=True)

        # We would want the normalization pool to mostly include, for each size/orientation, the opposite orientation.
        # We would also want to include smaller sizes. But perhaps that's not so important?

        basic_g_dn_matrix = np.zeros((n_v4_scales, n_v4_scales)).astype(np.float32)

        # The first ones are the ones that count towards the second
        for s1 in range(n_v4_scales):
            for s2 in range(n_v4_scales): # the smaller ones should always suppress the bigger ones
                s_distance = 1. if s1 <= s2 else 0.0001
                basic_g_dn_matrix[s1, s2] = s_distance

        self.g_dn_weights = self.add_weight(shape=((n_v4_scales, n_v4_scales)),
                                 initializer=tf.keras.initializers.Constant(invspa(basic_g_dn_matrix.astype(np.float32))),
                                 name="g_dn_weights",
                                 trainable=True)

        self.g = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(0.5),
                                 name="g",
                                 trainable=True)
        self.k = self.add_weight(shape=(n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(np.array([0.01, 0.1, 0.2, 0.3, .25, 0.125, 0.0325, 0.005]).astype(np.float32)),
                                 name="k",
                                 trainable=True)
    def print_weights(self):
        #print("CSF")
        #plt.imshow(self.csf.numpy())
        #plt.colorbar()
        #plt.show()
        #print("HRA parameters:")
        #print("---------")
        #print("V1 (scales/orientations) exponents [k] and half-points [b]")
        #plt.imshow(sp(self.v1_hra_k))
        #plt.colorbar()
        #plt.show()
        #plt.imshow(sp(self.v1_hra_b))
        #plt.colorbar()
        #plt.show()
#
        #print("G (scales/orientations) exponents [k] and half-points [b]")
        #print(sp(self.g_hra_k))
        #print(sp(self.g_hra_b))
        #plt.plot(sp(self.g_hra_k))
        #plt.show()
        #print("beta")
        #plt.plot(sp(self.g_hra_b))
        #plt.show()

        #print("M")
        #v4_filters = make_v4_filters(self.kk, self.v4_sp, self.v4_sn, self.v4_hp, self.v4_hn, self.v4_cp, self.v4_cn)[None, None, None, ...]
        #print(self.v4_sp.numpy(), self.v4_hp.numpy(), self.v4_cp.numpy())
        #for i in range(n_v4_scales):
        #    plt.imshow(tf.reduce_sum(v4_filters[0, 0, 0, :, i, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
        #                    int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))], [0]))
        #    plt.colorbar()
        #    plt.show()
#
        #print("Forward linking matrix:")
        #print("Scale weights:")
        #plt.imshow(self.v1_v4_scale_weights[:, :])
        #plt.colorbar()
        #plt.show()
        #print("Orientation weights:")
        #plt.imshow(self.v1_v4_orientation_weights[: ,:])
        #plt.colorbar()
        #plt.show()
        print("Blur orientation weights:")
        plt.imshow(self.v1_blur_orientation_weights[: ,:])
        plt.colorbar()
        plt.show()

        #print("BDN exponents: upper k:", sp(self.b_dn_k).numpy(), "pool k:", sp(self.b_dn_k_pool).numpy())
        #print("BDN exponents, b:", sp(self.b_dn_b).numpy())
        #print("BDN Weights (size 2):")
        #plt.imshow(sp(self.b_dn_weights)[0, :, 0, :])
        #plt.colorbar()
        #plt.show()
#
        #print("GDN exponents: upper k:", sp(self.g_dn_k).numpy(), "pool k:", sp(self.g_dn_k_pool).numpy())
        #print("GDN exponents, b:", sp(self.g_dn_b).numpy())
        #print("GDN Weights (size 2):")
        #plt.imshow(sp(self.g_dn_weights)[:, :])
        #plt.colorbar()
        #plt.show()
##
        ##g_scale_scores = self.g_scale_score_factor ** (sp(self.g_hra_k) * np.arange(n_v4_scales).astype(np.float32))
        ##print("G scale scores", g_scale_scores.numpy())
#
        #print("G exponent:", .5 + tf.nn.relu(self.g.numpy()))
        ##plt.plot(0.5 + tf.nn.relu(self.g.numpy()))
        ##plt.show()
        #print("K feedback strength exponents:")
        #plt.plot(self.k.numpy())
        #plt.show()

    def hra_v1(self, i):
        return i
        k = sp(self.v1_hra_k[None, None, :, :, None, None])
        b = sp(self.v1_hra_b[None, None, :, :, None, None])
        return ((i + eps) ** k) / (eps + b**k + (i + eps) ** k)
        # We need to ensure that whatever comes out of v1 is scaled.
        # The point of V1 scaling is that complex cells respond nonlinearly in real life; such that
        # e.g. stems are active in the center, and not as much on the outside.
        # But we need to ensure that total energy is kept the same, and simply redistributed.
        # Question is whether there is a way to normalize this.

    def hra_v4(self, i):
        return i
        k = sp(self.v4_hra_k[None, None, :, :, None, None])
        b = sp(self.v4_hra_b[None, None, :, :, None, None])
        return ((i + eps) ** k) # / (eps + b**k + (i + eps) ** k)

    def hra_g(self, i):
        k = sp(self.g_hra_k) #sp(self.g_hra_k[None, None, :, None, None])
        b = sp(self.g_hra_b[None, None, :, None, None])
        return ((i + eps) ** k) #/ (eps + b**k + (i + eps) ** k)

    def score_g(self, i):
        # We get [b, d, c, h, w], and for each c we want to have a quadratic equation


        return (self.g_hra_b[None, None, :, None, None] * i ** self.g_hra_k[None, None, :, None, None])


    def call(self, inputs):
        v1c = self.hra_v1(tf.abs(inputs)) #* tf.nn.relu(self.csf[None, None, :, :, None, None]) / (eps + tf.reduce_sum(tf.nn.relu(self.csf))), [2], keepdims=True) # should end up with just a single scale.
        pf = 3

        b_balanced = tf.pad(tf.concat([v1c]*2, axis=3), [[0, 0], [0, 0], [0, 0], [0, 0],
                                [int(np.ceil(box_height * (pf-1) / 2)), int(box_height * (pf-1)/ 2)],
                                [int(np.ceil(box_width * (pf-1)/ 2)), int(box_width * (pf-1)/ 2)]], mode='CONSTANT')
        
        b_balanced_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(b_balanced, 0.), [4, 5]))

        v4_filters = make_v4_filters(self.kk, self.v4_sp, self.v4_sn, self.v4_hp, self.v4_hn, self.v4_cp, self.v4_cn)[None, None, None, ...]
        
        v4_filters_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(v4_filters, 0.), [5, 6]))
        
        v4_activations_0_by_s_o = tf.nn.relu(tf.math.real(tf.signal.fftshift(tf.signal.ifft2d(v4_filters_fft * b_balanced_fft[:, :, :, :, None, :, :]), [5, 6])))
        
        v4_activations_0 = self.hra_v4(tf.einsum("bdsochw,sc->bdochw", v4_activations_0_by_s_o, eps + tf.nn.relu(self.v1_v4_scale_weights) / (eps + tf.reduce_sum(tf.nn.relu(self.v1_v4_scale_weights))) )) # hra_v4
        
        G_0 = tf.einsum("bdochw,oc->bdchw", v4_activations_0, tf.nn.relu(self.v1_v4_orientation_weights) / (eps + tf.reduce_mean(eps + tf.nn.relu(self.v1_v4_orientation_weights), axis=[0], keepdims=True) ))

        x1_filters, x2_filters = make_blur_filters(self.exp1, self.exp1 + 1)
        x1_filters_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(x1_filters[None, None, None, ...], 0.), [5, 6]))
        x2_filters_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(x2_filters[None, None, None, ...], 0.), [5, 6]))
        v1_x1 = tf.nn.relu(tf.math.real(tf.signal.fftshift(tf.signal.ifft2d(x1_filters_fft * b_balanced_fft[:, :, :, :, None, :, :]), [5, 6])))
        v1_x2 = tf.nn.relu(tf.math.real(tf.signal.fftshift(tf.signal.ifft2d(x2_filters_fft * b_balanced_fft[:, :, :, :, None, :, :]), [5, 6])))

        v1_distances = v1_x1 / v1_x2
        v1_fullnesses = v1_x1 * v1_distances ** (self.exp1 - 1)
        v1_strengths = tf.einsum("bdcoshw,oc->bdhw", v1_fullnesses / v1_distances, (tf.nn.relu(self.v1_blur_orientation_weights) / (eps + tf.reduce_mean(eps + tf.nn.relu(self.v1_blur_orientation_weights), axis=[0], keepdims=True))))[:, :, None, :, :]

        v1_strengths = tf.einsum("bdsochw->bdchw", v1_x1 / v1_x2)

        # Now we want these G-cells to compete against one another.
        # In a previous version, the G-cells competed indirectly, via B-cells. But we are now competing directly.

        #G_0_dn = G_0 ** 3 * tf.nn.relu(self.k)[None, None, :, None, None] / (sp(self.g_hra_b)[None, None, :, None, None] ** 2 + tf.einsum("bdchw,cq->bdqhw", G_0_filtered**2, sp(self.g_dn_weights)) + eps)

        #gn = (.5 + tf.nn.relu(self.g[None, None, :, None, None]))
        #gn = (.5 + tf.nn.relu(self.g))
        #G_strength = (tf.nn.relu(self.k[None, None, :, None, None]) * (eps + G_0) + eps) #/ (eps + (tf.nn.relu(self.attention_attraction_beta[None, None, :, None, None]) + eps) ** gn + (eps + relevant_p) ** gn)
        #G_strength = tf.nn.relu(G_0) * tf.nn.relu(self.k)[None, None, :, None, None] + eps
        # We already have the option to make some stronger than others, and that's via orientation and scale weights! So we don't need to do anything at all.

        return (G_0[:, :, :, int(np.ceil(box_height * (pf-1) / 2)):int(box_height * (pf-1) + np.ceil(box_height / 2)),
                              int(np.ceil(box_width * (pf-1) / 2)):int(box_width * (pf-1) + np.ceil(box_width / 2))], eps + G_0, v1_strengths, v1_strengths)  #* g_scal_scores[None, None, None, None, :, None, None]


class RelevanceLayer(tf.keras.layers.Layer): # These are like the V1 complex edges
    def __init__(self, **kwargs):
        super(RelevanceLayer, self).__init__(**kwargs)
        self.relevance_strictness = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(0.01),
                                 name="relevance_strictness",
                                 trainable=True)  
    def print_weights(self):
        print("Relevance strictness:", self.relevance_strictness.numpy())

    def call(self, inputs):
        (l, r) = inputs
        return 1. - (tf.abs(l - r)/(1e-12 + l + r)) ** (1. + tf.nn.relu(self.relevance_strictness))


class GroupingStrengthLayer(tf.keras.layers.Layer): # These are like the V1 complex edges
    def __init__(self, **kwargs):
        super(GroupingStrengthLayer, self).__init__(**kwargs)
        # Todo: we can try adding height here
        self.channel_attention_attraction_rate = self.add_weight(shape=(n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(np.array([0.1, 0.2, 0.3, 0.4, 0.3, 0.2, 0.1, 0.01]).astype(np.float32)),
                                 name="channel_attention_attraction_rate",
                                 trainable=True)  
        self.attention_g = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(0.1),
                                 name="attention_g",
                                 trainable=True)  
        #self.w_global = self.add_weight(shape=(n_v4_scales, n_v4_scales),
        #                         initializer=tf.keras.initializers.Constant(0.1),
        #                         name="w_global",
        #                         trainable=True)
        #self.w_local = self.add_weight(shape=(n_v4_scales, n_v4_scales),
        #                         initializer=tf.keras.initializers.Constant(0.1),
        #                         name="w_local",
        #                         trainable=True)
        self.attention_attraction_beta = self.add_weight(shape=(n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(0.1),
                                 name="attention_attraction_beta",
                                 trainable=True)
        self.b = self.add_weight(shape=(n_v4_scales),
                                 initializer=tf.keras.initializers.Constant(np.array([0.1, 0.2, 0.3, 0.4, 0.3, 0.2, 0.1, 0.05]).astype(np.float32)),
                                 name="b",
                                 trainable=True)
        self.n = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(6.),
                                 name="n",
                                 trainable=True)
        self.yfactors = self.add_weight(shape=(n_v4_scales,box_height),
                                       initializer=tf.keras.initializers.Constant(1.0),
                                       name="yfactor",
                                       trainable=True)

    def print_weights(self):
        #print("Channel attention attraction rate:")
        #plt.plot(self.channel_attention_attraction_rate.numpy())
        #plt.show()
        print("N:", self.n.numpy())
        #print("bet:")
        #plt.plot(self.attention_attraction_beta.numpy())
        #plt.show()
        #print("W local")
        #plt.imshow(self.w_local.numpy())
        #plt.colorbar()
        #plt.show()
        #print("W global")
        #plt.imshow(self.w_global.numpy())
        #plt.colorbar()
        #plt.show()
        #print("B feedback strength factors:")
        #plt.plot(self.b.numpy())
        #plt.show()
        #print("yfactor:")
        #plt.imshow(self.yfactors.numpy())
        #plt.colorbar()
        #plt.show()

    def call(self, inputs):
        (l, r, p, relevance) = inputs
        
        # We want only the relevant inputs, and for each input, we want to find out the grouping strength.
        # We essentially want for each scale of input, an exponent and a multiplier to predict the grouping strength.
        # We may need a sigmoid function of some sort to do this well.
        # Then, when the strength has been calculated everywhere, then we go and essentially find the max strength via a polynomial softmax approximation.

        #relevant_strengths = p * relevance
        relevant_strengths = l * tf.nn.relu(p - l) + r * tf.nn.relu(p - r)

        # Now, calculate the strength for each

        #strongest_strength = tf.reduce_sum((eps + strengths) ** tf.nn.relu(eps + self.n + 1.), [2], keepdims=True) / (eps + tf.reduce_sum((eps + strengths) ** tf.nn.relu(eps + self.n), [2,3,4], keepdims=True))
        ns = relevant_strengths * (10 ** self.n)
        softmax_weights = tf.exp(ns) / (eps + tf.reduce_sum(tf.exp(ns), [2,3,4], keepdims=True))
        strongest_strength = relevant_strengths * softmax_weights

        return strongest_strength[:, :, :, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                              int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))]

        # We would like to see the result, such that only the strongest one wins.

        #yfn = (tf.nn.relu(self.yfactors[None, None, :, :, None]) + eps) / (eps + tf.reduce_mean(tf.nn.relu(self.yfactors)))
        #attention_attraction_rates = eps + tf.nn.relu(p * self.channel_attention_attraction_rate[None, None, :, None, None]  / (eps + tf.reduce_min(tf.nn.relu(self.channel_attention_attraction_rate))))

        # We want to normalize, so that horizontally, any large-scale attention attractors are scaled down by small-scale competitors.


        # What if we want to normalize the attention stuff first, so that the larger ones don't win? And then we add the relevance after?


        #attention_attraction_pool_local = tf.einsum("bdchw,qc->bdqhw", attention_attraction_rates ** gn, eps + tf.nn.relu(self.w_local)) + eps
        #attention_attraction_pool_global = tf.einsum("bdchw,qc->bdq", attention_attraction_rates ** gn, eps + tf.nn.relu(self.w_global))[:, :, :, None, None] + eps
        #probability_of_grouping_feedback = (eps + relevance * yfn * attention_attraction_rates) ** gn / (eps + (tf.nn.relu(self.attention_attraction_beta[None, None, :, None, None]) + eps) ** gn + attention_attraction_pool_local + attention_attraction_pool_global)
        
        #feedback_strength = ((tf.nn.relu(p) + eps) ** (.5 + tf.nn.relu(self.k[None, None, :, None, None])) * tf.nn.relu(self.b[None, None, :, None, None])) #/ (eps + tf.reduce_min(self.b))

        #return probability_of_grouping_feedback * feedback_strength


# 8. Cost layer
class SkeletonCostLayer(tf.keras.layers.Layer): # These are like the V1 complex edges
    def __init__(self, **kwargs):
        super(SkeletonCostLayer, self).__init__(**kwargs)
        # <b, s, o, h, w, d>
        #self.wp = self.add_weight(shape=(n_v4_scales), # Penalties for Losses
        self.wp = self.add_weight(shape=(n_v1_scales), # Penalties for Losses
                                 #initializer=tf.keras.initializers.Constant(np.array([2., 1.7, 1.53, 0.4, 0.11, 0.01, 0.01, 0.01]).astype(np.float32)),
                                 initializer=tf.keras.initializers.Constant(np.array([2., 1.7, 1.53, 0.4, 0.11]).astype(np.float32)),
                                 #initializer=tf.keras.initializers.Constant(0.1),
                                 name="wp",
                                 trainable=True)
        #self.blurwidths = self.add_weight(shape=(n_v4_scales),
        self.blurwidths = self.add_weight(shape=(n_v1_scales),
                                 #initializer=tf.keras.initializers.Constant(tf.constant([0.0025, 0.003, 0.004, 0.0067, 0.0094, 0.017, 0.0291, 0.0471])),
                                 #initializer=tf.keras.initializers.Constant(tf.constant([0.0025, 0.003, 0.004, 0.0067, 0.0094, 0.017, 0.0291, 0.0471])),
                                 initializer=tf.keras.initializers.Constant(0.01),
                                 name="blurwidths",
                                 trainable=False)  
        self.we = self.add_weight(shape=(n_v4_scales), # Penalties for Losses
                                 #initializer=tf.keras.initializers.Constant(np.array([2., 1.7, 1.53, 0.4, 0.11, 0.01, 0.01, 0.01]).astype(np.float32)),
                                 initializer=tf.keras.initializers.Constant(1.0),
                                 name="we",
                                 trainable=True)

        self.blurwidthfactor = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(1.),
                                 name="blurwidthfactor",
                                 trainable=False)  
        self.n = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(1.),
                                 name="n",
                                 trainable=True)  

    def print_weights(self):
        print("Skeleton loss cost:")
        plt.plot(tf.nn.relu(self.wp).numpy())
        plt.show()
        print("Skeleton loss exps:")
        plt.plot(tf.nn.relu(self.we).numpy())
        plt.show()
        print("Blurwidths:")
        plt.plot(tf.nn.relu(self.blurwidths).numpy() * tf.nn.relu(self.blurwidthfactor))
        plt.show()
        print("Skeleton loss softmax n:", self.n.numpy())

    def call(self, inputs):
        #(l, r, p) = inputs

        # Find the skeleton losses
        #losses = tf.nn.relu(l - p) + tf.nn.relu(r - p)

        losses = inputs

        losses_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(losses, eps), [3, 4]))
        losses_filters = make_losses_filters(eps + tf.nn.relu(self.blurwidths) * tf.nn.relu(self.blurwidthfactor))
        losses_filters_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(losses_filters, eps), [3, 4]))
        losses_filtered = (eps + tf.nn.relu(tf.math.real(tf.signal.fftshift(tf.signal.ifft2d(losses_fft * losses_filters_fft), [3, 4])))) #[:, :, :, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                              #int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))])) # ** (eps + tf.nn.relu(self.we)[None, None, :, None, None])

        #penalties = (tf.nn.relu(self.wk + eps) ** tf.range(n_v4_scales, dtype=tf.float32))[None, None, :, None, None]
        penalties = tf.nn.relu(self.wp)[None, None, :, None, None]

        # Penalize the skeleton losses
        penalized_losses = losses_filtered * penalties 

        # Find the worst penalty
        pln = penalized_losses * 10 ** self.n
        expd = tf.exp(pln - tf.reduce_max(pln, [2,3,4], keepdims=True))
        softmax_weights = expd / (eps + tf.reduce_sum(expd, axis=[2,3,4], keepdims=True))
        worst_skeleton_loss_penalty = softmax_weights * penalized_losses

        return penalized_losses #worst_skeleton_loss_penalty


# 8. Cost layer
class LossLayer(tf.keras.layers.Layer): # These are like the V1 complex edges
    def __init__(self, **kwargs):
        super(LossLayer, self).__init__(**kwargs)
        # <b, s, o, h, w, d>
        self.target_strength = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(5.),
                                 name="target_strength",
                                 trainable=True)  
        self.target_grouping = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(0.0002),
                                 name="target_grouping",
                                 trainable=True)  
        self.skeleton_loss_weight = self.add_weight(shape=(),
                                 initializer=tf.keras.initializers.Constant(-3.),
                                 name="skeleton_loss_weight",
                                 trainable=True)
    def print_weights(self):
        print("Target strength:", self.target_strength.numpy())
        print("Target grouping:", self.target_grouping.numpy())
        print("Skeleton loss weight:", (10 * self.skeleton_loss_weight).numpy())

    def call(self, inputs):
        (grouping_strength, skeleton_loss_cost) = inputs

        xs = tf.constant(sample_distance_factors)[None, :]
        #ys = (grouping_strength - self.target_strength) ** 2 + skeleton_loss_cost * (eps + sp(self.skeleton_loss_weight)) # <b, d>
        # We want to minimize the
        #ys = -grouping_strength + skeleton_loss_cost * (eps + 10 ** (self.skeleton_loss_weight))
        ys = (skeleton_loss_cost - self.target_strength) ** 2 #+ (grouping_strength - self.target_grouping) ** 2 + 

        # Find worst violation of the well
        up_first_ness = (ys[:, 1] - ys[:, 0]) #[1,2,3,4,5,6]
        down_second_ness = (ys[:, 1] - ys[:, 2]) # has shape <batch_size>, [2,3,4,5,6,7,]
        worst_violation_sum = tf.reduce_mean((tf.reduce_max(tf.stack([up_first_ness, down_second_ness], axis=0), axis=[0]) + eps), name="worst_violation")
        return worst_violation_sum 

        # The problem we're having is that the ys are flat.
        # But with flat ys ... that means that the strongest strength is exactly constant between two sample distances.
        # Even without any kind of training changes, sometimes we seem to get just perfectly flat ys.
        # This is very strange. Why isn't the 
        #slopes = (ys[:, 1:]-ys[:, :-1])/(xs[:, 1:]-xs[:, :-1]) - 1e-8
        #intercepts = ys[:, 1:] - xs[:, 1:] * slopes
        #predicted_intersections = self.target_strength/slopes - ys[:, 1:]/slopes + xs[:, 1:]    # should be somewhere between -100 and 100
        ## what happens if the predicted intersection is at infinity?
        #use_second = tf.nn.sigmoid(100. * (predicted_intersections[:, 1] - xs[:, 1]))  # Predicted intersections may be very large, in which case 
        #use_first = 1. - use_second
        #predicted_intersection = use_first * predicted_intersections[:, 0] + use_second * predicted_intersections[:, 1]
        #deviations_from_target = (predicted_intersection - xs[0, 1]) ** 2
        #return (deviations_from_target, predicted_intersection)




# 9. Set up the actual math

def get_pair_violation_from_strength(left_v4_strength, right_v4_strength):
    # TODO: Give them each an exponent, weigh them by orientation and by scale, and sum them up.
    print("ORIGINAL STRENGTH", left_v4_strength)
    left_strength = tf.reduce_sum(left_v4_strength, [2, 3], keepdims=True) # <b,d,h,w>
    right_strength = tf.reduce_sum(right_v4_strength, [2, 3], keepdims=True) # <b,d,h,w>

    total_strength_loss = tf.identity(left_strength * right_strength / (left_strength + right_strength), "strengthloss")

    #sll = SkeletonCostLayer()
    #skeleton_loss_cost_image = tf.identity(sll(total_strength_loss), "skeleton_loss_cost_image") # <b, d>
    #skeleton_loss_cost = tf.identity(tf.reduce_sum(skeleton_loss_cost_image, [2,3]), "skeleton_losscosts")
    skeleton_loss_cost_image = tf.reduce_sum(total_strength_loss, [2,3], name="skeleton_loss_cost_image")

    skeleton_loss_cost = tf.reduce_max(skeleton_loss_cost_image, [2,3])

    ys = (skeleton_loss_cost - 90000) ** 2 #+ (grouping_strength - self.target_grouping) ** 2 + 

    # Find worst violation of the well
    up_first_ness = (ys[:, 1] - ys[:, 0]) #[1,2,3,4,5,6]
    down_second_ness = (ys[:, 1] - ys[:, 2]) # has shape <batch_size>, [2,3,4,5,6,7,]
    worst_violation_sum = tf.reduce_mean((tf.reduce_max(tf.stack([up_first_ness, down_second_ness], axis=0), axis=[0]) + eps), name="worst_violation")
    return worst_violation_sum 

    #return tf.reduce_sum(skeleton_loss_cost)
    # Find worst violation of the well

def get_pair_violation(left_v1_response, right_v1_response):
    """Runs the V1 responses through the V4 layer, weights the pair differences via the CostLayer,
    and then returns the worst violation of the "cost-must-be-lowest-for-optimal-distance" principle
    which is then passed to the optimizer."""

    # Feed V1 responses through the V4 layer
    v4 = V4Layer(False)
    (left_v4_filtered, left_v4_response, left_fullness, left_strength) = v4(tf.abs(left_v1_response))
    (right_v4_filtered, right_v4_response, right_fullness, right_strength) = v4(tf.abs(right_v1_response))
    (pair_v4_filtered, pair_v4_response, pair_fullness, pair_strength) = v4(tf.abs(left_v1_response + right_v1_response))

    left_v4_filtered = tf.identity(left_v4_filtered, "left_v4_filtered")

    left_fullness = tf.identity(left_fullness, "left_fullness")
    left_strength = tf.identity(left_strength, "left_strength")
    right_fullness = tf.identity(right_fullness, "right_fullness")
    right_strength = tf.identity(right_strength, "right_strength")

    left_v4_response = tf.identity(left_v4_response, "left_v4_response") + 0. * tf.reduce_sum(left_strength)
    right_v4_response = tf.identity(right_v4_response, "right_v4_response") + 0. * tf.reduce_sum(right_strength)
    pair_v4_response = tf.identity(pair_v4_response, "pair_v4_response")

    # What can we do with these things?

    total_strength = tf.identity(tf.identity(left_strength, "total_strength"), "skeleton_loss_cost_image") #  * (left_distance + right_distance)
    # This approximates the strength of the G-cells that define the gap.

    # We need a function that converts the local fullness value, combined with the distance, to a strength.

    rl = RelevanceLayer()
    relevance = tf.identity(rl((left_v4_response, right_v4_response)), "relevance")

    gsl = GroupingStrengthLayer()
    grouping_strength_image = tf.identity(gsl((left_v4_response, right_v4_response, pair_v4_response, relevance)), "grouping_local_strengths")
    grouping_strength = tf.identity(tf.reduce_sum(grouping_strength_image, [2,3,4]), "grouping_strengths") # <b, d>

    # We need a penalty for the maximal skeleton loss
    sll = SkeletonCostLayer()
    #skeleton_loss_cost_image = tf.identity(sll((left_v4_response, right_v4_response, pair_v4_response)), "skeleton_loss_cost_image") # <b, d>
    #skeleton_loss_cost_image = tf.identity(sll(total_strength), "skeleton_loss_cost_image") # <b, d>
    #skeleton_loss_cost = tf.identity(tf.reduce_sum(skeleton_loss_cost_image, [2,3,4]), "skeleton_losscosts")

    skeleton_loss_cost = tf.reduce_sum(total_strength, [2,3,4])

    ll = LossLayer()
    total_pair_cost = tf.identity(ll((grouping_strength, skeleton_loss_cost)), "pair_total_cost")  # The difference from the target. <b, d>


    return total_pair_cost + 0. * tf.reduce_sum(total_strength)
    # Find worst violation of the well
    #up_first_ness = (total_pair_cost[:, 1] - total_pair_cost[:, 0]) #[1,2,3,4,5,6]
    #down_second_ness = (total_pair_cost[:, 1] - total_pair_cost[:, 2]) # has shape <batch_size>, [2,3,4,5,6,7,]
    #worst_violation_sum = tf.reduce_mean((tf.reduce_max(tf.stack([up_first_ness, down_second_ness], axis=0), axis=[0]) + eps), name="worst_violation")
    #return worst_violation_sum 

    # Worst violation sum: results in extremely flat cost lines, but often far away from the target line.
    # We want to minimize the zero, but we also don't want to flatten things.
    # We want to total pair cost 
    #return tf.math.sqrt(eps + tf.reduce_sum(total_pair_cost[:, 1] / (eps + tf.reduce_sum(total_pair_cost, axis=[1], keepdims=True)))) 

    # Instead of the worst violation, we directly predict the deviation from the correct x
    #return tf.math.sqrt(eps + tf.reduce_mean(total_pair_cost))

# 10. Set up Keras model and run

def get_keras_model():
    # The translated raw images aren't used in the model, they're just for visualization purposes ...
    left_image = tf.keras.Input(shape=(n_sample_distances, box_height, box_width), name='left_image')
    right_image = tf.keras.Input(shape=(n_sample_distances, box_height, box_width), name='right_image')

    # ... but the translated V1 responses are:
    left_v1_response = tf.keras.Input(shape=(n_sample_distances, n_v1_scales, n_v1_orientations, box_height, box_width), name='left_v1_response', dtype=tf.complex64)
    right_v1_response = tf.keras.Input(shape=(n_sample_distances, n_v1_scales, n_v1_orientations, box_height, box_width), name='right_v1_response', dtype=tf.complex64)
    left_v4_strength = tf.keras.Input(shape=(n_sample_distances, n_v1_scales, 2*n_v1_orientations, box_height, box_width), name='left_v4_strength')
    right_v4_strength = tf.keras.Input(shape=(n_sample_distances, n_v1_scales, 2*n_v1_orientations, box_height, box_width), name='right_v4_strength')

    lvs = tf.identity(left_v4_strength, "lvs")
    print("IN:", left_v4_strength.shape, lvs.shape)

    # This calls the V4 layer, the penalty/reward layer, and finds the max cost
    #total_violation = tf.identity(get_pair_violation(left_v1_response, right_v1_response), "total_violation")
    total_violation = tf.identity(get_pair_violation_from_strength(left_v4_strength, right_v4_strength), "total_violation") + 0. * tf.reduce_sum(lvs)

    return tf.keras.Model(inputs=[left_image, right_image, left_v4_strength, right_v4_strength, left_v1_response, right_v1_response], outputs=(total_violation))

class MonitorProgressCallback(tf.keras.callbacks.Callback):
    def __init__(self, dataset, interval):
        self.dataset = dataset
        self.interval = interval
        self.current_data = None

    def get_val(self, name):
        l = [l for l in self.model.layers if l.name.find(name) >= 0][0]
        output = tf.keras.backend.function(self.model.inputs, [l.output])([self.current_data["left_v1_response"],
                                                                           self.current_data["right_v1_response"],
                                                                           self.current_data["left_v4_strength"],
                                                                           self.current_data["right_v4_strength"],
                                                                           self.current_data["left_image"],
                                                                           self.current_data["right_image"]])[0]
        return output

    def get_weights(self, name):
        l = [l for l in self.model.layers if l.name.find(name) >= 0][0]
        return l.get_weights()

    def print_weights(self, name):
        l = [l for l in self.model.layers if l.name.find(name) >= 0][0]
        l.print_weights()

    def on_epoch_end(self, epoch, logs=None):
        # Only show this stuff every [interval] batches
        if epoch % self.interval != 0:
            return

        #print([l.name for l in self.model.layers])
        #print(self.model.inputs)

        self.current_data, _ = next(self.dataset) #list(self.dataset.take(1).as_numpy_iterator())[0]
        pair_images = self.current_data["left_image"] + self.current_data["right_image"]

        #print("\nPair DIFFS:")
        #pair_cost = self.get_val("pair_pixel_cost")
        #print("Pair total cost", self.get_val("pair_total_cost")[0, 0])
        #plt.imshow(pair_images[0, 0, :, :], alpha=1)
        #plt.imshow(tf.reduce_sum(pair_cost, [2, 3, 4], keepdims=True)[0, 0, 0, 0, 0, :, :], alpha=0.7)
        #plt.colorbar()
        #plt.show()
        #print("Pair total cost", self.get_val("pair_total_cost")[0, 1])
        #pair_cost = self.get_val("pair_pixel_cost")
        #plt.imshow(pair_images[0, 1, :, :], alpha=1)
        #plt.imshow(tf.reduce_sum(pair_cost, [2, 3, 4], keepdims=True)[0, 1, 0, 0, 0, :, :], alpha=0.7)
        #plt.colorbar()
        #plt.show()
        #print("Pair total cost", self.get_val("pair_total_cost")[0, 2])
        #pair_cost = self.get_val("pair_pixel_cost")
        #plt.imshow(pair_images[0, 2, :, :], alpha=1)
        #plt.imshow(tf.reduce_sum(pair_cost, [2, 3, 4], keepdims=True)[0, 2, 0, 0, 0, :, :], alpha=0.7)
        #plt.colorbar()
        #plt.show()

        # print("Grouping strengths:")
        # local_grouping_strengths = self.get_val("grouping_local_strengths")
        # for nb in range(n_sample_distances):
        #     plt.imshow(pair_images[0, nb, :, :], cmap="gray")
        #     plt.imshow(tf.reduce_sum(local_grouping_strengths[0, nb, :, :, :], [0]), alpha=.8)
        #     plt.colorbar()
        #     plt.show()

        print("Skeleton losses:")
        local_skeleton_loss_cost = self.get_val("skeleton_loss_cost_image")

        for nb in range(n_sample_distances):
            plt.imshow(pair_images[0, nb, :, :], cmap="gray")
            #plt.imshow(tf.reduce_sum(local_skeleton_loss_cost[0, nb, :, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
            #                  int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))], [0]), alpha=.8)
            plt.imshow(local_skeleton_loss_cost[0, nb, :, :], alpha=.8)
            plt.colorbar()
            plt.show()

        return

        pair_diff = self.get_val("relevance")
        maxv = tf.reduce_max(pair_diff)
        minv = tf.reduce_min(pair_diff)
        ex = max(abs(maxv), abs(minv))
        fs = 16

        left_v4_response = self.get_val("left_v4_response")
        right_v4_response = self.get_val("right_v4_response")
        pair_v4_response = self.get_val("pair_v4_response")
        if True: # False if displaying B cells
            print("RAW LOSSES for size")
            figsize = (fs * 1 * box_width / 100, fs * n_v4_scales * box_height / 100)
            fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
            if n_v4_scales > 1:
                for j in range(n_v4_scales):
                    ax[j].imshow(pair_images[0, 1, :, :], alpha=1)
                    ax[j].imshow(pair_diff[0, 1, j, :, :], alpha=0.7, vmin=-ex, vmax=ex)
            else:
                ax.imshow(pair_images[0, 1, :, :], alpha=1)
                ax.imshow(pair_diff[0, 1, 0, :, :], alpha=0.7, vmin=-ex, vmax=ex)
            plt.show()
    
        if False:
            figsize = (fs * 1 * box_width / 100, fs * n_v4_scales * box_height / 100)
            fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
            print("RAW V4 GAINS for size, without min/max limit")
            if n_v4_scales > 1:
                for j in range(n_v4_scales):
                    print("Scale", j, "max loss:", tf.reduce_max(tf.nn.relu(pair_diff[0, 1, j, :, :])), "total losses:", tf.reduce_sum(tf.nn.relu(pair_diff[0, 1, j, :, :])))
                    ax[j].imshow(pair_images[0, 1, :, :], alpha=1)
                    ax[j].imshow(tf.nn.relu(pair_diff[0, 1, j, :, :]), alpha=0.7)
            else:
                ax.imshow(pair_images[0, 1, :, :], alpha=1)
                ax.imshow(tf.nn.relu(pair_diff[0, 1, 0, :, :]), alpha=0.7)

            plt.show()


        # THIN SAMPLE RESULTS
        if True:
            figsize = (fs * 1 * box_width / 100, fs * n_v4_scales * box_height / 100)
            fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
            print("RAW V4 PAIR TIMES RELEVANCE for size, without min/max limit")
            if n_v4_scales > 1:
                for j in range(n_v4_scales):
                    print("Scale", j, "max gain:", tf.reduce_max(tf.nn.relu((pair_diff*pair_v4_response)[0, 0, j, :, :])), "total gain:", tf.reduce_sum(tf.nn.relu((pair_diff * pair_v4_response)[0, 0, j, :, :])))
                    ax[j].imshow(pair_images[0, 0, :, :], alpha=1)
                    ax[j].imshow(pair_diff[0, 0, j, :, :] * pair_v4_response[0, 0, j, :, :], alpha=0.7)
            else:
                ax.imshow(pair_images[0, 0, :, :], alpha=1)
                ax.imshow(tf.nn.relu(-pair_diff[0, 0, 0, :, :]), alpha=0.7)
            plt.show()


        if True:
            figsize = (fs * 1 * box_width / 100, fs * n_v4_scales * box_height / 100)
            fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
            print("RAW V4 PAIR TIMES RELEVANCE for size, without min/max limit")
            if n_v4_scales > 1:
                for j in range(n_v4_scales):
                    print("Scale", j, "max gain:", tf.reduce_max(tf.nn.relu((pair_diff*pair_v4_response)[0, 1, j, :, :])), "total gain:", tf.reduce_sum(tf.nn.relu((pair_diff * pair_v4_response)[0, 1, j, :, :])))
                    ax[j].imshow(pair_images[0, 1, :, :], alpha=1)
                    ax[j].imshow(pair_diff[0, 1, j, :, :] * pair_v4_response[0, 1, j, :, :], alpha=0.7)
            else:
                ax.imshow(pair_images[0, 1, :, :], alpha=1)
                ax.imshow(tf.nn.relu(-pair_diff[0, 1, 0, :, :]), alpha=0.7)
            plt.show()

        # WIDE SAMPLE RESULTS
        if True:
            figsize = (fs * 1 * box_width / 100, fs * n_v4_scales * box_height / 100)
            fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
            print("RAW V4 PAIR TIMES RELEVANCE for size, without min/max limit")
            if n_v4_scales > 1:
                for j in range(n_v4_scales):
                    print("Scale", j, "max gain:", tf.reduce_max(tf.nn.relu((pair_diff * pair_v4_response)[0, 2, j, :, :])), "total gain:", tf.reduce_sum(tf.nn.relu((pair_diff * pair_v4_response)[0, 2, j, :, :])))
                    ax[j].imshow(pair_images[0, 2, :, :], alpha=1)
                    ax[j].imshow(pair_diff[0, 2, j, :, :] * pair_v4_response[0, 2, j, :, :], alpha=0.7)
            else:
                ax.imshow(pair_images[0, 2, :, :], alpha=1)
                ax.imshow(tf.nn.relu(-pair_diff[0, 2, 0, :, :]), alpha=0.7)
            plt.show()

        if False:
            print("FILTERED")
            left_v4_filtered = self.get_val("left_v4_filtered")
            if True:
                figsize = (fs * 1 * box_width / 100, fs * n_v4_scales * box_height / 100)
                fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
                print("RAW V4 PAIR TIMES RELEVANCE for size, without min/max limit")
                if n_v4_scales > 1:
                    for j in range(n_v4_scales):
                        print("Scale", j, "max gain:", tf.reduce_max(tf.nn.relu((pair_diff * left_v4_filtered)[0, 1, j, :, :])), "total gain:", tf.reduce_sum(tf.nn.relu((pair_diff * left_v4_filtered)[0, 1, j, :, :])))
                        ax[j].imshow(pair_images[0, 1, :, :], alpha=1)
                        ax[j].imshow(left_v4_filtered[0, 1, j, :, :], alpha=0.7)
                else:
                    ax.imshow(pair_images[0, 1, :, :], alpha=1)
                    ax.imshow(tf.nn.relu(-pair_diff[0, 1, 0, :, :]), alpha=0.7)
                plt.show()
    

        figsize = (fs * 1 * box_width / 100, fs * n_v4_scales * box_height / 100)
        fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
        print("PAIR v4 response for size, without min/max limit, 0-3")
        if n_v4_scales > 1:
            for j in range(n_v4_scales):
                print("Channel", j, "max", tf.reduce_max(pair_v4_response[0, 1, j, :, :]), "sum", tf.reduce_sum(left_v4_response[0, 1,  j, :, :]))
                #print("Channel", j, "max", tf.reduce_max(self.get_val("left_v4_response")[0, 1, 0, j, 0, :, :]), "sum", tf.reduce_sum(self.get_val("left_v4_response")[0, 1, 0, j, 0, :, :]))
                ax[j].imshow(pair_images[0, 1, :, :], alpha=1)
                ax[j].imshow(pair_v4_response[0, 1, j, :, :], alpha=0.8)
                #ax[j].imshow(self.get_val("left_v4_response")[0, 1, 0, j, 0, :, :], alpha=0.7)
        else:
            ax.imshow(pair_images[0, 1, :, :], alpha=1)
            ax.imshow(pair_v4_response[0, 1, 0, 0, 0, :, :], alpha=0.8)

        plt.show()

        if False:
            figsize = (fs * 1 * box_width / 100, fs * n_v4_scales * box_height / 100)
            fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
            print("PAIR v4 response for size, without min/max limit, 0-3")
            for j in range(n_v4_scales):
                print("Channel", j, "max", tf.reduce_max(pair_v4_response[0, 1, 0, 0, j, :, :]), "sum", tf.reduce_sum(pair_v4_response[0, 1, 0, 0, j, :, :]))
                #print("Channel", j, "max", tf.reduce_max(self.get_val("pair_v4_response")[0, 1, 0, j, 0, :, :]), "sum", tf.reduce_sum(self.get_val("pair_v4_response")[0, 1, 0, j, 0, :, :]))
                ax[j].imshow(pair_images[0, 1, :, :], alpha=1)
                ax[j].imshow(pair_v4_response[0, 1, 0, 0, j, :, :], alpha=0.8)
                #ax[j].imshow(self.get_val("pair_v4_response")[0, 1, 0, j, 0, :, :])
            plt.show()
    
            fig, ax = plt.subplots(1, n_v4_scales, gridspec_kw={'wspace':0, 'hspace':0}, figsize=figsize)
            maxv = tf.reduce_max(pair_cost)
            minv = tf.reduce_min(pair_cost)
            ex = max(abs(maxv), abs(minv))
            print("COST of DIFF for size")
            for j in range(n_v4_scales):
                ax[j].imshow(pair_images[0, 1, :, :], alpha=1)
                ax[j].imshow(pair_cost[0, 1, 0, 0, j, :, :], alpha=0.7, vmin=-ex, vmax=ex)
            plt.show()

        self.print_weights("v4_layer")

        self.print_weights("grouping_strength_layer")
        #self.print_weights("skeleton_cost_layer")
        self.print_weights("loss_layer")

        skeleton_loss_weight = 10 ** (self.get_weights("loss_layer")[2])
        grouping_strength = np.transpose(self.get_val("grouping_strengths"))

        target_strength = self.get_weights("loss_layer")[0]
        target_grouping = self.get_weights("loss_layer")[1]

        skeleton_loss_weights = np.transpose(skeleton_loss_weight * self.get_val("skeleton_losscosts"))
        print("Skeleton loss weight, extracted:", skeleton_loss_weight)
        #plt.plot(np.transpose((self.get_val("grouping_strengths") - 0.00023251) ** 2), ":") # Deviation penalties
        #plt.plot(np.transpose(skeleton_loss_weight * self.get_val("skeleton_losscosts")), "--")  # Skeleton loss penalties
        #plt.plot(np.transpose((self.get_val("grouping_strengths") - 0.00023251) ** 2) + skeleton_loss_weight * np.transpose(self.get_val("skeleton_losscosts"))) # Total penalties
        #plt.plot(skeleton_loss_weights, "--")
        #plt.plot(-grouping_strength, ":")
        #plt.plot(skeleton_loss_weights - grouping_strength)        
        plt.plot(np.transpose(self.get_val("skeleton_losscosts") - target_strength)**2) # + (grouping_strength - target_grouping) ** 2)
        plt.show()


        #ys = (tf.reduce_max(self.skeleton_loss_cost, [2,3,4]) - self.target_strength) ** 2
    


model = get_keras_model()
testing = 0
model.compile(loss=(lambda _, c: c), optimizer=tf.keras.optimizers.Adam(0.0 if testing else .005))


prepared_dataset = translated_dataset.shuffle(10*batch_size).batch(batch_size).prefetch(batch_size)

if True:
    history = model.fit(prepared_dataset,
                        callbacks=[MonitorProgressCallback(prepared_dataset.as_numpy_iterator(), 1 if testing else 12)],
                        epochs=(1 if testing else 1000),
                        steps_per_epoch=(1 if testing else 3), use_multiprocessing=False)

In [None]:
# PARAMETERS

kk = 2.8
cp = 0.2
cn = 0.25
v1_v4_scale_weights = tf.tile(tf.constant([.04, .1, .13, .04, .01])[:, None], [1, n_v4_scales]) # includes the CSF
v1_v4_scale_weights = tf.ones((n_v1_scales, n_v4_scales)) # An alternative for the CSF.
v1_v4_orientation_weights = tf.ones((2*n_v1_orientations, n_v4_scales))

k = tf.constant([2.4] * n_v4_scales)
b = (0.7 ** (2.4 * tf.range(n_v4_scales, dtype=tf.float32)))

v4_filters = make_v4_filters(2.8, 0, 0, 0, 0, cp, cn)[None, ...]
v4_filters_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(v4_filters, 0.), [3, 4])) 

# Skeleton loss penalty
blurwidths = tf.constant([0.0025, 0.002, 0.0026, 0.0067, 0.003, 0.0001, 0.0001, 0.0001])
loss_penalties = tf.constant([0.5, 1.6, 1.5, 0.37, 0.1, 0.01, 0.01, 0.01])
skeleton_loss_target = 10

#channel_attention_attraction_rate = tf.constant([.15, .18, .25, .48, .25, .14, .05, 0.002])[:, None, None]
#g = tf.constant([1.07, 1.122, 1.156, 1.158, 1.162, 1.152, 1.159, 1.143])[:, None, None]
#beta = tf.constant([0., 0.04, 0.06, 0.05, 0.135, 0.126, 0.117, 0.1])[:, None, None]
#w_local = 0.06 * tf.ones((n_v4_scales, n_v4_scales)) # actually a whole thing
#w_global = 0.04 * tf.ones((n_v4_scales, n_v4_scales)) # actually a whole thing
#fk = tf.constant([1.001, 1.055, 1.08, 1.041, 1.08, 1.091, 1.09, 1.08])[:, None, None]
#fb = tf.constant([0.24, 0.25, .332, .47, .32, .2, .12, .08])[:, None, None]
#yfn = tf.ones((n_v4_scales, box_height))[:, :, None]

def g_response(v1c):
    # Multiply by CSF
    v1b = tf.pad(tf.concat([v1c]*2, axis=1), [[0, 0], [0, 0],
                                [int(np.ceil(box_height / 2)), int(box_height / 2)],
                                [int(np.ceil(box_width / 2)), int(box_width / 2)]], mode='CONSTANT')

    b_balanced_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(v1b, 0.), [2, 3]))
    v4_activations_by_s_o = tf.nn.relu(tf.math.real(tf.signal.fftshift(tf.signal.ifft2d(v4_filters_fft * b_balanced_fft[:, :, None, :, :]), [3, 4]))) 
    v4_activations = tf.einsum("sochw,sc->ochw", v4_activations_by_s_o, v1_v4_scale_weights)
    g_activations = tf.einsum("ochw,oc->chw", v4_activations, v1_v4_orientation_weights)

    #print("GCELLs")
    #plt.imshow(g_activations[0, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
    #                        int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))])
    #plt.show()

    return g_activations 

def pair_at_distance(d, l, r):
    (lt, rt) = get_pair_translations(l, r, d)
    v1_l = translate_4d_image(glyph_v1_responses[l], np.array([lt]))[0, ...]
    v1_r = translate_4d_image(glyph_v1_responses[r], np.array([rt]))[0, ...]
    g_l = translate_4d_image(glyph_images[l][None, None, :, :], np.array([lt]))[0, 0, 0, :, :]
    g_r = translate_4d_image(glyph_images[r][None, None, :, :], np.array([rt]))[0, 0, 0, :, :]
    v1_p = v1_l + v1_r

    v1c_l, v1c_r, v1c_p = tf.abs(v1_l), tf.abs(v1_r), tf.abs(v1_p)
    v4g_l, v4g_r, v4g_p = g_response(v1c_l), g_response(v1c_r), g_response(v1c_p)


    # Find the skeleton losses
    losses = tf.nn.relu(v4g_l - v4g_p) + tf.nn.relu(v4g_r - v4g_p)

    losses_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(losses, eps), [1, 2]))
    losses_filters = make_losses_filters(eps + tf.nn.relu(blurwidths))[0, 0, ...]
    losses_filters_fft = tf.signal.fft2d(tf.signal.ifftshift(tf.complex(losses_filters, eps), [1, 2]))
    losses_filtered = tf.math.real(tf.signal.fftshift(tf.signal.ifft2d(losses_fft * losses_filters_fft), [1, 2]))[:, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                            int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))]

    penalized_losses = losses_filtered * loss_penalties[:, None, None] # tf.nn.relu(self.wp)[None, None, :, None, None]

    loss_penalty = (tf.reduce_max(penalized_losses) - skeleton_loss_target) ** 2
    if (l == "c"):
        plt.imshow(g_l + g_r, cmap="gray")
        plt.imshow(tf.reduce_sum(losses[:, int(np.ceil(box_height / 2)):int(box_height + np.ceil(box_height / 2)),
                            int(np.ceil(box_width / 2)):int(box_width + np.ceil(box_width / 2))], [0]), alpha=0.8)
        plt.colorbar()
        plt.show()

#    relevance = 1. - (tf.abs(v4g_l - v4g_r)/(eps + v4g_l + v4g_r)) ** 1.3
#
#    attention_attraction_rates = v4g_p * channel_attention_attraction_rate * relevance * yfn / (eps + tf.reduce_min(channel_attention_attraction_rate))
#
#    attention_attraction_pool_local = tf.einsum("chw,qc->qhw", attention_attraction_rates ** g, w_local)
#    attention_attraction_pool_global = tf.einsum("chw,qc->q", attention_attraction_rates ** g, w_global)[:, None, None]
#    probability_of_grouping_feedback = attention_attraction_rates ** g / (beta ** g + attention_attraction_pool_local + attention_attraction_pool_global)
#    
#    feedback_strength = v4g_p ** fk * fb
#
#    mean_grouping_strength = tf.reduce_sum(probability_of_grouping_feedback * feedback_strength)
    #print("loss_penalty", loss_penalty)
    return loss_penalty

def find_best_distance(l, r):
    best_distance = minimize_scalar(pair_at_distance, args=(l, r), options={"maxiter": 1}).x

    print("Best distance for", l, r, ":", best_distance)


find_best_distance("n", "n")
#find_best_distance("l", "l")
#find_best_distance("p", "o")
#find_best_distance("d", "b")
#find_best_distance("d", "h")
find_best_distance("c", "x")
#find_best_distance("m", "i")
#find_best_distance("a", "a")
