In [1]:
%cd ~/corpus

/home/sibava/corpus


In [2]:
import json
from transformers import T5Tokenizer
import pprint
import copy
import itertools
from tqdm.auto import tqdm
import glob
import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer = T5Tokenizer.from_pretrained('megagonlabs/t5-base-japanese-web')

In [4]:
case_en_ja = {
	'ga':'が',
	'o':'を',
	'ni':'に',
	'yotte':'によって'
}

In [5]:
def extract_gold_arguments(gold_labels:list):
	gold_arguments = {}
	case_types = {}
	argument_and_case = ''
	for gold in gold_labels:
		case_type = gold['case_type']
		case_name = gold['case_name']
		case_types[case_name] = case_type
		if case_type == 'null':
			gold_arguments[case_name] = []
			continue
		elif case_type.startswith('exog'):
			golds_chain = ['<extra_id_99>']
			closest_gold_argument = '<extra_id_99>'
		elif case_type.startswith('exo1'):
			golds_chain = ['<extra_id_98>']
			closest_gold_argument = '<extra_id_98>'
		elif case_type.startswith('exo2'):
			golds_chain = ['<extra_id_97>']
			closest_gold_argument = '<extra_id_97>'
		else:
			golds_chain = [sents[x['sent_idx']][x['word_idx']] for x in gold['gold_positions']]
			for x in gold['gold_positions']:
				if(x['case_type'] == case_type):
					closest_gold_argument = sents[x['sent_idx']][x['word_idx']] 
					#the gold argument that is closest to predicate.In inter setting,the first argument is selected.
					break
		argument_and_case += closest_gold_argument + case_en_ja[case_name]
		gold_arguments[case_name] = list(set(golds_chain))
	return gold_arguments,argument_and_case,case_types

In [8]:
active_path_list = glob.glob("/home/sibava/ZAR-konno/data/ntc_add_yotte_active_a/*")
passive_path_list = glob.glob("/home/sibava/ZAR-konno/data/ntc_add_yotte_passive_a/*")
output_dir_path = "/home/sibava/PAS-T5/pas-dataset-yotte"

for active_input_path,passive_input_path in zip(active_path_list,passive_path_list):
	active_lines =  open(active_input_path,mode='r').readlines()
	passive_lines =  open(passive_input_path,mode='r').readlines()
	outputfile_path = os.path.join(output_dir_path,os.path.basename(active_input_path))
	outputfile = open(outputfile_path,mode='w')
	for i in tqdm(range(len(active_lines))):
		doc = json.loads(active_lines[i])
		passive_doc = json.loads(passive_lines[i])
		sents = doc["sents"]
		pas_list = doc["pas_list"]
		passive_pas_list = passive_doc["pas_list"]
		for j in range(len(pas_list)):
			pas = pas_list[j]
			pas_dict = json.loads(pas)
			predicate = "".join([sents[pas_dict["prd_sent_idx"]][i] for i in pas_dict["predicate_positions"]])
			predicate_positions = pas_dict["predicate_positions"]
			prd_sent_idx = pas_dict["prd_sent_idx"]
			alt_type = pas_dict["alt_type"]
			
			context = copy.deepcopy(sents[0:pas_dict['prd_sent_idx'] + 1])
			context[prd_sent_idx].insert(predicate_positions[0],"<extra_id_0>")
			context[prd_sent_idx].insert(predicate_positions[-1] + 2,"<extra_id_1>")
			concated_context = ''.join(list(itertools.chain.from_iterable(context))) + tokenizer.eos_token
			padded_input_ids = tokenizer(concated_context,max_length=512,padding="max_length").input_ids
			
			active_gold_arguments,active_argument_and_case,active_case_types = extract_gold_arguments(pas_dict['gold_labels'])
			active_labels = active_argument_and_case + predicate
			if(alt_type == "passive"):
				passive_pas_dict = json.loads(passive_pas_list[j])
				passive_gold_arguments,passive_argument_and_case,_ = extract_gold_arguments(passive_pas_dict['gold_labels'])
				passive_labels = passive_argument_and_case + predicate
				labels = tokenizer(passive_labels + '<extra_id_2>' + active_labels ,max_length=30,padding="max_length").input_ids
			else:
				labels = tokenizer(active_labels + '<extra_id_2>' + active_labels ,max_length=30,padding="max_length").input_ids
			t5_pas_dict = {'input_ids':padded_input_ids[len(padded_input_ids)-512:]}
			t5_pas_dict['input_tokens'] = concated_context
			t5_pas_dict['labels'] = labels 
			t5_pas_dict['gold_arguments'] = active_gold_arguments
			t5_pas_dict['case_types'] = active_case_types
			t5_pas_dict['predicate'] = predicate
			t5_pas_dict['alt_type'] = alt_type
			outputfile.write(json.dumps(t5_pas_dict) + '\n')
	outputfile.close()


100%|██████████| 1751/1751 [01:37<00:00, 17.97it/s]
100%|██████████| 480/480 [00:16<00:00, 28.30it/s]
100%|██████████| 696/696 [00:35<00:00, 19.70it/s]


In [22]:
pprint.pprint(t5_pas_dict)

{'alt_type': 'active',
 'case_types': {'ga': 'dep', 'ni': 'null', 'o': 'null', 'yotte': 'null'},
 'gold_arguments': {'ga': ['先'], 'ni': [], 'o': [], 'yotte': []},
 'input_ids': [263,
               17658,
               330,
               374,
               260,
               296,
               652,
               368,
               2336,
               17334,
               259,
               25010,
               262,
               2887,
               265,
               13488,
               9837,
               338,
               1455,
               659,
               259,
               28024,
               263,
               6336,
               259,
               1586,
               5529,
               308,
               20412,
               262,
               338,
               1455,
               1715,
               260,
               1819,
               9879,
               9671,
               558,
               259,
               1755,
            