In [1]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence
import json
import xmltodict
import glob
        
class webNLG_DATASET(Dataset):
    def __init__(self, data_path):
        self.category_list = []
        self.modifiedtripleset_list = []
        self.text_list = []            
        
        xml_files = glob.glob(data_path+'*')
        for xml_file in xml_files:        
            with open(xml_file,'r') as f:
                xmlString = f.read()
            dict_data = xmltodict.parse(xmlString)['benchmark']['entries']['entry']
            if not isinstance(dict_data, list):
                dict_data = [dict_data]                

            # challenge version
            for i in range(len(dict_data)):
                y=dict_data[i]
                self.category_list.append(y['@category'])

                self.modifiedtripleset_list.append(y['modifiedtripleset']['mtriple'])
                z = y['lex']
                if isinstance(z, list):
                    z = z[0]
                self.text_list.append(z['#text'])

                
                # version 2.0
#                 for i in range(len(dict_data)):
#                     y=dict_data[i]

#                     self.category_list.append(y['@category'])

#                     if 'test' in xml_file.split('/'):
#                         self.modifiedtripleset_list.append(y['modifiedtripleset']['otriple'])
#                     else:
#                         self.modifiedtripleset_list.append(y['modifiedtripleset']['mtriple'])

#                     z = y['lex']
#                     if isinstance(z, list):
#                         z = z[0]
#                     self.text_list.append(z['text'])
        
    def __len__(self):
        return len(self.category_list)

    def __getitem__(self, idx): 
        triple_total = []
        if isinstance(self.modifiedtripleset_list[idx], list):
            for triple_list in self.modifiedtripleset_list[idx]:
                triple_total += triple_list.split('|')
        else:
            triple_total += self.modifiedtripleset_list[idx].split('|')
            
        triple = [x.strip() for x in triple_total]
        
        return self.category_list[idx], triple, self.text_list[idx]

In [2]:
# data_path = '/data/private/dataset/webnlg/data/v2.0/en/train/'
data_path = '/data/private/WebNLG-models/chimera-master/data/WebNLG/raw/test/'
webNLG_data = webNLG_DATASET(data_path)
dataloader = DataLoader(webNLG_data, batch_size=1, shuffle=False, num_workers=4)

In [8]:
from model import *
my_model = webmodel().cuda()
model_path = '/data/private/WebNLG-models/simple_model/pretrained/try_1/1'
my_model.load_state_dict(torch.load(model_path + '/model.bin'))
my_model.eval()
print('ok')

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at /data/private/GPT/openai-gpt2/base/ and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ok


In [4]:
my_model.tokenizer.all_special_ids, my_model.tokenizer.all_special_tokens, my_model.tokenizer.bos_token_id

([50256, 50259, 50256, 50257, 50258],
 ['<|endoftext|>', '<tr>', '<|endoftext|>', '<S>', '<c>'],
 50257)

In [3]:
webNLG_data[0]

IndexError: list index out of range

In [28]:
k = 0
for i_batch, sample_batched in enumerate(dataloader):
    k += 1
    cate, triple, text = sample_batched
    print('cate: ', cate[0])
    print('triple: ', triple)
    print('text: ', text[0])
    print('#####################################')       
    
    input_tensor = my_model.make_tensor(cate, triple, '').squeeze(0)
    
    response = my_model.generate(input_tensor)
    
#     print(k)
#     print("Target text: ", target_sentence)
    print("Response text: ", response)
    print('')    

    if k == 2:
        break

cate:  Airport
triple:  [('Abilene_Regional_Airport',), ('cityServed',), ('Abilene,_Texas',)]
text:  Abilene, Texas is served by the Abilene regional airport.
#####################################
Response text:  Abilene Regional Airport serves the city of Abilene, Texas. The airport is located in the city of Abilene, Texas and is located in the city of Abilene, Texas

cate:  Airport
triple:  [('Adolfo_Suárez_Madrid–Barajas_Airport',), ('location',), ('"Madrid, Paracuellos de Jarama, San Sebastián de los Reyes and Alcobendas"',)]
text:  Adolfo Suárez Madrid–Barajas Airport can be found in Madrid, Paracuellos de Jarama, San Sebastián de los Reyes and Alcobendas.
#####################################
Response text:  Adolfo Suárez Madrid–Barajas Airport is located in the city of Madrid, Paracuellos de Jarama, San Sebastián de los Reyes and Alcobendas



In [20]:
xml_folders = glob.glob(data_path+'*')
xml_folders.sort()
xml_folders

['/data/private/dataset/webnlg/data/v2.0/en/test/1triples',
 '/data/private/dataset/webnlg/data/v2.0/en/test/2triples',
 '/data/private/dataset/webnlg/data/v2.0/en/test/3triples',
 '/data/private/dataset/webnlg/data/v2.0/en/test/4triples',
 '/data/private/dataset/webnlg/data/v2.0/en/test/5triples',
 '/data/private/dataset/webnlg/data/v2.0/en/test/6triples',
 '/data/private/dataset/webnlg/data/v2.0/en/test/7triples']

In [21]:
for xml_folder in xml_folders:
    xml_roots = xml_folder+'/*'
    xml_files = glob.glob(xml_roots)

In [11]:
model_pathes = '/data/private/WebNLG-models/simple_model/pretrained/try_1/*'
model_folders = glob.glob(model_pathes)

for model_folder in model_folders:
    print(model_folder)

/data/private/WebNLG-models/simple_model/pretrained/try_1/runs
/data/private/WebNLG-models/simple_model/pretrained/try_1/1
/data/private/WebNLG-models/simple_model/pretrained/try_1/2
/data/private/WebNLG-models/simple_model/pretrained/try_1/3
/data/private/WebNLG-models/simple_model/pretrained/try_1/4
/data/private/WebNLG-models/simple_model/pretrained/try_1/5
/data/private/WebNLG-models/simple_model/pretrained/try_1/6


In [12]:
len(dataloader)

1862

## check prediction

In [1]:
f = open('prediction/prediction_1.txt')
f2 = open('prediction/reference.txt')
texts = f.readlines()
refs = f2.readlines()
f.close()
f2.close()

In [9]:
max_len = 0
for i in range(len(refs)):
    text = texts[i]
    ref = refs[i]
    x1 = my_model.tokenizer.encode(text.strip())
    x2 = my_model.tokenizer.encode(ref.strip())
    if len(x2) > max_len:
        max_len = len(x2)
print(max_len)

92


In [17]:
my_model.tokenizer.encode('.')

[13]

In [19]:
my_model.END_idx_list

[50258, 50256, 50257, 50256, 50259]

In [21]:
my_model.tokenizer.eos_token_id,my_model.tokenizer.decode(50256)

(50256, '<|endoftext|>')

## for evaluation

In [12]:
f = open('prediction/reference.txt')
texts = f.readlines()
f.close()

In [13]:
f2 = open('prediction/enter_reference.txt', 'w')
for text in texts:
    f2.write(text+'\n')
f2.close()

In [2]:
f = open('prediction/prediction_1.txt')
texts = f.readlines()
f.close()

In [3]:
f2 = open('prediction/modify_prediction_1.txt', 'w')
for text in texts:
    f2.write(text.replace('_', ' ').replace('@',''))
f2.close()