In [1]:
import json
import time
import pprint
import random
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    T5Config
)
# import MeCab
# import unidic
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 乱数シードの設定
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
device = 'cuda:0'

In [3]:
def replace_t5tok_2_sptok(t5_text:str) -> str:
	t5_text = t5_text.replace('<pad>','')
	t5_text = t5_text.replace('</s>','')
	t5_text = t5_text.replace('<extra_id_99>','→')
	t5_text = t5_text.replace('<extra_id_98>','↑')
	t5_text = t5_text.replace('<extra_id_97>','←')
	return t5_text
def replace_sptok_2_t5tok(sp_text:str) -> str:
	sp_text = sp_text.replace('→','<extra_id_99>')
	sp_text = sp_text.replace('↑','<extra_id_98>')
	sp_text = sp_text.replace('←','<extra_id_97>')
	return sp_text
def extrace_argument(output_ids:list,case:str,is_alt:bool=False) -> str:
    sep_id = mega_tokenizer.encode('<extra_id_2>',add_special_tokens=None)[0]
    if(is_alt):
        if(sep_id in output_ids.tolist()):
            sep_index = output_ids.tolist().index(sep_id)
            output_ids = output_ids[sep_index + 1:]
        else:
            print('sep token does not exist in output tokens')
    ja_case = case_en_ja[case]
    # output_tokens = tagger.parse(replace_t5tok_2_sptok(mega_tokenizer.decode(output_ids))).split()
    decoded_str = replace_t5tok_2_sptok(mega_tokenizer.decode(output_ids))
    decoded_str = re.sub(r'\s','',decoded_str)
    juman_result = jumanpp.analysis(decoded_str)
    output_tokens = [mrph.midasi for mrph in juman_result.mrph_list()]
    #「頼もしさ」の「さ」がガ格の項とき出力は「さが政府に求められる」となり解析不可？
    indices = [i for i, x in enumerate(output_tokens) if x == ja_case]
    if(len(indices) < 1):
        return ''
    else:
        return replace_sptok_2_t5tok(output_tokens[indices[0] - 1])
        #サブワードの末尾を取る場合はこのトークンをsentencepieceでトークナイズしてから末尾を取る
        #ガヲニを基準としてそこから前の部分を全部ラベルとする ex)党代表が雄弁に演説した　ガ格:党代表　or 代表,長めで学習してから評価で末尾のsubwordをとればよい

In [4]:
mega_tokenizer = T5Tokenizer.from_pretrained("megagonlabs/t5-base-japanese-web")
mega_model = T5ForConditionalGeneration.from_pretrained("megagonlabs/t5-base-japanese-web").to(device)
mega_model.resize_token_embeddings(len(mega_tokenizer))
mega_model.load_state_dict(torch.load('/home/sibava/PAS-T5/models/singleArg/checkpoint-12500/pytorch_model.bin'))


mega_model.eval()


case_en_ja = {
	'ga':'が',
	'o':'を',
	'ni':'に',
	'yotte':'によって'
}


from pyknp import Jumanpp
jumanpp = Jumanpp()

In [5]:
dataset = load_dataset('json',data_files={"test":'/home/sibava/PAS-T5/datasets/pas-dataset-singleArg/test.doc.jsonl'})
input_dataset = dataset.with_format(type='torch',columns=['input_ids'])

tensor_dataloader = DataLoader(input_dataset['test'],batch_size=16)
pas_dataset = iter(dataset['test'])

Using custom data configuration default-ba31821273c3dc6d
Reusing dataset json (/home/sibava/.cache/huggingface/datasets/json/default-ba31821273c3dc6d/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)
100%|██████████| 1/1 [00:00<00:00, 27.48it/s]


In [8]:
with open('/home/sibava/PAS-T5/decoded/decoded_psa_sinigleArg.jsonl',mode='w') as f:
	pas_dataset = iter(dataset['test'])
	dataloader = DataLoader(dataset['test'],batch_size=16)
	for batch_input in tqdm(iter(tensor_dataloader)):
		outputs = mega_model.generate(batch_input['input_ids'].to(device))
		for output in outputs:
			psa_instance = next(pas_dataset)
			alt_type = psa_instance['alt_type']
			predicate = psa_instance['predicate']
			for case in ['ga','o','ni']:
				if (psa_instance['case_types'][case] == None):
					continue
				argument = extrace_argument(output,case)
				decode_dict = {
					'output_token':argument,
					'gold_arguments':psa_instance['gold_arguments'][case],
					'case_name':case,
					'case_type':psa_instance['case_types'][case],
					'predicate':predicate,
					'alt_type':alt_type,
					'output_sentence':mega_tokenizer.decode(output),
					'input_tokens':psa_instance['input_tokens']
					}
				f.write(json.dumps(decode_dict) + '\n')


  0%|          | 0/4944 [00:00<?, ?it/s]

{'alt_type': 'active',
 'case_types': {'ga': 'dep', 'ni': None, 'o': None},
 'gold_arguments': {'ga': ['姿'], 'ni': None, 'o': None},
 'input_ids': [263,
               1968,
               328,
               441,
               264,
               4733,
               8757,
               7379,
               1727,
               270,
               259,
               32099,
               638,
               1645,
               313,
               32098,
               2891,
               261,
               1582,
               21653,
               260,
               1,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
 

  0%|          | 1/4944 [00:00<1:05:47,  1.25it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'dep'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': ['釣り糸']},
 'input_ids': [269,
               1968,
               328,
               441,
               264,
               4733,
               8757,
               7379,
               1727,
               270,
               259,
               638,
               1645,
               313,
               2891,
               261,
               1582,
               21653,
               260,
               1787,
               949,
               263,
               4643,
               292,
               19883,
               19770,
               649,
               6615,
               3285,
               269,
               32099,
               10566,
               4566,
               308,
               32098,
               13491,
               26595,
               259,
               4674,
               267,
               11419,
             

  0%|          | 2/4944 [00:01<1:19:31,  1.04it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': 'dep', 'o': None},
 'gold_arguments': {'ga': None, 'ni': ['水面'], 'o': None},
 'input_ids': [267,
               1968,
               328,
               441,
               264,
               4733,
               8757,
               7379,
               1727,
               270,
               259,
               638,
               1645,
               313,
               2891,
               261,
               1582,
               21653,
               260,
               1787,
               949,
               263,
               4643,
               292,
               19883,
               19770,
               649,
               6615,
               3285,
               269,
               10566,
               4566,
               308,
               13491,
               26595,
               259,
               4674,
               267,
               11419,
               10202,
               4166,
               

  0%|          | 3/4944 [00:03<1:29:02,  1.08s/it]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': 'null', 'o': None},
 'gold_arguments': {'ga': None, 'ni': [], 'o': None},
 'input_ids': [267,
               1968,
               328,
               441,
               264,
               4733,
               8757,
               7379,
               1727,
               270,
               259,
               638,
               1645,
               313,
               2891,
               261,
               1582,
               21653,
               260,
               1787,
               949,
               263,
               4643,
               292,
               19883,
               19770,
               649,
               6615,
               3285,
               269,
               10566,
               4566,
               308,
               13491,
               26595,
               259,
               4674,
               267,
               11419,
               10202,
               4166,
               273

  0%|          | 4/4944 [00:04<1:26:12,  1.05s/it]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'null'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': []},
 'input_ids': [269,
               1968,
               328,
               14550,
               27847,
               13094,
               3501,
               1228,
               5153,
               985,
               3669,
               2581,
               261,
               5609,
               264,
               25843,
               270,
               259,
               732,
               1722,
               1237,
               7170,
               2675,
               26441,
               271,
               1461,
               806,
               1479,
               269,
               2026,
               18793,
               28353,
               260,
               795,
               2115,
               263,
               11027,
               287,
               5341,
               280,
               32099,
               

  0%|          | 5/4944 [00:04<1:17:41,  1.06it/s]

{'alt_type': 'passive',
 'case_types': {'ga': None, 'ni': None, 'o': 'intra'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': ['会', 'ジェトロ']},
 'input_ids': [269,
               1968,
               328,
               14550,
               27847,
               13094,
               3501,
               1228,
               5153,
               985,
               3669,
               2581,
               261,
               5609,
               264,
               25843,
               270,
               259,
               732,
               1722,
               1237,
               7170,
               2675,
               26441,
               271,
               1461,
               806,
               1479,
               269,
               2026,
               18793,
               28353,
               260,
               795,
               2115,
               263,
               11027,
               287,
               5341,
               11158,
               259,
  

  0%|          | 6/4944 [00:05<1:14:24,  1.11it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': 'null', 'o': None},
 'gold_arguments': {'ga': None, 'ni': [], 'o': None},
 'input_ids': [267,
               1968,
               328,
               14550,
               27847,
               13094,
               3501,
               1228,
               5153,
               985,
               3669,
               2581,
               261,
               5609,
               264,
               25843,
               270,
               259,
               732,
               1722,
               1237,
               7170,
               2675,
               26441,
               271,
               1461,
               806,
               1479,
               269,
               2026,
               18793,
               28353,
               260,
               795,
               2115,
               263,
               11027,
               287,
               5341,
               11158,
               259,
               

  0%|          | 7/4944 [00:06<1:08:04,  1.21it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': 'null', 'o': None},
 'gold_arguments': {'ga': None, 'ni': [], 'o': None},
 'input_ids': [267,
               1968,
               328,
               14550,
               27847,
               13094,
               3501,
               1228,
               5153,
               985,
               3669,
               2581,
               261,
               5609,
               264,
               25843,
               270,
               259,
               732,
               1722,
               1237,
               7170,
               2675,
               26441,
               271,
               1461,
               806,
               1479,
               269,
               2026,
               18793,
               28353,
               260,
               795,
               2115,
               263,
               11027,
               287,
               5341,
               11158,
               259,
               

  0%|          | 8/4944 [00:07<1:03:50,  1.29it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'null'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': []},
 'input_ids': [269,
               1968,
               328,
               14550,
               2077,
               331,
               1956,
               263,
               11928,
               287,
               13633,
               261,
               2023,
               4647,
               7463,
               267,
               4332,
               6787,
               796,
               3602,
               7514,
               261,
               13205,
               263,
               11027,
               287,
               259,
               32099,
               3737,
               5979,
               32098,
               260,
               1,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
               0,
     

  0%|          | 9/4944 [00:07<1:06:03,  1.25it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'dep'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': ['姿勢']},
 'input_ids': [269,
               1968,
               328,
               14550,
               2077,
               331,
               1956,
               263,
               11928,
               287,
               13633,
               261,
               2023,
               4647,
               7463,
               267,
               4332,
               6787,
               796,
               3602,
               7514,
               261,
               13205,
               263,
               11027,
               287,
               259,
               3737,
               5979,
               260,
               274,
               2179,
               279,
               6364,
               341,
               273,
               269,
               1571,
               21675,
               267,
               346,
               1465

  0%|          | 10/4944 [00:08<1:09:49,  1.18it/s]

{'alt_type': 'active',
 'case_types': {'ga': 'inter', 'ni': None, 'o': None},
 'gold_arguments': {'ga': ['首相'], 'ni': None, 'o': None},
 'input_ids': [263,
               1968,
               328,
               14550,
               2077,
               331,
               1956,
               263,
               11928,
               287,
               13633,
               261,
               2023,
               4647,
               7463,
               267,
               4332,
               6787,
               796,
               3602,
               7514,
               261,
               13205,
               263,
               11027,
               287,
               259,
               3737,
               5979,
               260,
               274,
               2179,
               279,
               6364,
               341,
               273,
               269,
               1571,
               21675,
               267,
               346,
               14

  0%|          | 11/4944 [00:09<1:09:12,  1.19it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'null'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': []},
 'input_ids': [269,
               1968,
               328,
               14550,
               2077,
               331,
               1956,
               263,
               11928,
               287,
               13633,
               261,
               2023,
               4647,
               7463,
               267,
               4332,
               6787,
               796,
               3602,
               7514,
               261,
               13205,
               263,
               11027,
               287,
               259,
               3737,
               5979,
               260,
               274,
               2179,
               279,
               6364,
               341,
               273,
               269,
               1571,
               21675,
               267,
               346,
               1465,
 

  0%|          | 12/4944 [00:10<1:12:39,  1.13it/s]

{'alt_type': 'active',
 'case_types': {'ga': 'inter', 'ni': None, 'o': None},
 'gold_arguments': {'ga': ['相'], 'ni': None, 'o': None},
 'input_ids': [263,
               1968,
               328,
               14550,
               2077,
               331,
               1956,
               271,
               13886,
               263,
               1458,
               8989,
               839,
               450,
               8964,
               1473,
               942,
               323,
               789,
               264,
               13886,
               263,
               1458,
               9526,
               9685,
               263,
               259,
               560,
               1896,
               1208,
               839,
               450,
               1473,
               942,
               261,
               7066,
               261,
               17814,
               317,
               282,
               3598,
               19671,


  0%|          | 13/4944 [00:11<1:16:50,  1.07it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'null'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': []},
 'input_ids': [269,
               1968,
               328,
               14550,
               2077,
               331,
               1956,
               271,
               13886,
               263,
               1458,
               8989,
               839,
               450,
               8964,
               1473,
               942,
               323,
               789,
               264,
               13886,
               263,
               1458,
               9526,
               9685,
               263,
               259,
               560,
               1896,
               1208,
               839,
               450,
               1473,
               942,
               261,
               7066,
               261,
               17814,
               317,
               282,
               3598,
               19671,
    

  0%|          | 14/4944 [00:12<1:18:45,  1.04it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'dep'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': ['化']},
 'input_ids': [269,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               269,
               32099,
               2853,
               289,
               32098,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               36

  0%|          | 15/4944 [00:13<1:19:07,  1.04it/s]

{'alt_type': 'active',
 'case_types': {'ga': 'dep', 'ni': None, 'o': None},
 'gold_arguments': {'ga': ['大統領'], 'ni': None, 'o': None},
 'input_ids': [263,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
             

  0%|          | 16/4944 [00:14<1:23:10,  1.01s/it]

{'alt_type': 'active',
 'case_types': {'ga': 'dep', 'ni': None, 'o': None},
 'gold_arguments': {'ga': ['の'], 'ni': None, 'o': None},
 'input_ids': [263,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
               

  0%|          | 17/4944 [00:15<1:18:43,  1.04it/s]

{'alt_type': 'passive',
 'case_types': {'ga': 'dep', 'ni': None, 'o': None},
 'gold_arguments': {'ga': ['省'], 'ni': None, 'o': None},
 'input_ids': [263,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
              

  0%|          | 18/4944 [00:16<1:15:06,  1.09it/s]

{'alt_type': 'active',
 'case_types': {'ga': 'dep', 'ni': None, 'o': None},
 'gold_arguments': {'ga': ['省'], 'ni': None, 'o': None},
 'input_ids': [263,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
               

  0%|          | 19/4944 [00:17<1:12:26,  1.13it/s]

{'alt_type': 'active',
 'case_types': {'ga': 'dep', 'ni': None, 'o': None},
 'gold_arguments': {'ga': ['者'], 'ni': None, 'o': None},
 'input_ids': [263,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
               

  0%|          | 20/4944 [00:18<1:13:02,  1.12it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': 'null', 'o': None},
 'gold_arguments': {'ga': None, 'ni': [], 'o': None},
 'input_ids': [267,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
               26

  0%|          | 21/4944 [00:19<1:15:04,  1.09it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'null'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': []},
 'input_ids': [269,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
               26

  0%|          | 22/4944 [00:20<1:20:49,  1.01it/s]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'null'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': []},
 'input_ids': [269,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
               26

  0%|          | 23/4944 [00:21<1:23:54,  1.02s/it]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'null'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': []},
 'input_ids': [269,
               1968,
               328,
               338,
               2164,
               854,
               506,
               13451,
               8757,
               10634,
               19974,
               274,
               9294,
               1073,
               3657,
               273,
               261,
               10634,
               434,
               15820,
               343,
               259,
               9162,
               972,
               3066,
               259,
               972,
               10406,
               261,
               4513,
               1147,
               5233,
               269,
               6403,
               8711,
               537,
               18199,
               3602,
               269,
               21359,
               12338,
               26

  0%|          | 24/4944 [00:22<1:23:53,  1.02s/it]

{'alt_type': 'active',
 'case_types': {'ga': None, 'ni': None, 'o': 'null'},
 'gold_arguments': {'ga': None, 'ni': None, 'o': []},
 'input_ids': [269,
               1968,
               328,
               10726,
               269,
               19600,
               280,
               12187,
               1317,
               387,
               3592,
               782,
               2142,
               262,
               11027,
               287,
               259,
               16561,
               270,
               4338,
               32099,
               11559,
               284,
               32098,
               259,
               1053,
               10688,
               1130,
               2859,
               1793,
               352,
               7136,
               269,
               3056,
               10313,
               280,
               260,
               1,
               0,
               0,
               0,
               0,
        

  0%|          | 24/4944 [00:22<1:17:31,  1.06it/s]


KeyboardInterrupt: 