# Make figures for paper on geometric images

## Authors:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)
- **Soledad Villar** (JHU)

## To-do items and bugs:
- Make plots (and maybe a LaTeX table) that illustrate the group B_d.
- Figure out a better way to plot in `D=3`.

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=3

import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp
from jax import random

import ginjax.geometric as geom
import ginjax.utils as utils

%load_ext autoreload
%autoreload 2

In [None]:
D = 2
group_operators = geom.make_all_operators(D)
dpi = 300
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'STIXGeneral'

# set save_plots to true and specify the save folder to save the images as pdfs
save_plots = False 
save_folder = '../../images/paper_images/'

In [None]:
# Plot the 3x3 filters of tensor order 0,1,2 and parity 0 and 1.
N = 3
max_k = 2

allfilters_N3, maxn = geom.get_invariant_filters_dict(
    [N], 
    range(max_k+1), 
    [0,1], 
    D, 
    group_operators, 
    scale='one', 
)

maxlen = maxn[(D,N)]
# 3 scalar filters
names = [f'{geom.tensor_name(image.k, image.parity)} {i}' for i, image in enumerate(allfilters_N3[(D,N,0,0)])]
utils.plot_grid(allfilters_N3[(D,N,0,0)], names, maxlen)

# 2 vector filters and 2 pseudovector filters
names = [f'{geom.tensor_name(image.k, image.parity)} {i}' for i, image in enumerate(allfilters_N3[(D,N,1,0)] + allfilters_N3[(D,N,1,1)])]
utils.plot_grid(allfilters_N3[(D,N,1,0)] + allfilters_N3[(D,N,1,1)], names, maxlen)

# 5 tensor filters
names = [f'{geom.tensor_name(image.k, image.parity)} {i}' for i, image in enumerate(allfilters_N3[(D,N,2,0)])]
utils.plot_grid(allfilters_N3[(D,N,2,0)], names, maxlen)

# 5 pseudotensor filters
names = [f'{geom.tensor_name(image.k, image.parity)} {i}' for i, image in enumerate(allfilters_N3[(D,N,2,1)])]
utils.plot_grid(allfilters_N3[(D,N,2,1)], names, maxlen);

In [None]:
# Plot the 5x5 filters of tensor order 0,1 and parity 0 and 1.
N = 5
max_k = 1

allfilters_N5, maxn = geom.get_invariant_filters_dict(
    [N], 
    range(max_k+1), 
    [0,1], 
    D, 
    group_operators, 
    scale='one', 
)

maxlen = 7
# 6 scalar filters + 1 pseudoscalar filter
filters_scalar_N5 = allfilters_N5[(D,N,0,0)]+allfilters_N5[(D,N,0,1)]
names = [f'{geom.tensor_name(image.k, image.parity)} {i}' for i, image in zip(list(range(6)) + [0], filters_scalar_N5)]
utils.plot_grid(filters_scalar_N5, names, maxlen)
if save_plots:
    plt.savefig(save_folder + 'filters_m5_row1.pdf')

maxlen = 6
# 6 vector filters
names = [f'{geom.tensor_name(image.k, image.parity)} {i}' for i, image in enumerate(allfilters_N5[(D,N,1,0)])]
utils.plot_grid(allfilters_N5[(D,N,1,0)], names, maxlen)
if save_plots:
    plt.savefig(save_folder + 'filters_m5_row2.pdf')

# 6 pseudovector filters
names = [f'{geom.tensor_name(image.k, image.parity)} {i}' for i, image in enumerate(allfilters_N5[(D,N,1,1)])]
fig = utils.plot_grid(allfilters_N5[(D,N,1,1)], names, maxlen)
if save_plots:
    plt.savefig(save_folder + 'filters_m5_row3.pdf')

In [None]:
# make a sensible smooth scalar image on a 2-torus
N = 16
D = 2
key = random.PRNGKey(42)
image = random.normal(key, shape=(N,)*D)

scalar_image = geom.GeometricImage(image, 1, D)
smoothing_filter = allfilters_N3[(D,3,0,0)][0] + allfilters_N3[(D,3,0,0)][1] + allfilters_N3[(D,3,0,0)][2]
scalar_image = scalar_image.convolve_with(smoothing_filter).convolve_with(smoothing_filter)

In [None]:
filters = [
    (allfilters_N5[(D,5,0,0)][4], r"C_{s}"), 
    (allfilters_N3[(D,3,1,0)][0], r"C_{v}"), 
    (allfilters_N3[(D,3,1,1)][1], r"C_{pv}"), 
    (allfilters_N5[(D,5,0,1)][0], r"C_{ps}"),
]

In [None]:
monomials = {}
monomials[1] = [(scalar_image.normalize(), r"s"), ]
monomials[1] += [(scalar_image.convolve_with(ff).normalize(), r"s\ast " + tt) for ff, tt in filters]

In [None]:
fig = monomials[1][0][0].plot() #scalar image
if save_plots:
    plt.savefig(save_folder + 'scalar_img.pdf', bbox_inches='tight')

In [None]:
fig = monomials[1][1][0].plot() # scalar image convolved with scalar filter
if save_plots:
    plt.savefig(save_folder + 'scalar_img_convolved.pdf', bbox_inches='tight')

In [None]:
fig = monomials[1][2][0].plot() #scalar image convolved with vector filter
if save_plots:
    plt.savefig(save_folder + 'img_convolved_vector_ff.pdf', bbox_inches='tight')

In [None]:
fig = monomials[1][3][0].plot() #scalar image convolved with pseudovector filter
if save_plots:
    plt.savefig(save_folder + 'img_convolved_pseudovector_ff.pdf', bbox_inches='tight')

In [None]:
fig = monomials[1][4][0].plot() #scalar image convolved with pseudoscalar filter
if save_plots:
    plt.savefig(save_folder + 'img_convolved_pseudoscalar_ff.pdf', bbox_inches='tight')

In [None]:
fig = allfilters_N5[(D,5,0,0)][4].plot() # scalar filter
if save_plots:
    plt.savefig(save_folder + 'scalar_filter.pdf', bbox_inches='tight')

In [None]:
fig = allfilters_N3[(D,3,1,0)][0].plot() # vector filter
if save_plots:
    plt.savefig(save_folder + 'vector_filter.pdf', bbox_inches='tight')

In [None]:
fig = allfilters_N3[(D,3,1,1)][1].plot() # pseudovector filter
if save_plots:
    plt.savefig(save_folder + 'pseudovector_filter.pdf', bbox_inches='tight')

In [None]:
fig = allfilters_N5[(D,5,0,1)][0].plot() # pseudoscalar filter
if save_plots:
    plt.savefig(save_folder + 'pseudoscalar_filter.pdf', bbox_inches='tight')

In [None]:
# Plot the action of B_2 on a vector

def plot_vec(original_arrow, rotated_arrow, title, ax):
    ax.set_title(title)
    ax.spines['left'].set_position('zero')
    ax.spines['right'].set_color('none')
    ax.spines['bottom'].set_position('zero')
    ax.spines['top'].set_color('none')

    # remove the ticks from the top and right edges
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
    ax.set_xlim(-1,1)
    ax.set_ylim(-1,1)
    
    ax.arrow(
        0,
        0,
        original_arrow[0],
        original_arrow[1], 
        length_includes_head=True,
        head_width= 0.24 * 0.33,
        head_length=0.72 * 0.33,
    )
    
    ax.arrow(
        0,
        0,
        rotated_arrow[0],
        rotated_arrow[1], 
        length_includes_head=True,
        head_width= 0.24 * 0.33,
        head_length=0.72 * 0.33,
        facecolor='red',
        edgecolor='black',
    )

sorted_operators = np.stack(group_operators)[[0,5,3,6,1,2,7,4]]
original_arrow = jnp.array([2,1])/jnp.linalg.norm(jnp.array([2,1]))
rotated_arrows = [gg @ original_arrow for gg in sorted_operators]
names = [
    'Identity', 
    r'Rot $90^{}$'.format('{\circ}'), 
    r'Rot $180^{}$'.format('{\circ}'), 
    r'Rot $270^{}$'.format('{\circ}'), 
    'Flip X', 
    'Flip Y', 
    r'Rot $90^{}$, Flip X'.format('{\circ}'),
    r'Rot $270^{}$, Flip X'.format('{\circ}'), 
]

num_rows = 2
num_cols = 4
bar = 8. # figure width in inches?
fig, axes = plt.subplots(num_rows, num_cols, figsize = (bar, 1.15 * bar * num_rows / num_cols), # magic
                         squeeze=False)
axes = axes.flatten()
plt.subplots_adjust(left=0.001/num_cols, right=1-0.001/num_cols, wspace=0.5/num_cols,
                    bottom=0.001/num_rows, top=1-0.1/num_rows, hspace=0.5/num_rows)

for i, rotated_arrow in enumerate(rotated_arrows):
    plot_vec(original_arrow, rotated_arrow, names[i], axes[i])
    
