In [1]:
import json
import time
import pprint
import random
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    T5Config
)
import MeCab
import unidic
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 乱数シードの設定
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
device = 'cuda:0'

In [3]:
def replace_t5tok_2_sptok(t5_text:str) -> str:
	t5_text = t5_text.replace('<pad>','')
	t5_text = t5_text.replace('</s>','')
	t5_text = t5_text.replace('<extra_id_99>','→')
	t5_text = t5_text.replace('<extra_id_98>','↑')
	t5_text = t5_text.replace('<extra_id_97>','←')
	return t5_text
def replace_sptok_2_t5tok(sp_text:str) -> str:
	sp_text = sp_text.replace('→','<extra_id_99>')
	sp_text = sp_text.replace('↑','<extra_id_98>')
	sp_text = sp_text.replace('←','<extra_id_97>')
	return sp_text
def get_index(l, x, default=False):
    return l.index(x) if x in l else -256

In [4]:
def extrace_argument(output_ids:list,case:str,is_alt:bool=False) -> str:
    sep_id = mega_tokenizer.encode('<extra_id_2>',add_special_tokens=None)[0]
    if(is_alt):
        if(sep_id in output_ids.tolist()):
            sep_index = output_ids.tolist().index(sep_id)
            output_ids = output_ids[sep_index + 1:]
        else:
            print('sep token does not exist in output tokens')
    ja_case = case_en_ja[case]
    output_tokens = tagger.parse(replace_t5tok_2_sptok(mega_tokenizer.decode(output_ids))).split()
    # decoded_str = replace_t5tok_2_sptok(mega_tokenizer.decode(output_ids))
    # decoded_str = re.sub(r'\s','',decoded_str)
    # juman_result = jumanpp.analysis(decoded_str)
    # output_tokens = [mrph.midasi for mrph in juman_result.mrph_list()]
    #「頼もしさ」の「さ」がガ格の項とき出力は「さが政府に求められる」となり解析不可？
    indices = [i for i, x in enumerate(output_tokens) if x == ja_case]
    if(len(indices) < 1):
        return ''
    else:
        return replace_sptok_2_t5tok(output_tokens[indices[0] - 1])
        #サブワードの末尾を取る場合はこのトークンをsentencepieceでトークナイズしてから末尾を取る
        #ガヲニを基準としてそこから前の部分を全部ラベルとする ex)党代表が雄弁に演説した　ガ格:党代表　or 代表,長めで学習してから評価で末尾のsubwordをとればよい

In [4]:
def extract_argument(output_ids:list,case:str,is_alt:bool=False) -> str:
    sep_id = mega_tokenizer.encode('<extra_id_2>',add_special_tokens=None)[0]
    if(is_alt):
        if(sep_id in output_ids.tolist()):
            sep_index = output_ids.tolist().index(sep_id)
            output_ids = output_ids[sep_index + 1:]
        else:
            print('sep token does not exist in output tokens')
    decoded_str = mega_tokenizer.decode(output_ids)
    decoded_str = re.sub(r'\s|<pad>','',decoded_str) 
    ga_index = get_index(decoded_str,'が')
    o_index = get_index(decoded_str,'を')
    ni_index = get_index(decoded_str,'に')
    
    ja_case = case_en_ja[case]
    if(case == 'ga'):
        if(ga_index == -256):
            return ''
        else:
            return decoded_str[:ga_index]
    elif(case == 'o'):
        if(o_index == -256):
            return ''
        elif(ga_index == -256):
            return decoded_str[:o_index]
        else:
            return decoded_str[ga_index+1:o_index]
    elif(case == 'ni'):
        if(ni_index == -256):
            return ''
        elif(o_index != -256):
            return decoded_str[o_index+1:ni_index]
        elif(o_index == -256 and ga_index != 0):
            return decoded_str[ga_index + 1:ni_index]
        else:
            return decoded_str[0:ni_index]
    else:
        return replace_sptok_2_t5tok(output_tokens[indices[0] - 1])
        #サブワードの末尾を取る場合はこのトークンをsentencepieceでトークナイズしてから末尾を取る
        #ガヲニを基準としてそこから前の部分を全部ラベルとする ex)党代表が雄弁に演説した　ガ格:党代表　or 代表,長めで学習してから評価で末尾のsubwordをとればよい

In [5]:
mega_tokenizer = T5Tokenizer.from_pretrained("megagonlabs/t5-base-japanese-web")
mega_model = T5ForConditionalGeneration.from_pretrained("megagonlabs/t5-base-japanese-web").to(device)
mega_model.resize_token_embeddings(len(mega_tokenizer))
mega_model.load_state_dict(torch.load('/home/sibava/PAS-T5/models/closest_50/checkpoint-6500/pytorch_model.bin'))


mega_model.eval()


case_en_ja = {
	'ga':'が',
	'o':'を',
	'ni':'に',
	'yotte':'によって'
}

# tagger = MeCab.Tagger('-Owakati')
# from pyknp import Jumanpp
# jumanpp = Jumanpp()


In [9]:
dataset = load_dataset('json',data_files={"test":'/home/sibava/PAS-T5/datasets/suffix/test.doc.jsonl'})
input_dataset = dataset.with_format(type='torch',columns=['input_ids'])

tensor_dataloader = DataLoader(input_dataset['test'],batch_size=16)
pas_dataset = iter(dataset['test'])

Using custom data configuration default-157737582470aa31


Downloading and preparing dataset json/default to /home/sibava/.cache/huggingface/datasets/json/default-157737582470aa31/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 2577.94it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 369.54it/s]


Dataset json downloaded and prepared to /home/sibava/.cache/huggingface/datasets/json/default-157737582470aa31/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:01<00:00,  1.23s/it]


In [12]:
with open('/home/sibava/PAS-T5/decoded/suffix.jsonl',mode='w') as f:
	pas_dataset = iter(dataset['test'])
	dataloader = DataLoader(dataset['test'],batch_size=16)
	for batch_input in tqdm(iter(tensor_dataloader)):
		outputs = mega_model.generate(batch_input['input_ids'].to(device))
		for output in outputs:
			psa_instance = next(pas_dataset)
			alt_type = psa_instance['alt_type']
			predicate = psa_instance['predicate']
			for case in ['ga','o','ni']:
				if (psa_instance['arg_types'][case] == None):
					continue
				argument = extract_argument(output,case)
				# argument = mega_tokenizer.decode(output,skip_special_tokens=True)
				decode_dict = {
					'output_token':argument,
					'gold_arguments':psa_instance['gold_arguments'][case],
					'case_name':case,
					'arg_type':psa_instance['arg_types'][case],
					'predicate':predicate,
					'alt_type':alt_type,
					'output_sentence':mega_tokenizer.decode(output),
					}
				f.write(json.dumps(decode_dict) + '\n')


100%|██████████| 1649/1649 [11:22<00:00,  2.42it/s]
