In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Alignment

In [None]:
from simalign import SentenceAligner
aligner = SentenceAligner(model='bert', token_type='bpe', matching_methods='m')

In [None]:
from data import generate_edits, to_str

a = 'Jo was great at the Best Western.'
b = 'At the Best Western, the hotel manager Jo was really good.'

a_toks = tokenizer.tokenize(a)
for edit in generate_edits(a, b, tokenizer, aligner):
    print(to_str(*edit, a_toks, tokenizer))

# Toy Dataset

## Forced Training

In [None]:
from data import EvolverDataset
from torch.utils.data import DataLoader

trajectory_list = []
for letter in 'abcdefghijklmnopqrstuvwxyz':
    trajectory = []
    for i in range(4):
        trajectory.append(' '.join([letter for _ in range(2**i)]))
    trajectory_list.append(trajectory)
    
dataset = EvolverDataset(trajectory_list, max_len=10, force_targets=True, name='toy')
loader = DataLoader(dataset, batch_size=4)

In [None]:
from data import elaborate
input_ids, traj_edit_tgts = next(iter(loader))
elaborate(traj_edit_tgts)

In [None]:
from run import train_forced
from model import Evolver
from torch.optim import AdamW

evolver = Evolver(d_model=512, max_len=10, include_sub=False)
optim = AdamW(evolver.parameters(), lr=1e-3)

train_forced(evolver, optim, loader, 10, 10, None, 'test')

# Particle Filtering

In [None]:
from model import Evolver

evolver = Evolver(nhead=8)

s1 = 'a b c d'
s2 = 'b c d a'

s1 = tokenizer(s1, return_tensors='pt', max_length=10, padding='max_length')['input_ids'].squeeze()
s2 = tokenizer(s2, return_tensors='pt', max_length=10, padding='max_length')['input_ids'].squeeze()

src, src_pad_mask = evolver.get_src(s1)
_, tgt_pad_mask = evolver.get_src(s2)

In [None]:
from run import particle_filter

evolver.eval()

res, *_ = particle_filter(
    evolver, s1, s2,
    src, src_pad_mask, tgt_pad_mask,
    5, 2, 1.0, device='cpu'
)

In [None]:
from data import elaborate
elaborate(res)

# MCEM

In [None]:
from data import TrainLoader

traj_list = []
for c in 'abcd':
    traj = []
    for i in range(4):
        traj.append(' '.join([c for _ in range(2**i)]))
    traj_list.append(traj) 

train_loader = TrainLoader(traj_list, bsz=1, max_len=10, tokenizer=tokenizer).to('cpu')

In [None]:
from data import EvalLoader, get_input_ids

traj_list = [' '.join([c for _ in range(8)]) for c in 'wxyz']

eval_loader = EvalLoader(traj_list, num_samples=3, max_len=10, tokenizer=tokenizer).to('cpu')

In [None]:
from model import Evolver
from train import train_evolver
from torch.optim import AdamW

evolver = Evolver(encoder_layers=3, decoder_layers=3, device='cpu')
optim = AdamW(evolver.parameters(), lr=1e-4)

In [None]:
train_evolver(
    evolver, optim, train_loader, eval_loader,
    epochs=25, checkpoint_at=50, eval_at=1,
    num_particles=10, threshold=0, temperature=1,
    prefix='test-3-larger', device='cpu'
)

In [None]:
import torch
from run import sample_trajectory

evolver.eval()

rand = Evolver(encoder_layers=1, decoder_layers=1)
rand.eval()

traj_input_ids = eval_loader.traj_input_ids[0]

print(sample_trajectory(
    evolver, traj_input_ids,
    1, 0, 1, device='cpu'
)[1])

print(sample_trajectory(
    rand, traj_input_ids,
    1, 0, 1, device='cpu'
)[1])

## Seq2Seq Example

In [None]:
from torch.utils.data import DataLoader
from data import Seq2SeqDataset, StratifiedInfiniteSampler
from transformers import BertTokenizer

dataset = Seq2SeqDataset(
    inputs=['hello', 'hello my', 'hello my name', 'hello my name is'],
    outputs=['hello my', 'hello my name', 'hello my name is', 'hello my name is TJ'],
    max_len=10,
    tokenizer=tokenizer
)

train_loader = DataLoader(dataset, batch_size=4, sampler=StratifiedInfiniteSampler(dataset, 4))
eval_loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
from model import Transformer
from torch.optim import AdamW
from constants import VOCAB_SIZE

model = Transformer(
    d_model=512,
    nhead=2,
    max_len=10,
    dropout=0.1,
    vocab_size=VOCAB_SIZE,
    encoder_layers=2,
    decoder_layers=2
)

optim = AdamW(model.parameters(), lr=1e-4)

In [None]:
from train import train_ar

train_ar(
    model, optim, None,
    train_loader, eval_loader,
    100, 1, 2000, 20,
    'cpu', 'toy'
)

In [None]:
!python train.py --config=configs/ud-2.0.0.json --device=cpu

## Batched Particle Filter

In [None]:
from data import TrajectoryDataset, StratifiedInfiniteSampler, collate_unsupervised
from torch.utils.data import DataLoader

traj_list = []
for c in 'abcd':
    traj = ['']
    for i in range(4):
        traj.append(' '.join([c for _ in range(2**i)]))
    traj_list.append(traj) 
    
dataset = TrajectoryDataset(
    traj_list=traj_list,
    log_probs=[0 for _ in range(2)],
    max_len=10, tokenizer=tokenizer
)

train_loader = DataLoader(
    dataset,
    batch_size=2,
    collate_fn=collate_unsupervised,
    sampler=StratifiedInfiniteSampler(dataset, 2)
)

In [None]:
import json

with open('data/toy/toy.jsonl', 'w') as f:
    for thing in traj_list:
        json.dump((thing, 0), f)
        f.write('\n')

In [None]:
from model import Evolver
from torch.optim import AdamW

evolver = Evolver(
    d_model=128,
    nhead=2,
    max_len=10,
    encoder_layers=4,
    decoder_layers=4,
    dropout=0,
    dim_feedforward=512
)

optim = AdamW(evolver.parameters(), lr=1e-3)

In [None]:
from train import train_evolver

train_evolver(
    evolver, optim, None, train_loader, train_loader,
    train_steps=100, eval_steps=1, grad_accum_steps=1, checkpoint_at=200, eval_at=10,
    num_particles=5, threshold=3, temperature=1, resample_at=1,
    device='cpu', prefix='test'
)

## Supervised (Best-of-1) Training

In [None]:
from model import Evolver
from torch.optim import AdamW

evolver = Evolver(
    d_model=128,
    nhead=2,
    max_len=10,
    encoder_layers=4,
    decoder_layers=4,
    dropout=0,
    dim_feedforward=512
)

optim = AdamW(evolver.parameters(), lr=1e-3)

In [None]:
from data import TrajectoryDataset, SupervisedTrajectoryDataset, StratifiedInfiniteSampler
from torch.utils.data import DataLoader
    
dataset = SupervisedTrajectoryDataset(
    traj_list=[['', 'a', 'a a a'], ['', 'b', 'b b', 'b b b b b', 'b b b b b b b b']],
    log_probs=[0 for _ in range(2)],
    max_len=10, tokenizer=tokenizer
)

eval_dataset = TrajectoryDataset(
    traj_list=[['', 'a', 'a a a'], ['', 'b', 'b b', 'b b b b b', 'b b b b b b b b']],
    log_probs=[0 for _ in range(2)],
    max_len=10, tokenizer=tokenizer
)

In [None]:
from data import collate_supervised, collate_unsupervised

train_loader = DataLoader(
    dataset,
    batch_size=2,
    collate_fn=collate_supervised,
    sampler=StratifiedInfiniteSampler(dataset, 2)
)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=1,
    collate_fn=collate_unsupervised,
    # sampler=StratifiedInfiniteSampler(eval_dataset, 1)
)

In [None]:
from train import train_evolver

train_evolver(
    evolver, optim, None, train_loader, eval_loader,
    train_steps=100, eval_steps=1, grad_accum_steps=1, checkpoint_at=200, eval_at=10,
    num_particles=5, threshold=3, temperature=1, resample_at=1,
    device='cpu', name='test'
)

## inspection

In [None]:
import torch
from model import Evolver
from data import get_input_ids

model = Evolver(max_len=64)
model.load_state_dict(torch.load('checkpoints/sup-ud-3.0.pt', map_location='cpu')['model'])
model.eval()

inputs = ['hello', 'hello my', 'hello my name is']
input_ids = get_input_ids(inputs, 64, tokenizer)

src, src_pad_mask = model.get_src(input_ids)

In [None]:
from run import fast_sample

edit_tgts, log_probs = fast_sample(model, input_ids, src, src_pad_mask, 3, 0, 100)

In [None]:
from data import elaborate

elaborate(edit_tgts)

In [None]:
import matplotlib.pyplot as plt

vals = torch.exp(probs[1][2].squeeze())
sup_vals = torch.exp(sup_probs[1][2].squeeze())

val, idxs = torch.topk(vals, k=5)
sup_val, sup_idxs = torch.topk(sup_vals, k=5)

fig, axs = plt.subplots(2)

axs[0].plot(torch.arange(VOCAB_SIZE).detach(), vals.detach())
axs[1].plot(torch.arange(VOCAB_SIZE).detach(), sup_vals.detach())

## playground