# RobustFill Demo

In [1]:
from heapq import nlargest
import torch
import torch.nn as nn
from beam import beam_search
from tokens import Tokenizer
from train import full_config

## Download checkpoint and load it

In [2]:
tokenizer = Tokenizer.create()
config = full_config()
model = nn.DataParallel(config.model)
# Get the checkpoint from: https://huggingface.co/eddyyeo/robustfill
loaded = torch.load('./checkpoint.pth', map_location=torch.device('cpu'))
model.load_state_dict(loaded['model_state_dict'])

<All keys matched successfully>

## Architecture of the Model

In [3]:
model

DataParallel(
  (module): RobustFill(
    (embedding): Embedding(75, 128)
    (input_encoder): LSTM(128, 512)
    (output_encoder): AttentionLSTM(
      (rnn): SingleAttention(
        (attention): LuongAttention(
          (linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (lstm): LSTM(640, 512)
      )
    )
    (program_decoder): ProgramDecoder(
      (program_lstm): SingleAttention(
        (attention): LuongAttention(
          (linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (lstm): LSTM(1050, 512)
      )
      (max_pool_linear): Linear(in_features=512, out_features=512, bias=True)
      (softmax_linear): Linear(in_features=512, out_features=538, bias=True)
    )
  )
)

## Example input-output pairs

In [4]:
example_strings = [
    ('Jacob Devlin', 'Devlin, J.'),
    ('Eddy Yeo', 'Yeo, E.'),
    ('Andrej Karpathy', 'Karpathy, A.'),
    ('Anatoly Yakovenko', 'Yakovenko, A.'),
]

## Strings to transform with the generated program

In [5]:
test_strings = [
    'Elon Musk',
    'Joe Rogan',
    'Balaji Srinivasan',
]

In [6]:
topk = beam_search(
    model=model.module,
    tokenizer=tokenizer,
    width=100,
    max_program_length=64,
    strings=example_strings)

## Transformed strings

In [7]:
program = tokenizer.parse_program(nlargest(5, topk)[0][1])
for ts in test_strings:
    print(f'{ts} --> {program.eval(ts)}')

Elon Musk --> Musk, E.
Joe Rogan --> Rogan, J.
Balaji Srinivasan --> Srinivasan, B.


## Generated program

In [8]:
program

Concat(
    Compose(
        Trim(),
        GetFrom(<Type.LOWER: 6>)
    ),
    ConstStr(','),
    ConstStr(' '),
    GetUpto(<Type.CHAR: 8>),
    ConstStr('.')
)