In [None]:
import torch

In [None]:
# from https://www.geeksforgeeks.org/python-import-from-sibling-directory/
import sys
 
# append the path of the
# parent directory
sys.path.append(".")

In [None]:
%cd ~/gordon

In [None]:
from perm_equivariant_seq2seq.equivariant_models import EquiSeq2Seq
from perm_equivariant_seq2seq.data_utils import get_scan_split, get_equivariant_scan_languages
from perm_equivariant_seq2seq.symmetry_groups import get_permutation_equivariance
from perm_equivariant_seq2seq.utils import tensors_from_pair, tensor_from_sentence

In [None]:
# model 3 is using alternating permutations for 400,000 iterations
# model 1 is using circle shifts
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_pairs, test_pairs = get_scan_split(split='add_jump')
in_equivariances = ['jump', 'run', 'walk', 'look']
out_equivariances = ['JUMP', 'RUN', 'WALK', 'LOOK']

In [None]:
equivariant_commands, equivariant_actions = get_equivariant_scan_languages(pairs=train_pairs, input_equivariances=in_equivariances, output_equivariances=out_equivariances)
input_symmetry_group = get_permutation_equivariance(equivariant_commands)
output_symmetry_group = get_permutation_equivariance(equivariant_actions)

hidden_size = 64 # default
layer_type = 'GLSTM' # default
use_attention = True
bidirectional = True

In [None]:
ap_state_dict = torch.load("models/add_jump/rnn_GLSTM_hidden_64_directions_2/model3/model_fully_trained.pt")

In [None]:
ap_model = EquiSeq2Seq(input_symmetry_group=input_symmetry_group,
                        output_symmetry_group=output_symmetry_group,
                        input_language=equivariant_commands,
                        encoder_hidden_size=hidden_size,
                        decoder_hidden_size=hidden_size,
                        output_language=equivariant_actions,
                        layer_type=layer_type,
                        use_attention=use_attention,
                        bidirectional=bidirectional)

In [None]:
ap_model.load_state_dict(ap_state_dict)

In [None]:
ap_model.to(device)
ap_model.eval()

In [None]:
cs_state_dict = torch.load("models/add_jump/rnn_GLSTM_hidden_64_directions_2/model1/model_fully_trained.pt")

In [None]:
SOS_token = 0 # hard-coded in language
EOS_token = 1
pair = ['jump', 'JUMP']
input_t, output_t = tensors_from_pair(pair, equivariant_commands, equivariant_actions)
model_sentence = ap_model(input_t)
_, sentence_ints = model_sentence.data.topk(1)
try:
    eos_location = (sentence_ints == EOS_token).nonzero()[0][0]
except:
    eos_location = len(sentence_ints) - 2
model_sentence = sentence_ints[:eos_location+1]
print(model_sentence)
print(output_t)

In [None]:
def sentence_correct(target, model_sentence):
    # First, extract sentence up to EOS
    _, sentence_ints = model_sentence.data.topk(1)
    # If there is no EOS token, take the complete list
    try:
        eos_location = (sentence_ints == EOS_token).nonzero()[0][0]
    except:
        eos_location = len(sentence_ints) - 2
    model_sentence = sentence_ints[:eos_location+1]
    # Check length is correct
    if len(model_sentence) != len(target):
        return torch.tensor(0, device=device)
    else:
        correct = model_sentence == target
        return torch.prod(correct).to(device)

In [None]:
sentence_correct(output_t, ap_model(input_t))