In [1]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
torch.set_printoptions(linewidth=10000)
tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')
model = AutoModelForQuestionAnswering.from_pretrained('pgajo/mbert-xl-wa-tasteset-recipe-aligner').to('cuda:1')

create bilingual version

In [2]:
import json
path = '/home/pgajo/working/food/data/TASTEset/data/TASTEset_updated.json'
with open(path, encoding='utf8') as f:
    data = json.load(f)
print(data)

with open('/home/pgajo/working/food/data/TASTEset/data/TASTEset_raw.it', 'r', encoding='utf8') as f:
    italian_recipes = f.readlines()
bilingual_data = []
for i, entry in enumerate(data['annotations']):
    new_entry = {}
    new_entry['text_en'] = entry['text']
    new_entry['entities_en'] = entry['entities']
    new_entry['text_it'] = italian_recipes[i].strip()
    new_entry['entities_it'] = []
    bilingual_data.append(new_entry)
bilingual_dataset = {'classes': data['classes'], 'annotations': bilingual_data}
# save bilingual dataset to json
with open(path.replace('.json', '_en-it.json'), 'w', encoding='utf8') as f:
    json.dump(bilingual_dataset, f, ensure_ascii=False)

{'classes': ['FOOD', 'QUANTITY', 'UNIT', 'PROCESS', 'PHYSICAL_QUALITY', 'COLOR', 'TASTE', 'PURPOSE', 'PART'], 'annotations': [{'text': '5 ounces rum 4 ounces triple sec 3 ounces Tia Maria 20 ounces orange juice', 'entities': [[0, 1, 'QUANTITY'], [2, 8, 'UNIT'], [9, 12, 'FOOD'], [13, 14, 'QUANTITY'], [15, 21, 'UNIT'], [22, 32, 'FOOD'], [33, 34, 'QUANTITY'], [35, 41, 'UNIT'], [42, 51, 'FOOD'], [52, 54, 'QUANTITY'], [55, 61, 'UNIT'], [62, 74, 'FOOD']]}, {'text': '2 tubes cinnamon roll, refrigerated, with icing 4 tablespoons butter, melted 6 eggs 1/2 cup milk 2 teaspoons cinnamon 2 teaspoons vanilla 1 cup maple syrup', 'entities': [[0, 1, 'QUANTITY'], [2, 7, 'UNIT'], [8, 21, 'FOOD'], [23, 35, 'PROCESS'], [37, 41, 'FOOD'], [42, 47, 'FOOD'], [48, 49, 'QUANTITY'], [50, 61, 'UNIT'], [62, 68, 'FOOD'], [70, 76, 'PROCESS'], [77, 78, 'QUANTITY'], [79, 83, 'FOOD'], [84, 87, 'QUANTITY'], [88, 91, 'UNIT'], [92, 96, 'FOOD'], [97, 98, 'QUANTITY'], [99, 108, 'UNIT'], [109, 117, 'FOOD'], [118, 119, 'QUAN

predict

In [7]:
import json
from tqdm.auto import tqdm
with open('/home/pgajo/working/food/data/TASTEset/data/TASTEset_updated_en-it.json', encoding='utf8') as f:
    data = json.load(f)
# print(data['annotations'][0])

for entry in tqdm(data['annotations'], total=len(data['annotations'])):
    for i, entity in enumerate(entry['entities_en']):
        print('----------------------------------------------')
        # print(entry['text_en'][entity[0]:entity[1]])
        # print(entry['text_it'])
        input = tokenizer(entry['text_en'][entity[0]:entity[1]], entry['text_it'], return_tensors='pt').to('cuda:1')
        input_ids = input['input_ids'].squeeze()
        with torch.no_grad():
            outputs = model(**input)
        start_scores = outputs.start_logits
        end_scores = outputs.end_logits
        start_index_token = int(torch.argmax(start_scores))
        print('start_index_token', start_index_token)
        end_index_token = int(torch.argmax(end_scores))
        print('end_index_token', end_index_token)
        # print('len(input_ids)', len(input_ids))
        if start_index_token >= len(input_ids) - 1 or end_index_token >= len(input_ids) - 1:
            continue
        print('encoding:', input_ids)
        decoded_input = tokenizer.decode(input_ids)
        print('decoded:', decoded_input)
        for j, id in enumerate(input_ids):
            print(j, int(id), tokenizer.decode([id]), end='\t\t')
        print()
        print('prediction_tokens:', input['input_ids'].squeeze()[start_index_token:end_index_token])
        print('prediction:', tokenizer.decode(input['input_ids'].squeeze()[start_index_token:end_index_token]))
        print('gold:', [entry['text_en'][entity[0]:entity[1]]])
        char_span_start = input.token_to_chars(start_index_token)
        print('char_span_start', char_span_start)
        char_span_end = input.token_to_chars(end_index_token-1)
        print('char_span_prediction', entry['text_it'][char_span_start[0]:char_span_end[1]])
        print('char_span_end', char_span_end)
        char_span = (char_span_start[0], char_span_end[1])
        # print('char_span', char_span)
        # print('char_span[0]', char_span[0])
        # print('char_span[1]', char_span[1])
        if not char_span[0] > char_span[1]:
            entry['entities_it'].append([char_span[0], char_span[1], entry['entities_en'][i][2]])
            # print(entry['entities_it'])
        # else:
            # print('skipping')
print(data)

  0%|          | 0/700 [00:00<?, ?it/s]

----------------------------------------------
start_index_token 3
end_index_token 4
len(input_ids) 28
encoding: tensor([  101,   126,   102,   126, 14907, 10120, 52522,   125, 14907, 10120, 40159, 37913,   124, 14907, 10120, 29033, 10113, 11066, 10197, 14907, 10120, 10198, 20493,   172,   112, 13785, 13212,   102], device='cuda:1')
decoded: [CLS] 5 [SEP] 5 once di rum 4 once di triple sec 3 once di Tia Maria 20 once di succo d'arancia [SEP]
0 101 [CLS]		1 126 5		2 102 [SEP]		3 126 5		4 14907 once		5 10120 di		6 52522 rum		7 125 4		8 14907 once		9 10120 di		10 40159 triple		11 37913 sec		12 124 3		13 14907 once		14 10120 di		15 29033 Ti		16 10113 ##a		17 11066 Maria		18 10197 20		19 14907 once		20 10120 di		21 10198 su		22 20493 ##cco		23 172 d		24 112 '		25 13785 ara		26 13212 ##ncia		27 102 [SEP]		
prediction_tokens: tensor([126], device='cuda:1')
prediction: 5
gold: ['5']
char_span_start CharSpan(start=0, end=1)
char_span_prediction 5
char_span_end CharSpan(start=0, end=1)
---------

TypeError: 'NoneType' object is not subscriptable

In [None]:
# save aligned dataset to a new json with _aligned suffix 
with open('/home/pgajo/working/food/data/TASTEset/data/TASTEset_updated_en-it_QAaligned-mBERT-xl-wa.json', 'w', encoding='utf8') as f:
    json.dump(data, f, ensure_ascii=False)

In [None]:
for entry in data['annotations'][:10]:
    print(entry['text_en'])
    print(entry['text_it'])
    print(entry['entities_en'])
    print(entry['entities_it'])
    print()