In [1]:
import json
import time
import pprint
import random
import re
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    T5Config
)
import MeCab
import unidic

# 乱数シードの設定
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
device = 'cuda:0'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def replace_t5tok_2_sptok(t5_text:str) -> str:
	t5_text = t5_text.replace('<pad>','')
	t5_text = t5_text.replace('</s>','')
	t5_text = re.sub(r'\s','',t5_text)
	t5_text = t5_text.replace('<extra_id_99>','→')
	t5_text = t5_text.replace('<extra_id_98>','↑')
	t5_text = t5_text.replace('<extra_id_97>','←')
	return t5_text
def replace_sptok_2_t5tok(sp_text:str) -> str:
	sp_text = sp_text.replace('→','<extra_id_99>')
	sp_text = sp_text.replace('↑','<extra_id_98>')
	sp_text = sp_text.replace('←','<extra_id_97>')
	return sp_text

In [3]:
def extract_argument(output_ids:list,case:str,is_alt:bool=False) -> str:
    sep_id = mega_tokenizer.encode('<extra_id_2>',add_special_tokens=None)[0]
    if(is_alt):
        if(sep_id in output_ids.tolist()):
            sep_index = output_ids.tolist().index(sep_id)
            output_ids = output_ids[sep_index + 1:]
        else:
            print('sep token does not exist in output tokens')
    ja_case = case_en_ja[case]
    # output_tokens = tagger.parse(replace_t5tok_2_sptok(mega_tokenizer.decode(output_ids))).split()
    decoded_str = replace_t5tok_2_sptok(mega_tokenizer.decode(output_ids))
    
    juman_result = jumanpp.analysis(decoded_str)
    output_tokens = [mrph.midasi for mrph in juman_result.mrph_list()]
    #「頼もしさ」の「さ」がガ格の項とき出力は「さが政府に求められる」となり解析不可？
    indices = [i for i, x in enumerate(output_tokens) if x == ja_case]
    if(len(indices) < 1):
        return ''
    else:
        return replace_sptok_2_t5tok(output_tokens[indices[0] - 1])
        #サブワードの末尾を取る場合はこのトークンをsentencepieceでトークナイズしてから末尾を取る
        #ガヲニを基準としてそこから前の部分を全部ラベルとする ex)党代表が雄弁に演説した　ガ格:党代表　or 代表,長めで学習してから評価で末尾のsubwordをとればよい

In [4]:
from pyknp import Jumanpp
jumanpp = Jumanpp()

In [5]:
case_en_ja = {
	'ga':'が',
	'o':'を',
	'ni':'に',
	'yotte':'によって'
}

In [6]:
mega_tokenizer = T5Tokenizer.from_pretrained("megagonlabs/t5-base-japanese-web")
mega_model = T5ForConditionalGeneration.from_pretrained("megagonlabs/t5-base-japanese-web").to(device)
mega_model.resize_token_embeddings(len(mega_tokenizer))
mega_model.load_state_dict(torch.load('/home/sibava/PAS-T5/models/closest_trained.pth'))
mega_model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32100, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32100, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedGeluDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo)

In [9]:
dataset = load_dataset('json',data_files={"test":'/home/sibava/PAS-T5/datasets/pas-dataset/pas_data.test.jsonl'})
input_dataset = dataset.with_format(type='torch',columns=['input_ids'])
tensor_dataloader = DataLoader(input_dataset['test'],batch_size=16)
pas_dataset = iter(dataset['test'])

Using custom data configuration default-f9a589914adc8635
Reusing dataset json (/home/sibava/.cache/huggingface/datasets/json/default-f9a589914adc8635/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)
100%|██████████| 1/1 [00:01<00:00,  2.00s/it]


In [10]:
with open('decoded_psa_juman.jsonl',mode='w') as f:
	pas_dataset = iter(dataset['test'])
	dataloader = DataLoader(dataset['test'],batch_size=16)
	for batch_input in tqdm(iter(tensor_dataloader)):
		outputs = mega_model.generate(batch_input['input_ids'].to(device))
		for output in outputs:
			psa_instance = next(pas_dataset)
			alt_type = psa_instance['alt_type']
			predicate = psa_instance['predicate']
			output_sentence = mega_tokenizer.decode(output)
			for case in ['ga','o','ni']:
				argument = extract_argument(output,case)
				if(argument == 'はがき'):
					print(output)
				decode_dict = {
					'output_token':argument,
					'gold_arguments':psa_instance['gold_arguments'][case],
					'case_name':case,
					'case_type':psa_instance['case_types'][case],
					'predicate':predicate,
					'alt_type':alt_type,
					'input_tokens':psa_instance['input_tokens'],
					'output_sentence':output_sentence
					}
				f.write(json.dumps(decode_dict) + '\n')

  3%|▎         | 50/1648 [01:16<26:59,  1.01s/it] 

tensor([    0,   262,   263,   474,   263, 18803,     1,     0,     0,     0],
       device='cuda:0')
tensor([    0, 17542,   263,   262,   263,   474,   269,  1661,   309,     1],
       device='cuda:0')
tensor([   0, 1575,  263,  262,  263,  474,  465,    1,    0,    0],
       device='cuda:0')
tensor([    0,   943,   262,   263,   474, 13287,     1,     0,     0,     0],
       device='cuda:0')


  3%|▎         | 51/1648 [01:18<29:22,  1.10s/it]

tensor([    0,   262,   263,   474, 14606,     1,     0,     0,     0,     0],
       device='cuda:0')
tensor([    0, 16765,   262,   263,   474,   269,  6624,   583,     1],
       device='cuda:0')
tensor([    0, 16765,   262,   263,   474,   269,  9864,     1,     0],
       device='cuda:0')


 34%|███▍      | 559/1648 [10:34<17:52,  1.02it/s]

tensor([    0, 32000,   263,  1560,   269,   262,   263,   474,   267,  2464,
            1], device='cuda:0')


100%|██████████| 1648/1648 [28:03<00:00,  1.02s/it]
