In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from evo import *

params = {
    'd_model': 512,
    'nhead': 8,
    'dim_feedforward': 2048,
    'dropout': 0.1,
    'encoder_layers': 3,
    'decoder_layers': 3,
    'max_len': 64
}

evolver = Evolver(**params)
ps_evolver = PointerStyleEvolver(**params)

In [None]:
from data import *
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_loader = supervised_loader(
    path='data/ud/ud_train_3.0.jsonl',
    max_len=10,
    tokenizer=tokenizer,
    batch_size=4,
    cache_prefix=None,
    all_tokens=True,
    limit=20,
    sampler=StratifiedInfiniteSampler
)

eval_loader = unsupervised_loader(
    path='data/toy/toy.jsonl',
    max_len=10,
    tokenizer=tokenizer,
    batch_size=4,
    sampler=StratifiedInfiniteSampler
)

In [None]:
from run import apply_edits

traj_input_ids, _, traj_edit_tgts, _ = next(iter(train_loader))

apply_edits(traj_input_ids[:, 0], tuple(map(lambda x: x[:, 0], traj_edit_tgts)))

In [None]:
from torch.optim import AdamW

kwargs = {
    'train_loader': train_loader,
    'eval_loader': eval_loader,
    'train_steps': 1,
    'eval_steps': 2,
    'grad_accum_steps': 1,
    'clip_gradients': False,
    'checkpoint_at': 20,
    'eval_at': 1
}

print('STARTING REGULAR EVOLVER')
train_evolver(evolver, AdamW(evolver.parameters(), lr=3e-4), None, **kwargs)

print('STARTING PS EVOLVER')
train_evolver(ps_evolver, AdamW(ps_evolver.parameters(), lr=3e-4), None, **kwargs)

## multihead pointer

In [None]:
import torch
from trans import MultiheadPointer

pointer = MultiheadPointer(512, 8)

mem = torch.randn(3, 10, 512)
tgt = torch.randn(3, 5, 512)
src_pad_mask = torch.full((3, 10), True)
src_pad_mask[:, :7] = False

idx_weights = pointer(tgt, mem, key_padding_mask=src_pad_mask)
idx_weights

## regressions

```
python evo.py --config=configs/toy/sup-toy.json
python evo.py --config=configs/toy/sup-toy-epoch.json
python evo.py --config=configs/toy/ps-sup-toy.json
!python evo.py --config=configs/toy/ps-unsup-toy.json
python evo.py --config=configs/toy/noshare-sup-toy-e1d1.json
python evo.py --config=configs/toy/den-toy.json
```

In [None]:
!python evo.py --config=configs/toy/sup-toy.json --local

In [None]:
!python evo.py --config=configs/toy/sup-toy-epoch.json --local

In [None]:
!python evo.py --config=configs/toy/ps-unsup-toy.json

In [None]:
!python evo.py --config=configs/toy/ps-sup-toy.json

In [None]:
!python evo.py --config=configs/toy/den-toy.json

In [None]:
!python evo.py --config=configs/toy/ar-d-toy.json

## sampling

In [None]:
import torch
from evo import PointerStyleEvolver

model = PointerStyleEvolver(pointer_attn=True)
model.load_state_dict(torch.load('ps-sup-imdb-pattn_20240822_235245-9900.pt', map_location='cpu')['model'])

In [None]:
from run import sample, particle_filter
from data import get_input_ids
from utils import BT

inputs = ['', '', '']
outputs = ['hello my', 'ac', 'test test test']
input_ids = get_input_ids(inputs, 512, tokenizer=BT)
output_ids = get_input_ids(outputs, 512, tokenizer=BT)

model.eval()
edit_tgts, tgt = sample(model, input_ids, None, M=1, threshold=0, resample_at=1e9)

In [None]:
from data import elaborate

elaborate(edit_tgts)

In [None]:
from run import apply_edits

input_ids_2 = apply_edits(input_ids, edit_tgts)
edit_tgts_2, tgt = sample(model, input_ids_2, None, M=20, threshold=10, resample_at=1e9)

In [None]:
print(elaborate(edit_tgts_2))
BT.decode([1045, 2031, 2464])