In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from evo import *

params = {
    'd_model': 512,
    'nhead': 8,
    'dim_feedforward': 2048,
    'dropout': 0.1,
    'encoder_layers': 3,
    'decoder_layers': 3,
    'max_len': 64
}

evolver = Evolver(**params)
ps_evolver = PointerStyleEvolver(**params)

In [None]:
from data import *
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_loader = supervised_loader(
    path='data/ud/ud_train_3.0.jsonl',
    max_len=10,
    tokenizer=tokenizer,
    batch_size=4,
    cache_prefix=None,
    all_tokens=True,
    limit=20,
    sampler=StratifiedInfiniteSampler
)

eval_loader = unsupervised_loader(
    path='data/toy/toy.jsonl',
    max_len=10,
    tokenizer=tokenizer,
    batch_size=4,
    sampler=StratifiedInfiniteSampler
)

In [None]:
from run import apply_edits

traj_input_ids, _, traj_edit_tgts, _ = next(iter(train_loader))

apply_edits(traj_input_ids[:, 0], tuple(map(lambda x: x[:, 0], traj_edit_tgts)))

In [None]:
from torch.optim import AdamW

kwargs = {
    'train_loader': train_loader,
    'eval_loader': eval_loader,
    'train_steps': 1,
    'eval_steps': 2,
    'grad_accum_steps': 1,
    'clip_gradients': False,
    'checkpoint_at': 20,
    'eval_at': 1
}

print('STARTING REGULAR EVOLVER')
train_evolver(evolver, AdamW(evolver.parameters(), lr=3e-4), None, **kwargs)

print('STARTING PS EVOLVER')
train_evolver(ps_evolver, AdamW(ps_evolver.parameters(), lr=3e-4), None, **kwargs)

## multihead pointer

In [None]:
import torch
from trans import MultiheadPointer

pointer = MultiheadPointer(512, 8)

mem = torch.randn(3, 10, 512)
tgt = torch.randn(3, 5, 512)
src_pad_mask = torch.full((3, 10), True)
src_pad_mask[:, :7] = False

idx_weights = pointer(tgt, mem, key_padding_mask=src_pad_mask)
idx_weights

## regressions

```
python evo.py --config=configs/toy/sup-toy.json
python evo.py --config=configs/toy/sup-toy-epoch.json
python evo.py --config=configs/toy/ps-sup-toy.json
!python evo.py --config=configs/toy/ps-unsup-toy.json
python evo.py --config=configs/toy/noshare-sup-toy-e1d1.json
python evo.py --config=configs/toy/den-toy.json
```

In [None]:
!python evo.py --config=configs/toy/sup-toy.json --local

In [None]:
!python evo.py --config=configs/toy/sup-toy-epoch.json --local

In [None]:
!python evo.py --config=configs/toy/ps-unsup-toy.json

In [None]:
!python evo.py --config=configs/toy/ps-sup-toy.json

In [None]:
!python evo.py --config=configs/toy/den-toy.json

In [4]:
!python evo.py --config=configs/toy/ar-d-toy.json

INFO:train:using <class 'data.StratifiedInfiniteSampler'> in loader
INFO:train:using ar loaders
INFO:data:tokenizing sequence dataset...
INFO:data:done in 0.00 seconds!
tokenizing trajectories: 100%|██████████████████| 4/4 [00:00<00:00, 2262.60it/s]
INFO:train:starting run for 500 steps
INFO:train:eval every 20 steps
INFO:train:checkpoint every 500 steps
INFO:train:using <class '__main__.Transformer'>
INFO:train:starting new run
[34m[1mwandb[0m: Currently logged in as: [33mtjbai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/Users/bai/argo/projects/ddm/evolver/wandb/run-20240821_221147-j1dborig[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mar-d-toy_20240821_221146

In [None]:
from data import unsupervised_loader, StratifiedInfiniteSampler
from transformers import BertTokenizer

loader = unsupervised_loader(
    'data/toy/toy.jsonl',
    max_len=10,
    tokenizer=BertTokenizer.from_pretrained('bert-base-uncased'),
    batch_size=2,
    sampler=StratifiedInfiniteSampler
)

In [None]:
for batch in loader:
    print(batch[0].shape)
    traj_input_ids = batch[0]
    break

input_ids = traj_input_ids[:, 0]
output_ids = traj_input_ids[:, 1]

In [None]:
B, N = input_ids.shape