In [1]:
import torch
from PlanRealizeGenerator import PlanRealizeGenerator
from WebNLGDatasetReader import Benchmark, select_test_files
from Utils import webnlg_entry_to_examples, validate_plan
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

Prepare test data

In [2]:
b = Benchmark()
files = select_test_files('webnlg_v3/en/test')
b.fill_benchmark(files)

test_examples = []
for entry in b.entries:
    test_examples.extend(webnlg_entry_to_examples(entry))

In [3]:
test_examples[0]

{'category': 'SportsTeam',
 'eid': 'Id1',
 'size': '5',
 'triples_map': {0: '<S> Estádio_Municipal_Coaracy_da_Mata_Fonseca <P> location <O> Arapiraca',
  1: '<S> Agremiação_Sportiva_Arapiraquense <P> league <O> Campeonato_Brasileiro_Série_C',
  2: '<S> Campeonato_Brasileiro_Série_C <P> country <O> Brazil',
  3: '<S> Agremiação_Sportiva_Arapiraquense <P> nickname <O> "\'\'Alvinegro"',
  4: '<S> Agremiação_Sportiva_Arapiraquense <P> ground <O> Estádio_Municipal_Coaracy_da_Mata_Fonseca'},
 'input': '<S> Estádio_Municipal_Coaracy_da_Mata_Fonseca <P> location <O> Arapiraca <S> Agremiação_Sportiva_Arapiraquense <P> league <O> Campeonato_Brasileiro_Série_C <S> Campeonato_Brasileiro_Série_C <P> country <O> Brazil <S> Agremiação_Sportiva_Arapiraquense <P> nickname <O> "\'\'Alvinegro" <S> Agremiação_Sportiva_Arapiraquense <P> ground <O> Estádio_Municipal_Coaracy_da_Mata_Fonseca',
 'lid': 'Id1',
 'text': 'Estádio Municipal Coaracy da Mata Fonseca is the name of the ground of Agremiação Sportiva A

Load plan-realize model

In [3]:
model = PlanRealizeGenerator(planner_path='planner', realizer_path='royeis/t5-realizer')
model.plan_mode()

In [4]:
outputs = []
prev_eid = -1

for e in tqdm(test_examples):
    eid = e['eid']
    if prev_eid == eid:
        continue
    
    outputs.append({'eid': eid, 'plan': model.generate_plan(e)})
    prev_eid = eid

100%|██████████| 5150/5150 [04:07<00:00, 20.84it/s]


In [12]:
outputs[55]

{'eid': 'Id56',
 'plan': '<sentence> <S> Nord_(Year_of_No_Light_album) <P> followedBy <O> Live_at_Roadburn_<S> Nord_(Year_of_No_Light_album) <P> releaseDate <O> 2006-09-06008_(Year_of_No_Light_album) <S> Nord_(Year_of_No_Light_album) <P> runtime <O> 58.<S> Nord_(Year_of_No_Light_album) <P> precededBy <O> Demo_2004_(Year_of_No_Light_album)1 <sentence> <S> Nord_(Year_of_No_Light_album) <P> precededBy <O> Demo_2004_(Year_of_No_Light_album) <S> Nord_(Year_of_No_Light_album) <P> artist <O> Year_of_No_Light <sentence> <S> Nord_(Year_of_No_Light_album) <P> releaseDate <O> 2006-09-06'}

In [None]:
model.realize_mode()
for o in tqdm(outputs):
    o['generation'] = model.generate_realization(o['plan'])

In [15]:
n_i = 0
n_e = 0
n_f = 0
prev_eid = -1
for e in tqdm(test_examples):
    eid = e['eid']
    if prev_eid == eid:
        continue
    
    n_i += 1
    prev_eid = eid
    size = int(e['size'])
    plan = model._plan(e['input'])
    v_plan = model._validate_plan(plan, size)
    if v_plan != plan:
        print(f'plan: {plan}')
        print(f'fixed plan: {v_plan}')
        n_f += 1

    if not validate_plan(v_plan, size):
        print(plan)
        print(v_plan, size)
        n_e += 1

  3%|▎         | 164/5150 [00:08<04:56, 16.83it/s]

In [11]:
n_i

1779

In [15]:
from transformers import T5ForConditionalGeneration, T5Tokenizer

In [21]:
planner = T5ForConditionalGeneration.from_pretrained('royeis/T5-FlowNLG-Planner').to(DEVICE)
tokenizer = T5Tokenizer.from_pretrained('t5-base')
# tokenizer.add_tokens(['<S>', '<P>', '<O>'])

In [22]:
n_i = 0
n_e = 0
prev_eid = -1
for e in tqdm(test_examples):
    eid = e['eid']
    if prev_eid == eid:
        continue
    
    n_i += 1
    prev_eid = eid
    size = int(e['size'])
    input_ids = tokenizer(e['planner_input'], return_tensors='pt').to(DEVICE)
    out = planner.generate(**input_ids)
    plan = tokenizer.decode(out[0], skip_special_tokens=True)
    if not validate_plan(plan, size):
        # print(plan, size)
        n_e += 1

100%|██████████| 4928/4928 [03:59<00:00, 20.62it/s]


In [24]:
n_i

1862