In [None]:
import torch

In [None]:
from perm_equivariant_seq2seq.equivariant_models import EquiSeq2Seq
from perm_equivariant_seq2seq.data_utils import get_scan_split, get_equivariant_scan_languages
from perm_equivariant_seq2seq.symmetry_groups import get_permutation_equivariance, CircularShift
from perm_equivariant_seq2seq.utils import tensors_from_pair, tensor_from_sentence

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_pairs, test_pairs = get_scan_split(split='add_jump')
in_equivariances = ['jump', 'run', 'walk', 'look']
out_equivariances = ['JUMP', 'RUN', 'WALK', 'LOOK']

In [None]:
equivariant_commands, equivariant_actions = get_equivariant_scan_languages(pairs=train_pairs, input_equivariances=in_equivariances, output_equivariances=out_equivariances)

In [None]:
input_symmetry_group = get_permutation_equivariance(equivariant_commands)
output_symmetry_group = get_permutation_equivariance(equivariant_actions)

In [None]:
# make sure that these groups are actually groups
# 1. check identity is correct
# 2. check inverses is correct
# 3. check closure

In [None]:
def check_identity(perm_group):
    eye = perm_group.e
    assert torch.equal(perm_group.index2mat[0], eye), "indexing of identity is incorrect"
    for idx in perm_group.index2mat:
        assert torch.equal(eye @ perm_group.index2mat[idx], perm_group.index2mat[idx]), "identity behavior incorrect"

In [None]:
def check_inverses(perm_group):
    # check index2mat, index2inverse, and index2inverse_indices are correct
    # check that they are actually inverses
    assert len(perm_group.index2mat) == len(perm_group.index2inverse), "dictionary sizes inconsistent"
    assert len(perm_group.index2inverse) == len(perm_group.index2inverse_indices), "dictionary sizes inconsistent part 2"
    for idx in perm_group.index2mat:
        # print("--- ", idx)
        inv = perm_group.index2inverse[idx]
        inv_prods = perm_group.index2inverse_indices[idx]
        # print(inv)
        # print(inv_prods)
        for idy in perm_group.index2mat:
            # print("------ ", idy)
            # print(inv @ perm_group.index2mat[idy])
            # print(perm_group.index2mat[inv_prods[idy].item()])
            # print(f"inv @ mat shape: {(inv @ perm_group.index2mat[idy]).shape}")
            # print(f"inv_prods: {(perm_group.index2mat[inv_prods[idy].item()]).shape}")
            assert torch.isclose(inv @ perm_group.index2mat[idy], perm_group.index2mat[inv_prods[idy].item()]).all(), "index2inverses_indices book-keeping incorrect"
        assert torch.isclose(inv @ perm_group.index2mat[idx], perm_group.e).all(), "inverse behavior incorrect"

In [None]:
def check_closure(perm_group):
    for idx in perm_group.index2mat:
        for idy in perm_group.index2mat:
            prod = perm_group.index2mat[idx] @ perm_group.index2mat[idy]
            in_group = False
            for idz in perm_group.index2mat:
                if torch.isclose(prod, perm_group.index2mat[idz]).all():
                    in_group = True
                    break
            assert in_group, "product of elements not in group"

In [None]:
# test with cyclic shift
cyclic = CircularShift(num_letters=equivariant_commands.n_words,
                         num_equivariant=equivariant_commands.num_equivariant_words,
                         first_equivariant=equivariant_commands.num_fixed_words + 1)

In [None]:
def test_group(gp):
    check_identity(gp)
    check_inverses(gp)
    check_closure(gp)

In [None]:
test_group(cyclic)
test_group(input_symmetry_group)
test_group(output_symmetry_group)

In [None]:
# testing einsum code
ipt = torch.randn(1, 3, 2) # batch x |G| x K
conv_filter = torch.randn(3, 3, 2, 2) # |G| x |G| x K x K
# expect output to have shape: batch x |G| x K

In [None]:
use_ein = torch.einsum("bhk,ghkl->bgl", ipt, conv_filter)
print(use_ein)

In [None]:
ip = ipt[:, None, ..., None]
conv_fil = conv_filter[None, ...]
old_ver = (ip * conv_fil).sum(2).sum(2)

print(old_ver)

In [None]:
torch.allclose(use_ein, old_ver)