In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from evo import *

params = {
    'd_model': 512,
    'nhead': 8,
    'dim_feedforward': 2048,
    'dropout': 0.1,
    'encoder_layers': 3,
    'decoder_layers': 3,
    'max_len': 64
}

evolver = Evolver(**params)
ps_evolver = PointerStyleEvolver(**params)

In [None]:
from data import *
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_loader = supervised_loader(
    path='data/toy/toy.jsonl',
    max_len=10,
    tokenizer=tokenizer,
    batch_size=4,
    cache_prefix=None,
    all_tokens=True,
    limit=100,
    sampler=StratifiedInfiniteSampler
)

eval_loader = unsupervised_loader(
    path='data/toy/toy.jsonl',
    max_len=10,
    tokenizer=tokenizer,
    batch_size=4,
    sampler=StratifiedInfiniteSampler
)

In [None]:
from torch.optim import AdamW

kwargs = {
    'train_loader': train_loader,
    'eval_loader': eval_loader,
    'train_steps': 1,
    'eval_steps': 2,
    'grad_accum_steps': 1,
    'clip_gradients': False,
    'checkpoint_at': 20,
    'eval_at': 1
}

print('STARTING REGULAR EVOLVER')
train_evolver(evolver, AdamW(evolver.parameters(), lr=3e-4), None, **kwargs)

print('STARTING PS EVOLVER')
train_evolver(ps_evolver, AdamW(ps_evolver.parameters(), lr=3e-4), None, **kwargs)

## multihead pointer

In [None]:
import torch
from trans import MultiheadPointer

pointer = MultiheadPointer(512, 8)

mem = torch.randn(3, 10, 512)
tgt = torch.randn(3, 5, 512)
src_pad_mask = torch.full((3, 10), True)
src_pad_mask[:, :7] = False

idx_weights = pointer(tgt, mem, key_padding_mask=src_pad_mask)
idx_weights

## regressions

```
python evo.py --config=configs/toy/sup-toy.json
python evo.py --config=configs/toy/sup-toy-epoch.json
python evo.py --config=configs/toy/ps-sup-toy.json
python evo.py --config=configs/toy/noshare-sup-toy-e1d1.json
```

In [None]:
!python evo.py --config=configs/toy/sup-toy.json --local

In [None]:
!python evo.py --config=configs/toy/sup-toy-epoch.json --local

In [None]:
!python evo.py --config=configs/toy/ps-sup-toy.json --local