In [None]:
import itertools as it
import numpy as np
import jax.numpy as jnp
import jax.random as random
import ginjax.geometric as geom
import ginjax.utils as utils
%load_ext autoreload
%autoreload 2

In [None]:
D = 2
N = 3
img_k = 1
max_k = 2
group_operators = geom.make_all_operators(D)
print(len(group_operators))

In [None]:
allfilters, maxn = geom.get_invariant_filters(
    [N], 
    range(max_k+1), 
    [0,1], 
    D, 
    group_operators, 
    scale='one', 
    return_type='dict',
    return_maxn=True,
)
for key in allfilters.keys():
    D, M, k, parity = key
    names = ["{} {}".format(geom.tensor_name(k, parity), i) for i in range(len(allfilters[key]))]
    utils.plot_filters(allfilters[key], names, maxn[(D, M)])

In [None]:
#Under contractions, the 3 scalar filters are the same as the first 3 k=2 filters. Thus we ignore them.
trimmed_filters = allfilters
trimmed_filters[(D,N,0,0)] = []
trimmed_filters[(D,N,0,1)] = []

filter_list = list(it.chain(*list(allfilters.values())))
print(len(filter_list))

In [None]:
# Make an N side length, parity=0 geometric vector image on a D-torus
key = random.PRNGKey(0)
vector_images = []
if (N == 3):
    num_images = 3
elif (N == 5):
    num_images = 7
    
key, subkey = random.split(key)
img_shape = ((num_images,) + (N,)*D + (D,)*img_k)
geom_img = geom.BatchGeometricImage(random.normal(subkey, shape=img_shape), 0, D)

In [None]:
def quadratic_filter(img, c1, c2, c3):
    return (img.convolve_with(c1) * img.convolve_with(c2)).convolve_with(c3)

In [None]:
def getVectorImgs(vector_image):
    vector_images = []
    names = []
    for c1_idx, c2_idx, c3_idx in it.combinations(range(len(filter_list)), 3):
        c1 = filter_list[c1_idx]
        c2 = filter_list[c2_idx]
        c3 = filter_list[c3_idx]

        #conditions suitable for a sequence of kronecker contractions
        if (
            ((c1.k + c2.k + c3.k + vector_image.k)%2 == 0) and 
            ((c1.parity + c2.parity + c3.parity + vector_image.parity)%2 == 0)
        ):
            print(c1_idx, c2_idx, c3_idx)
            img = quadratic_filter(vector_image, c1, c2, c3)

            for idxs in geom.get_contraction_indices(img.k, vector_image.k):

                img_contracted = img.multicontract(idxs)
                assert img_contracted.shape() == vector_image.shape()

                vector_images.append(img_contracted.data.flatten())
                
    return jnp.array(vector_images)

In [None]:
datablock = getVectorImgs(geom_img)

In [None]:
print(datablock.shape)
print(jnp.unique(datablock, axis=0).shape)

In [None]:
u, s, v = jnp.linalg.svd(jnp.unique(datablock, axis=0))
print("there are", np.sum(s > 100*geom.TINY), "different images")