In [1]:
import sys
import os

# Enable horizontal scrolling for large outputs
from IPython.display import display, HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))
display(HTML("<style>div.output_scroll { height: 44em; }</style>"))

# Calculate the path to your_package
notebook_dir = os.getcwd()
package_dir = os.path.dirname(notebook_dir)

# Add this path to sys.path
if package_dir not in sys.path:
    sys.path.append(package_dir)

In [8]:
from config import parse_args
from main import *

commandline_args = '--task_path=shaw_geo/v1/standard --dtm_layers=20 --steps=4e4 --ctrl_hidden_dim=256 --train_log_freq=-1 --max_tree_depth=11 --d_filler=128 --num_workers=4 --sparse --max_filled_roles=2048 --custom_memory=1 --output_lowercase=0 --add_eob_to_memory=1 --cons_only=1 --early_stop_epochs=100 --wandb_group=geo --use_wandb --learn_filler_embed=1 --validate_every_num_epochs=10 --tied_io_languages=1 --filler_noise_location=input --filler_noise_std=0 --use_vocab_info=0 --is_agent_universal=1 --positional_embedding_type=sinusoidal --wandb_name=shaw_geo_longer --root_prediction_type=QK_ATTN_OVER_INPUTS --gclip=10 --lr=5e-5 --optim_beta2=.99 --optim_beta2=.9 --wd=.1 --batch_size=32 --transformer_nheads=4 --router_dropout=.1 --num_extra_tokens_in_memory=4'
# XT needs the <NT> token wrapped in "", remove it if it's there
commandline_args = commandline_args.replace('"<NT>"', '<NT>')
args = parse_args(commandline_args.split())
args.batch_size = 1

run = '1213.0'
checkpoint_file = f'/tmp/run{run}/run{run}/mirrored/out/best_checkpoint.pt'

task_path = prepare_data_files(args)
device = 'cpu'
is_ddp = False
main_process = True
data_loaders, input_lang, output_lang = data.prepare_data_loaders(
    task_path,
    args.max_tree_depth,
    args.add_eob_tokens,
    is_ddp,
    args.batch_size,
    args.num_workers,
    data_filter=args.data_filter,
    max_train_examples=args.max_train_examples,
    output_lowercase=args.output_lowercase,
    add_eob_to_memory=args.add_eob_to_memory,
    num_extra_tokens_in_memory=args.num_extra_tokens_in_memory,
)

if args.tied_io_languages:
    for i, v in output_lang.ind2vocab.items():
        input_lang.add_word(v)
    output_lang = input_lang
    for data_loader in data_loaders.values():
        if data_loader:
            data_loader.dataset.output_lang = output_lang
            
test_iter = iter(data_loaders['test'])

print(f'Input language size: {len(input_lang.ind2vocab)}')
print(f'Output language size: {len(output_lang.ind2vocab)}')

max_input_length = -1
for name, loader in data_loaders.items():
    if loader:
        max_input_length = max(max_input_length, loader.dataset.max_input_length)


using LOCAL dataset/task files found at: /Users/psoulos/.data/shaw_geo/v1/standard
Max depth seen in file: 11
600 training examples
Max depth seen in file: 11
600 valid examples
Max depth seen in file: 17
279 test examples
Input language size: 204
Output language size: 204


In [9]:
from TPR_utils import TPR, decoded_tpr_to_tree_fn
from models import DiffTreeMachine
from main import convert_args_to_config

if args.use_vocab_info:
    vocab_info = data.get_vocab_info(args.task_path, output_lang.ind2vocab.values(),)
else:
    vocab_info = {
        'unary': (),
        'binary': (),
        'terminal': ('<EOB>',)
    }

tpr = TPR(
    args,
    num_input_fillers=len(input_lang.ind2vocab),
    num_output_fillers=len(output_lang.ind2vocab),
    num_roles=2 ** args.max_tree_depth,
    d_filler=args.d_filler,
    d_role=args.d_role,
    filler_emb_gain=args.filler_emb_gain,
    learn_empty_filler=args.learn_empty_filler,
    tied_io_languages=args.tied_io_languages,
    empty_filler_initialization=args.empty_filler_initialization,
    device=device,
    sparse=args.sparse
).to(device=device)


hardcode_cons_root_index = None
if args.hardcode_cons_root_token:
    if args.hardcode_cons_root_token == '-1':
        hardcode_cons_root_index = -1
    else:
        vocab2index = output_lang.vocab2ind
        assert args.hardcode_cons_root_token in vocab2index, (
            f'The token {args.harcode_cons_root_token} is not in the '
            f'vocab.')
        hardcode_cons_root_index = vocab2index[args.hardcode_cons_root_token]
    logger.info(
        f'Hardcoding the root token to {args.hardcode_cons_root_token} with index'
        f' {hardcode_cons_root_index}'
    )


convert_args_to_config(args, input_lang, output_lang, tpr, hardcode_cons_root_index, max_input_length)

dtm = DiffTreeMachine(args).to(device=device)

map_location = device
dtm.load_state_dict(torch.load(checkpoint_file,
                               map_location=map_location)['model'])
dtm.eval()

optimizer, scheduler = setup_optimizer_and_scheduler(dtm, args)

trainer = Trainer(
    dtm,
    tpr,
    data_loaders['train'],
    data_loaders['valid'],
    data_loaders['test'],
    optimizer,
    args.epoch,
    args.steps,
    args.num_warmup_steps,
    main_process,
    is_ddp,
    decoded_tpr_to_tree_fn(args.tpr_loss_type, sparse=args.sparse),
    torch.nn.CrossEntropyLoss(),
    device,
    output_lang.ind2vocab,
    vocab_info,
    args.use_wandb,
    args.validate_every_num_epochs,
    args.train_log_freq,
    early_stop_epochs=args.early_stop_epochs,
    pad_idx=0,
    sparse=args.sparse,
    scheduler=scheduler,
    gclip=args.gclip,
    lr=args.lr,
    out_dir=args.out_dir,
    best_checkpoint_file=args.best_checkpoint_file,
    most_recent_checkpoint_file=args.most_recent_checkpoint_file,
    use_custom_memory=args.custom_memory,
    cross_entropy_weighting=args.cross_entropy_weighting,
    entropy_regularization_coefficient=args.entropy_regularization_coefficient,
    max_input_length=args.max_input_length,
)

Trainable params: 1507741


In [10]:
import torch

def print_correct_output(batch):
    output_filler_indices = batch['output_fillers']
    batch_size = output_filler_indices.shape[0]
    output_role_indices = batch['output_roles']
    output_batch_indices = torch.nonzero(output_role_indices, as_tuple=True)[0]
    output_filler_indices = output_filler_indices[output_role_indices != 0]
    output_role_indices = output_role_indices[output_role_indices != 0]
    target = torch.sparse_coo_tensor(indices=torch.stack((output_batch_indices, output_role_indices)),
                                     values=output_filler_indices, size=(batch_size,
                                                                         tpr.num_roles)).coalesce()

    formatted_tree = TreePrettyPrinter(Tree.fromstring(
        batch_symbols_to_node_tree(SparseTPR(target.indices(), target.values()), output_lang.ind2vocab,
                                   terminal_vocab=(),#vocab_info['terminal'],
                                   unary_vocab=vocab_info['unary'], )[0].str(), ))
    print('Correct output:\n{}'.format(formatted_tree.text()))

In [11]:
def make_output_lowercase(args, string):
    out = []
    if args.output_lowercase:
        for s in string:
            if s == 'I_JUMP':
                out.append('jump')
            if s == 'I_WALK':
                out.append('walk')
            if s == 'I_LOOK':
                out.append('look')
            if s == 'I_RUN':
                out.append('run')
            if s == 'I_TURN_RIGHT':
                out.append('right')
            if s == 'I_TURN_LEFT':
                out.append('left')
    else:
        out = string
    return out

In [45]:
batch = next(test_iter)
#print(batch)
_ = trainer.process_batch(batch, debug=True)
print_correct_output(batch)
print(f'Correct? {_[3]}')

Step 0:
Blackboard:
[cons_l, cons_r, root filler]
 0	[.01 .00 .02]    (  )
 1	[.00 .16 .03]    ( <TOKEN_0> )
 2	[.50 .02 .05]    ( <TOKEN_1> )
 3	[.00 .01 .20]    ( <TOKEN_2> )
 4	[.01 .00 .12]    ( <TOKEN_3> )
 5	[.00 .56 .07]    ( how )
 6	[.01 .02 .08]    ( many )
 7	[.14 .00 .05]    ( states )
 8	[.04 .01 .05]    ( do )
 9	[.16 .04 .13]    ( not )
10	[.11 .17 .09]    ( have )
11	[.02 .01 .12]    ( rivers )
Output: 
```          count    
     _______|____   
<TOKEN_1>       how
```
Step 1:
Blackboard:
[cons_l, cons_r, root filler]
 0	[.01 .00 .03]    (  )
 1	[.00 .01 .02]    ( <TOKEN_0> )
 2	[.30 .00 .04]    ( <TOKEN_1> )
 3	[.00 .00 .17]    ( <TOKEN_2> )
 4	[.01 .00 .20]    ( <TOKEN_3> )
 5	[.00 .01 .05]    ( how )
 6	[.02 .00 .06]    ( many )
 7	[.10 .00 .05]    ( states )
 8	[.05 .00 .04]    ( do )
 9	[.27 .00 .13]    ( not )
10	[.18 .01 .07]    ( have )
11	[.05 .00 .13]    ( rivers )
12	[.01 .96]  0. ( count <TOKEN_1> how )
Output: 
```      count            
  ______|_______  

In [76]:
_ = trainer.process_batch(batch, debug=True)
print_correct_output(batch)
print(f'Correct? {_[3]}')

Step 0:
Blackboard:
 0	[.18 .05]    (  )
 1	[.09 .06]    ( run )
 2	[.15 .27]    ( around )
 3	[.12 .05]    ( right )
 4	[.06 .09]    ( twice )
 5	[.12 .04]    ( and )
 6	[.08 .08]    ( run )
 7	[.15 .20]    ( around )
 8	[.06 .15]    ( left )
Output: 
```       <NT>       
   _____|_____     
around      around
```
Step 1:
Blackboard:
 0	[.02 .01]    (  )
 1	[.01 .10]    ( run )
 2	[.08 .01]    ( around )
 3	[.05 .00]    ( right )
 4	[.35 .01]    ( twice )
 5	[.01 .01]    ( and )
 6	[.03 .30]    ( run )
 7	[.33 .02]    ( around )
 8	[.11 .01]    ( left )
 9	[.02 .54]  0. ( <NT> around around )
Output: 
```            <NT>      
      _______|_____    
I_TURN_LEFT      I_RUN
```
Step 2:
Blackboard:
 0	[.04 .01]    (  )
 1	[.01 .04]    ( run )
 2	[.18 .01]    ( around )
 3	[.06 .01]    ( right )
 4	[.20 .03]    ( twice )
 5	[.02 .01]    ( and )
 6	[.01 .24]    ( run )
 7	[.23 .08]    ( around )
 8	[.04 .05]    ( left )
 9	[.03 .13]  0. ( <NT> around around )
10	[.19 .39]  1. ( <NT> I_TU

In [None]:
incorrect_samples = []
for batch in test_iter:
    _ = trainer.process_batch(batch, debug=False)
    if _[3][0] == False:
        input_string = []
        for token_idx in batch['input_fillers'][0]:
            input_string.append(input_lang.ind2vocab[token_idx.item()])
        incorrect_samples.append(' '.join(input_string))
    if _[3][1] == False:
        input_string = []
        for token_idx in batch['input_fillers'][1]:
            input_string.append(input_lang.ind2vocab[token_idx.item()])
        incorrect_samples.append(' '.join(input_string))

In [None]:
for i in incorrect_samples:
    print(i)

In [28]:
from data import text_tree_to_node, BinaryT2TDataset
from preprocessing.preprocess_scan import build_leaves_tree, build_separate_tree

in_tree = text_tree_to_node(build_separate_tree('how large is the largest city in m0'.split()), add_eob_to_memory=args.add_eob_to_memory, add_eob_tokens=args.add_eob_tokens, num_extra_tokens_in_memory=args.num_extra_tokens_in_memory)
out_tree = text_tree_to_node('( answer ( size ( largest ( intersection city ( loc_2 m0 ) ) ) ) )', add_eob_to_memory=False, add_eob_tokens=args.add_eob_tokens)
example = {"input": in_tree, "output": out_tree}
item = example

input_roles, input_fillers = data_loaders['train'].dataset.text_to_tensors(item['input'], language=input_lang)
output_roles, output_fillers = data_loaders['train'].dataset.text_to_tensors(item['output'], language=output_lang)
input_ = {
    'input_fillers': input_fillers.unsqueeze(0),
    'input_roles': input_roles.unsqueeze(0),
    'output_fillers': output_fillers,
    'output_roles': output_roles,
}

_ = trainer.process_batch(input_, debug=True)
print_correct_output(input_)

Step 0:
Blackboard:
[cons_l, cons_r, root filler]
 0	[.01 .00 .02]    (  )
 1	[.01 .17 .03]    ( <TOKEN_0> )
 2	[.70 .01 .04]    ( <TOKEN_1> )
 3	[.00 .01 .16]    ( <TOKEN_2> )
 4	[.01 .00 .12]    ( <TOKEN_3> )
 5	[.00 .66 .08]    ( how )
 6	[.06 .06 .07]    ( large )
 7	[.03 .02 .06]    ( is )
 8	[.00 .00 .02]    ( the )
 9	[.04 .01 .11]    ( largest )
10	[.13 .01 .08]    ( city )
11	[.00 .01 .14]    ( in )
12	[.00 .04 .09]    ( m0 )
Output: 
```          city    
     ______|____   
<TOKEN_1>      how
```
Step 1:
Blackboard:
[cons_l, cons_r, root filler]
 0	[.01 .00 .03]    (  )
 1	[.01 .02 .03]    ( <TOKEN_0> )
 2	[.44 .00 .03]    ( <TOKEN_1> )
 3	[.01 .00 .11]    ( <TOKEN_2> )
 4	[.02 .00 .14]    ( <TOKEN_3> )
 5	[.00 .04 .08]    ( how )
 6	[.11 .01 .08]    ( large )
 7	[.05 .00 .03]    ( is )
 8	[.01 .00 .01]    ( the )
 9	[.04 .00 .13]    ( largest )
10	[.28 .00 .08]    ( city )
11	[.01 .00 .16]    ( in )
12	[.02 .01 .09]    ( m0 )
13	[.00 .91]  0. ( city <TOKEN_1> how )
Output: 

In [None]:
from data import text_tree_to_node, BinaryT2TDataset
from preprocessing.preprocess_scan import build_leaves_tree, build_separate_tree

in_tree = text_tree_to_node(build_separate_tree('turn around left thrice'.split()), add_eob_to_memory=args.add_eob_to_memory, add_eob_tokens=args.add_eob_tokens)

out_string = '( <NT> ( <NT> ( <NT> ( <NT> I_TURN_RIGHT I_TURN_RIGHT ) ( <NT> I_TURN_RIGHT I_TURN_RIGHT ) ) ( <NT> ( <NT> I_TURN_RIGHT I_TURN_RIGHT ) ( <NT> I_TURN_RIGHT I_TURN_RIGHT ) ) ) ( <NT> ( <NT> I_TURN_RIGHT I_TURN_RIGHT ) ( <NT> I_TURN_RIGHT I_TURN_RIGHT ) ) )'.split()
out_string = make_output_lowercase(args, out_string)
out_tree = text_tree_to_node(build_leaves_tree(out_string), add_eob_to_memory=False, add_eob_tokens=args.add_eob_tokens)
example = {"input": in_tree, "output": out_tree}
item = example

input_roles, input_fillers = data_loaders['train'].dataset.text_to_tensors(item['input'], language=input_lang)
output_roles, output_fillers = data_loaders['train'].dataset.text_to_tensors(item['output'], language=output_lang)
input_ = {
    'input_fillers': input_fillers.unsqueeze(0),
    'input_roles': input_roles.unsqueeze(0),
    'output_fillers': output_fillers,
    'output_roles': output_roles,
}

_ = trainer.process_batch(input_, debug=True)
print_correct_output(input_)

In [None]:
output = _[-1]
x_decoded = decoded_tpr_to_tree_fn(args.tpr_loss_type, sparse=args.sparse)(
    tpr.unbind(
        (SparseTPR(output.indices(), output.values())), decode=True, type_='output'
    )
)
debug_tree = batch_symbols_to_node_tree(
    x_decoded,
    output_lang.ind2vocab,
    terminal_vocab=(),#vocab_info['terminal'],
    unary_vocab=(),#vocab_info['unary']
)[0]
pretty_tree = TreePrettyPrinter(Tree.fromstring(debug_tree.str()))
print(pretty_tree.text())

In [None]:
print(output.indices())
(output.values()[int('11',2)]@tpr.out.weight.T).argmax()

In [None]:
role = int('1110000',2)
print(role)
print(output.role_indices()[role-1])
print((output.values()[role-1]@tpr.out.weight.T).argmax())
print(output_lang.ind2vocab[(output.values()[role-1]@tpr.out.weight.T).argmax().item()])

In [None]:
output_lang.ind2vocab[2]

In [None]:
from data import text_tree_to_node, BinaryT2TDataset
from preprocessing.preprocess_scan import build_leaves_tree, build_separate_tree

in_tree = text_tree_to_node(build_separate_tree('walk around right twice and jump opposite left twice'.split()), add_eob_to_memory=args.add_eob_to_memory, add_eob_tokens=args.add_eob_tokens)
out_string = 'I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP'.split()
out_string = make_output_lowercase(args, out_string)
out_tree = text_tree_to_node(build_leaves_tree(out_string), add_eob_to_memory=False, add_eob_tokens=args.add_eob_tokens)
example = {"input": in_tree, "output": out_tree}
item = example

input_roles, input_fillers = data_loaders['train'].dataset.text_to_tensors(item['input'], language=input_lang)
output_roles, output_fillers = data_loaders['train'].dataset.text_to_tensors(item['output'], language=output_lang)
input_ = {
    'input_fillers': input_fillers.unsqueeze(0),
    'input_roles': input_roles.unsqueeze(0),
    'output_fillers': output_fillers,
    'output_roles': output_roles,
}

_ = trainer.process_batch(input_, debug=True)
print_correct_output(input_)

In [None]:
from data import text_tree_to_node, BinaryT2TDataset
from preprocessing.preprocess_scan import build_leaves_tree, build_separate_tree

in_tree = text_tree_to_node(build_separate_tree('look around right thrice after walk around left twice'.split()), add_eob_to_memory=args.add_eob_to_memory, add_eob_tokens=args.add_eob_tokens)
out_string = 'I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP'.split()
out_string = make_output_lowercase(args, out_string)
out_tree = text_tree_to_node(build_leaves_tree(out_string), add_eob_to_memory=False, add_eob_tokens=args.add_eob_tokens)
example = {"input": in_tree, "output": out_tree}
item = example

input_roles, input_fillers = data_loaders['train'].dataset.text_to_tensors(item['input'], language=input_lang)
output_roles, output_fillers = data_loaders['train'].dataset.text_to_tensors(item['output'], language=output_lang)
input_ = {
    'input_fillers': input_fillers.unsqueeze(0),
    'input_roles': input_roles.unsqueeze(0),
    'output_fillers': output_fillers,
    'output_roles': output_roles,
}

_ = trainer.process_batch(input_, debug=True)
print_correct_output(input_)

In [None]:
from data import text_tree_to_node, BinaryT2TDataset
from preprocessing.preprocess_scan import build_leaves_tree, build_separate_tree

in_tree = text_tree_to_node(build_separate_tree('walk around left twice and look around right thrice'.split()), add_eob_to_memory=args.add_eob_to_memory, add_eob_tokens=args.add_eob_tokens)
out_string = 'I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP'.split()
out_string = make_output_lowercase(args, out_string)
out_tree = text_tree_to_node(build_leaves_tree(out_string), add_eob_to_memory=False, add_eob_tokens=args.add_eob_tokens)
example = {"input": in_tree, "output": out_tree}
item = example

input_roles, input_fillers = data_loaders['train'].dataset.text_to_tensors(item['input'], language=input_lang)
output_roles, output_fillers = data_loaders['train'].dataset.text_to_tensors(item['output'], language=output_lang)
input_ = {
    'input_fillers': input_fillers.unsqueeze(0),
    'input_roles': input_roles.unsqueeze(0),
    'output_fillers': output_fillers,
    'output_roles': output_roles,
}

_ = trainer.process_batch(input_, debug=True)
print_correct_output(input_)

In [None]:
from data import text_tree_to_node, BinaryT2TDataset
from preprocessing.preprocess_scan import build_leaves_tree, build_separate_tree

in_tree = text_tree_to_node(build_separate_tree('walk around right thrice and jump left twice'.split()), add_eob_to_memory=args.add_eob_to_memory, add_eob_tokens=args.add_eob_tokens)
out_string = 'I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP'.split()
out_string = make_output_lowercase(args, out_string)
out_tree = text_tree_to_node(build_leaves_tree(out_string), add_eob_to_memory=False, add_eob_tokens=args.add_eob_tokens)
example = {"input": in_tree, "output": out_tree}
item = example

input_roles, input_fillers = data_loaders['train'].dataset.text_to_tensors(item['input'], language=input_lang)
output_roles, output_fillers = data_loaders['train'].dataset.text_to_tensors(item['output'], language=output_lang)
input_ = {
    'input_fillers': input_fillers.unsqueeze(0),
    'input_roles': input_roles.unsqueeze(0),
    'output_fillers': output_fillers,
    'output_roles': output_roles,
}

_ = trainer.process_batch(input_, debug=True)
#print_correct_output(input_)

In [None]:
from data import text_tree_to_node, BinaryT2TDataset
from preprocessing.preprocess_scan import build_leaves_tree, build_separate_tree

in_tree = text_tree_to_node(build_separate_tree('walk around right thrice and jump opposite left twice'.split()), add_eob_to_memory=args.add_eob_to_memory, add_eob_tokens=args.add_eob_tokens)
out_string = 'I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP I_TURN_LEFT I_TURN_LEFT I_JUMP'.split()
out_string = make_output_lowercase(args, out_string)
out_tree = text_tree_to_node(build_leaves_tree(out_string), add_eob_to_memory=False, add_eob_tokens=args.add_eob_tokens)
example = {"input": in_tree, "output": out_tree}
item = example

input_roles, input_fillers = data_loaders['train'].dataset.text_to_tensors(item['input'], language=input_lang)
output_roles, output_fillers = data_loaders['train'].dataset.text_to_tensors(item['output'], language=output_lang)
input_ = {
    'input_fillers': input_fillers.unsqueeze(0),
    'input_roles': input_roles.unsqueeze(0),
    'output_fillers': output_fillers,
    'output_roles': output_roles,
}

_ = trainer.process_batch(input_, debug=True)
#print_correct_output(input_)

In [None]:
from data import text_tree_to_node, BinaryT2TDataset
from preprocessing.preprocess_scan import build_leaves_tree, build_separate_tree

in_tree = text_tree_to_node(build_separate_tree('walk around right thrice and jump around left thrice'.split()), add_eob_to_memory=args.add_eob_to_memory, add_eob_tokens=args.add_eob_tokens)
out_string = 'I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_RIGHT I_WALK I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP I_TURN_LEFT I_JUMP'.split()
out_string = make_output_lowercase(args, out_string)
out_tree = text_tree_to_node(build_leaves_tree(out_string), add_eob_to_memory=False, add_eob_tokens=args.add_eob_tokens)
example = {"input": in_tree, "output": out_tree}
item = example

input_roles, input_fillers = data_loaders['train'].dataset.text_to_tensors(item['input'], language=input_lang)
output_roles, output_fillers = data_loaders['train'].dataset.text_to_tensors(item['output'], language=output_lang)
input_ = {
    'input_fillers': input_fillers.unsqueeze(0),
    'input_roles': input_roles.unsqueeze(0),
    'output_fillers': output_fillers,
    'output_roles': output_roles,
}

_ = trainer.process_batch(input_, debug=True)
print_correct_output(input_)

In [None]:
for i, filler in enumerate(zip(output_lang.ind2vocab, tpr.out.weight.norm(dim=-1))):
    print(f'{i}. {filler[0]} {filler[1]}')


In [None]:
for i in range(tpr.out.weight.shape[0]):
    print(f'{i}. {(tpr.filler_emb.weight[14]*.7+tpr.filler_emb.weight[4]*.3) @ tpr.out.weight[i]}')

In [None]:
for i in range(tpr.out.weight.shape[0]):
    print(f'{i}. {tpr.filler_emb.weight[7] @ tpr.out.weight[i]}')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap



# Generating a new colormap
colors = ["green", "white", "blue"]  # green for negative, white for zero, blue for positive
n_bins = [3]  # Discretizes the interpolation into bins
cmap_name = 'my_list'
cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)



# Compute the dot product matrix
dot_product_matrix = tpr.out.weight @ tpr.out.weight.T

# Plotting the heatmap
plt.figure(figsize=(8, 6))
plt.imshow(dot_product_matrix.detach(), cmap=cm, interpolation='nearest')
plt.title("Heatmap of Dot Products Between Vectors")
plt.colorbar()
plt.show()

In [None]:
for i in range(len(output_lang.ind2vocab)):
    print(f'{i}.{output_lang.ind2vocab[i]}: {tpr.out.weight[i] @ tpr.out.weight.T}')

In [None]:
print((tpr.filler_emb.weight[14]*.7+tpr.filler_emb.weight[8]*.3) @ tpr.out.weight.T)
print()
print((tpr.filler_emb.weight[9]*.7+tpr.filler_emb.weight[8]*.3) @ tpr.out.weight.T)
print()
print((tpr.filler_emb.weight[2]*.7+tpr.filler_emb.weight[8]*.3) @ tpr.out.weight.T)

In [None]:
tpr.filler_emb.weight[7] @ tpr.out.weight.T

In [None]:
a = torch.tensor([ 0.0764,  0.5909, -0.9669,  0.0574, -0.9311, -1.0875, -0.3789, -0.1703,
        -0.1916,  1.4145, -0.1150,  0.5652, -0.4569,  0.3671, -1.0746, -1.1743,
         0.1241, -1.1443, -1.3893, -0.8754, -1.3105,  0.0818, -0.1234,  0.3814,
         0.5174, -0.9131, -0.7015,  0.2231, -1.2841, -0.0058,  0.7574, -0.4444,
         0.1757, -0.2283,  0.0047, -0.2394, -0.2723, -0.1466,  0.4142, -0.1406,
        -0.1118, -1.0626, -0.1752,  0.1124, -0.5517, -0.6271, -0.3438, -0.3370,
        -0.1784,  1.4131, -1.1570,  0.4874, -1.0664,  0.3588,  0.0939,  0.1241,
        -0.6608,  1.2563,  0.7050,  0.7262, -0.2961, -1.1610,  0.1476, -0.6420])
f = a @ tpr.out.weight.T
for i, v in enumerate(f):
    print(f'{i}: {v}')


In [None]:
tpr.out.weight.norm(dim=-1)

In [30]:
print('INPUT')
for key, value in input_lang.ind2vocab.items():
    print(f'{key}: {value}\t{tpr.filler_emb.weight[key].norm()}')
print('OUTPUT')
for key, value in output_lang.ind2vocab.items():
    print(f'{key}: {value}')

INPUT
0: <PAD>	0.0
1: <EOB>	9.14475154876709
2: <TOKEN_0>	8.799625396728516
3: <TOKEN_1>	7.769104957580566
4: <TOKEN_2>	8.615045547485352
5: <TOKEN_3>	8.349431037902832
6: give	8.492795944213867
7: me	8.643167495727539
8: the	8.588788032531738
9: cities	9.896332740783691
10: in	7.988222122192383
11: m0	9.697813034057617
12: what	8.416818618774414
13: are	8.168783187866211
14: high	8.860904693603516
15: points	10.132843017578125
16: of	9.519013404846191
17: states	9.757287979125977
18: surrounding	9.1204833984375
19: name	10.013427734375
20: rivers	8.625320434570312
21: can	8.815765380859375
22: you	8.57943058013916
23: tell	9.366765975952148
24: capital	9.518416404724121
25: could	9.52619457244873
26: is	9.410720825195312
27: highest	9.981000900268555
28: point	9.248022079467773
29: state	10.932914733886719
30: all	8.359414100646973
31: which	8.631967544555664
32: lakes	9.239540100097656
33: largest	9.9051513671875
34: longest	10.648015022277832
35: river	11.098204612731934
36: that	9.

In [None]:
print(tpr.filler_emb.weight.norm(dim=-1))

In [None]:
torch.linalg.matrix_rank(tpr.out.weight)

## Visualize Embeddings

In [None]:
embeddings = tpr.out.weight.data.numpy()
from sklearn.decomposition import PCA

#output_indices = [10,11,12,13,16,21]

pca = PCA(n_components=2)  # for 2D visualization
reduced_embeddings = pca.fit_transform(embeddings)
import matplotlib.pyplot as plt

plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1])
plt.xlabel('PCA Component 1')
plt.ylabel('PCA Component 2')
# Annotate each point
for i, label in enumerate(output_lang.ind2vocab):
    plt.annotate(output_lang.ind2vocab[label], (reduced_embeddings[i, 0], reduced_embeddings[i, 1]))
plt.title('Embedding Visualization using PCA')
plt.show()

In [None]:
torch.linalg.matrix_rank(tpr.filler_emb.weight)

In [None]:
Tree.fromstring('(NT I_WALK I_JUMP)').pretty_print()

In [None]:
Tree.fromstring('( COMMAND ( PHRASE ( VERB ( ACTION jump ) ) ) )	( CX ( PX ( VX ( AX I_JUMP ) ) ) )')