In [8]:
import yaml, codecs
import json

import numpy as np

from collections import OrderedDict

In [55]:
from datasets import load_from_disk

In [66]:
from transformers import (set_seed,
                          AutoTokenizer
                          )

In [68]:
model_name_or_path = 'CASE22_subtask4_xlm-roberta-base'
# Get model's tokenizer.
print('Loading tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path, use_fast=True)
# default to left padding
tokenizer.padding_side = "left"
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token

Loading tokenizer...


In [151]:
tokenizer.convert_ids_to_tokens(test_data['test'][5]['input_ids'])

['<s>',
 '▁SAM',
 'PLE',
 '_',
 'STAR',
 'T',
 '▁Cuando',
 '▁ya',
 '▁se',
 '▁había',
 '▁apa',
 'gado',
 '▁el',
 '▁fuego',
 '▁que',
 '▁devo',
 'ró',
 '▁la',
 '▁formación',
 '▁más',
 '▁moderna',
 '▁de',
 '▁la',
 '▁ex',
 '▁línea',
 '▁Sar',
 'miento',
 '▁',
 ',',
 '▁A',
 'ní',
 'bal',
 '▁Fernández',
 '▁en',
 'cendi',
 'ó',
 '▁otra',
 '▁ho',
 'guera',
 '▁que',
 '▁continúa',
 '▁prend',
 'ida',
 '▁',
 '.',
 '▁[',
 'S',
 'EP',
 ']',
 '▁Para',
 '▁explicar',
 '▁los',
 '▁incendi',
 'os',
 '▁en',
 '▁los',
 '▁tren',
 'es',
 '▁de',
 '▁la',
 '▁empresa',
 '▁T',
 'BA',
 '▁',
 ',',
 '▁el',
 '▁ministro',
 '▁de',
 '▁Seguridad',
 '▁y',
 '▁Justicia',
 '▁lan',
 'zó',
 '▁una',
 '▁acusa',
 'ción',
 '▁muy',
 '▁dura',
 '▁contra',
 '▁el',
 '▁Partido',
 '▁Obr',
 'ero',
 '▁',
 ',',
 '▁el',
 '▁M',
 'ST',
 '▁',
 ',',
 '▁Que',
 'bra',
 'cho',
 '▁y',
 '▁Proyecto',
 '▁Sur',
 '▁de',
 '▁Pino',
 '▁Sola',
 'nas',
 '▁',
 '.',
 '▁[',
 'S',
 'EP',
 ']',
 '▁La',
 '▁acusa',
 'ción',
 '▁por',
 '▁la',
 '▁presun',
 'ta',
 '▁respons

In [9]:
with open('CASE22_combined_labels.json','r') as fp:
    label_dict = json.load(fp)

In [10]:
label_dict['ids_to_labels']

{'0': 'O',
 '1': 'B-trigger',
 '2': 'B-target',
 '3': 'I-target',
 '4': 'B-place',
 '5': 'I-place',
 '6': 'B-etime',
 '7': 'I-etime',
 '8': 'B-fname',
 '9': 'I-fname',
 '10': 'B-participant',
 '11': 'I-trigger',
 '12': 'I-participant',
 '13': 'B-organizer',
 '14': 'I-organizer'}

In [148]:
langs = ['portuguese']#['english', 'portuguese','spanish']

In [153]:
for lang in langs:
    print('decoding... {}'.format(lang))
    with codecs.open('CASE22_test_{}_pred (2).yaml'.format(lang),'r', encoding='utf-8') as fp:
        data = yaml.load(fp, Loader=yaml.UnsafeLoader)

    predictions = np.argmax(data['predictions'], axis=2)
    
    test_data = load_from_disk('CASE22_hf/test_{}.hf'.format(lang))

    # Remove ignored index (special tokens)
    # true_predictions = [
    #     [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    #     for prediction, label in zip(predictions, labels)
    # ]

    sentences = []
    s = []
    with open('{}.txt'.format(lang),'r') as fp:
        for d in fp:
            if d.isspace():
                sentences.append(s)
                s = []
                continue
            s.append(d.replace('\n',''))
    sentences.append(s)

    with open('xlm-roberta-base-{}-submission.txt'.format(lang),'w') as fp:
        for i, sent in enumerate(sentences):
            tokenized_ts = tokenizer.convert_ids_to_tokens(test_data['test'][i]['input_ids'])
            predi = predictions[i]
            predi = predi[-len(tokenized_ts):]
            cnt = 0
            for s in sent:
                while (cnt < len(tokenized_ts)) and not (tokenized_ts[cnt].startswith(tokenized_ts[1][0])):
                    cnt += 1
                if cnt == 1 and tokenized_ts[cnt] == '▁SAM':
                    fp.write(s+'\t'+label_dict['ids_to_labels']['0']+'\n')
                    cnt += 1
                    continue
                if cnt < predi.shape[0]:
                    # print('data_{}'.format(label_dict['ids_to_labels'][str(predi[cnt])]))
                    fp.write(s+'\t'+label_dict['ids_to_labels'][str(predi[cnt])]+'\n')
                else:
                    fp.write(s+'\t'+label_dict['ids_to_labels']['0']+'\n')
                cnt += 1
            if i < len(sentences)-1:
                fp.write('\n')

decoding... english
decoding... portuguese
decoding... spanish
