In [1]:
import os, sys, shutil, re, copy, random
from typing import Tuple, Union, List
from itertools import chain
from time import time

from falcon import falcon_util
from util.globals import *
from util import nethook

from baselines.ft import FTHyperParams, apply_ft_to_model
from baselines.mend import MENDHyperParams, MendRewriteExecutor
from rome import ROMEHyperParams, apply_rome_to_model
from memit import MEMITHyperParams, apply_memit_to_model

from experiments.py.eval_utils_counterfact import compute_rewrite_quality_counterfact, test_batch_prediction, test_generation
from experiments.py.eval_utils_zsre import compute_rewrite_quality_zsre


from dsets import (
    AttributeSnippets,
    CounterFactDataset,
    MENDQADataset,
    MultiCounterFactDataset,
    get_tfidf_vectorizer
)

ALG_DICT = {
    'FT': (FTHyperParams, apply_ft_to_model),
    'MEND': (MENDHyperParams, MendRewriteExecutor().apply_to_model),
    'ROME': (ROMEHyperParams, apply_rome_to_model),
    'MEMIT': (MEMITHyperParams, apply_memit_to_model)
}


DS_DICT = {
    'mcf': (MultiCounterFactDataset, compute_rewrite_quality_counterfact),
    'cf': (CounterFactDataset, compute_rewrite_quality_counterfact),
    'zsre': (MENDQADataset, compute_rewrite_quality_zsre),
}


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

seed = 7
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


import warnings
warnings.filterwarnings('ignore')

In [2]:
def _init_model(model_name):
    if type(model_name) is str:
        print('# init_model() Instantiating model')
        model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
        tok = AutoTokenizer.from_pretrained(model_name)
        tok.pad_token = tok.eos_token
    else:
        model, tok = model_name
        model_name = model.config._name_or_path
    print(f'\tmodel : {type(model)}')
    print(f'\ttokenizer : {type(tok)}\n')

    return model, tok

In [3]:
def _load_data(ds_name, tok, dataset_size_limit=None):
    ds_class, ds_eval_method = DS_DICT[ds_name]
    print(f'# data_load() ds_class : {ds_class}')
    print(f'# data_load() ds_eval_method : {ds_eval_method}\n')
    ds = ds_class(DATA_DIR, tok=tok, size=dataset_size_limit)

    return ds

In [4]:
model_name = 'gpt2-xl'
model, tok = _init_model(model_name)

# init_model() Instantiating model
	model : <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>
	tokenizer : <class 'transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast'>



In [5]:
# ds_name = 'mcf'
# dataset_size_limit = None
# ds = load_data(ds_name, tok, dataset_size_limit)

In [6]:
# def _generate(model, tok, prompts: List[str], top_k, max_out_len):
#     outputs = model.generate(**tok(prompts, return_tensors='pt').to('cuda'),
#                              top_k=top_k, max_length=max_out_len,
#                              do_sample=False, num_beams=1,
#                              pad_token_id=tok.eos_token_id
#     )
    
#     return str(tok.decode(outputs[0], skip_special_token=True)).replace('\n', '\t').replace('[\t]+', '\t')

# 여기서부턴 일회성(임시) 함수

In [7]:
def _write_sorted(out_dict: dict, out_file_path: str):
    print(f'_write_sorted() dict : {len(out_dict)} -> {out_file_path}')
    
    sorted_dict = falcon_util.sorted_dict(out_dict)

    with open(out_file_path, 'w') as f:
        for key, value in sorted_dict.items():
            f.write(f'{key}\t{value}\n')
        f.close()

In [8]:
import random

def _random_sampling(datas, n=100, seed=7):
    random.seed(seed)
    sampled = random.sample(datas, n)
    
    print(f'_random_sampling() datas : {len(datas)}, sampled : {len(sampled)}')
    return sampled

In [9]:
def _rep_do_simple(datas):
    ret_datas = []
    for data in datas:
        case_id = data['case_id']
        subject = data['requested_rewrite']['subject']
        prompt = data['requested_rewrite']['prompt']
        
        ret_data = {'case_id': case_id, 'subject': subject, 'prompt': prompt}
        ret_datas.append(ret_data)
    
    return ret_datas


def _write_datas(out_file_path, datas, do_simple=False):
    if do_simple:
        write_datas = _rep_do_simple(datas)
    else:
        write_datas = datas

    print(f'_write_datas() datas : {len(write_datas)} -> {out_file_path}')

    falcon_util.make_parent(out_file_path)
    f = falcon_util.open_file(out_file_path, mode='w')
    f.write(falcon_util.to_json_str(write_datas))
    f.close()

In [10]:
def _rm_dup_subj(datas, ext_n):
    subject_set = set()
    rm_dup_subj_datas = []

    for data in datas:
        subject = data['requested_rewrite']['subject']

        if not subject in subject_set:
            rm_dup_subj_datas.append(data)
            subject_set.add(subject)

            if len(rm_dup_subj_datas) == ext_n:
                break

    print(f'_get_rm_dup_subj() datas : {len(datas)}, rm duplicated subject : {len(rm_dup_subj_datas)}')
    return rm_dup_subj_datas

In [11]:
def _get_subj_rela_dict(datas):
    subj_rela_dict = {}

    for data in datas:
        subject = data['requested_rewrite']['subject']

        if not subject in subj_rela_dict.keys():
            subj_rela_dict[subject] = [data]
        else:
            subj_rela_dict[subject].append(data)

    print(f'_get_subj_rela_dict() datas : {len(datas)}, subj_rela_dict size : {len(subj_rela_dict)}')
    return subj_rela_dict

In [12]:
def _get_dup_subj_lists(datas):
    subj_rela_dict = _get_subj_rela_dict(datas)

    dup_sizes = {}
    for subject in subj_rela_dict.keys():
        subject_datas = subj_rela_dict[subject]        
        subject_datas_len = len(subject_datas)
        falcon_util.add_dict_freq(dup_sizes, subject_datas_len)
    print(f'_get_dup_subj_lists() dup_sizes : {dup_sizes}')

    dup_subj_datas_lists = [[] for _ in range(len(dup_sizes))]

    for subject in subj_rela_dict.keys():
        subject_datas = subj_rela_dict[subject]        
        subject_datas_len = len(subject_datas)
        dup_subj_datas_lists[subject_datas_len-1].extend(subject_datas)
    
    return dup_subj_datas_lists

In [13]:
def _sr_swap(prompt: str, subject: str):
    eojs = prompt.split()
    
    idx = -1
    eoj, suf = '#%#', '#%#'

    for i, eoj in enumerate(eojs):
        idx = i
        if eoj == '{}\'s':
            suf = '\'s'
        elif eoj == '{}?':
            suf = '?'
        elif eoj == '{},':
            suf = ','
        elif eoj == '{}':
            suf = ''
        else:
            idx = -1
        
        if idx != -1:
            break
    # print(f'{idx} : {eoj} : {suf}')

    
    if idx == -1:
        print(f'_sr_swap() prompt error : {prompt}')
        return None, None

    
    if idx == 0:
        relation = ' '.join(eojs[1:])
        prompt_sr_swap = f'{subject}{suf}' + ' {}'
    else:
        left = ' '.join(eojs[:idx])
        right = ' '.join(eojs[idx+1:])

        if 'play?' in right:
            relation = right
            prompt_sr_swap = f'{left} {subject}{suf}' + ' {}'
        elif left == 'The headquarter of' and right == 'is located in':
            relation = right
            prompt_sr_swap = f'{left} {subject}{suf}' + ' {}'
        elif len(left) < len(right):
            relation = right
            prompt_sr_swap = f'{left} {subject}{suf}' + ' {}'
        elif len(left) > len(right):
            relation = left
            prompt_sr_swap = '{} ' + f'{subject}{suf} {right}'
        else:
            return None, None
    
    return prompt_sr_swap, relation

In [14]:
# prep_list = ['in', 'of', 'by', 'from', 'on', 'the', 'is', 'as', 'at', 'for', 'does', 'to', 'a', 'after', 'an', 'with', 'label', ':']
prep_list = ['in', 'of', 'by', 'from', 'on', 'the', 'is', 'as', 'at', 'for', 'does', 'to', 'a', 'after', 'an', 'with', 'label', 'with']
prep_set = set(prep_list)
print(f'prep_set size : {len(prep_set)}\n{prep_set}\n')

def _check_prep(data: dict, do_print=False):
    prompt = data['requested_rewrite']['prompt']
    subject = data['requested_rewrite']['subject']

    prompt_sr_swap, relation = _sr_swap(prompt, subject)

    relation_toks = falcon_util.trim(relation.split(), True)

    idx = len(relation_toks)

    for i in range(len(relation_toks)-1, -1, -1):
        if relation_toks[i] in prep_set:
            idx = i
        else:
            break
    
    relation_rmd_prep = ' '.join(relation_toks[:idx])
    
    if relation == relation_rmd_prep:
        is_rm_prep = False
    else:
        is_rm_prep = True

        if not relation.startswith(relation_rmd_prep):
            print(f'# error:\nrelation : {relation}\nrelation_rmd_prep : {relation_rmd_prep}')
    
    rmd_prep = ' '.join(relation_toks[idx:])

    if is_rm_prep:
        if prompt_sr_swap[:2] == '{}':
            prompt_sr_swap_rmd_prep = '{}' + f' {rmd_prep}{prompt_sr_swap[2:]}'
        else:
            prompt_sr_swap_rmd_prep = f'{prompt_sr_swap} {rmd_prep}'
    else:
        prompt_sr_swap_rmd_prep = ''

    
    if do_print:
        print(f'prompt : {prompt}')
        print(f'subject : {subject}\n')
        print(f'prompt_sr_swap : {prompt_sr_swap}')
        print(f'relation : {relation}\n')
        print(f'prompt_sr_swap_rmd_prep : {prompt_sr_swap_rmd_prep}')
        print(f'relation_rmd_prep : {relation_rmd_prep}\n')
        print(f'is_rm_prep : {is_rm_prep}, rmd_prep : {rmd_prep}\n\n')
    
    return prompt, subject, prompt_sr_swap, relation, prompt_sr_swap_rmd_prep, relation_rmd_prep, is_rm_prep, rmd_prep


# Test
# datas = [
#     {'requested_rewrite': {'prompt': '{}\'s originated in', 'subject': 'We Are Wolves'}},
#     {'requested_rewrite': {'prompt': 'The native language a of {} is', 'subject': 'Willem Johan Kolff'}},
#     {'requested_rewrite': {'prompt': 'In {}, an official language is', 'subject': 'Dutch Language Union'}}
# ]

# for data in datas:
#     _check_prep(data, True)

prep_set size : 17
{'is', 'as', 'to', 'with', 'label', 'for', 'on', 'after', 'at', 'by', 'does', 'from', 'of', 'the', 'an', 'in', 'a'}



In [15]:
random.seed(seed)

def _sr_swap_random(prompt: str, subject: str):
    eojs = prompt.split()
    eoj_len = len(eojs)

    idx_rand = -1
    while 1:
        idx_rand = random.randint(0, eoj_len-1)
        eoj_rand = eojs[idx_rand]

        if not '{}' in eoj_rand:
            break
    
    if idx_rand != -1:
        relation = eojs[idx_rand]
        eojs[idx_rand] = '#%#'
        prompt = ' '.join(eojs)
        prompt_sr_swap_random = prompt.format(subject)
        prompt_sr_swap_random = re.sub('#%#', '{}', prompt_sr_swap_random)

        return prompt_sr_swap_random, relation
    else:
        print(f'_sr_swap_random () error, prompt : {prompt}')
        return None, None



# Test
# datas = [
#     ['{}\'s in originated in', 'We Are Wolves'],
#     ['The native language a of {} is', 'Willem Johan Kolff'],
#     ['In {}, an official language is', 'Dutch Language Union']
# ]

# for data in datas:
#     prompt_sr_swap_random, relation = _sr_swap_random(data[0], data[1])
    
#     print(f'prompt : {data[0]}')
#     print(f'subject : {data[1]}\n')
#     print(f'prompt_sr_swap_random : {prompt_sr_swap_random}')
#     print(f'relation : {relation}\n\n')

In [16]:
# def __sr_swap_idx(prompt: str, subject: str, idx=-1):
#     eojs = prompt.split()

#     if eojs[idx] == '{}':
#         idx -= 1

#     relation = eojs[idx]
#     eojs[idx] = '#%#'
#     prompt = ' '.join(eojs)
#     prompt_sr_swap_idx = prompt.format(subject)
#     prompt_sr_swap_idx = re.sub('#%#', '{}', prompt_sr_swap_idx)

#     return prompt_sr_swap_idx, relation


def _sr_swap_idx(prompt: str, subject: str, idx=-1):
    eojs = prompt.format(subject).split()
    eoj_len = len(eojs)

    if 0 <= idx and eoj_len <= idx:
        idx = eoj_len-1
    elif idx < 0 and (eoj_len + idx) < 0:
        idx = 0
    
    relation = eojs[idx]
    eojs[idx] = '{}'
    prompt_sr_swap_idx = ' '.join(eojs)
    return prompt_sr_swap_idx, relation


# Test
# datas = [
#     ['{}\'s in originated in', 'We Are Wolves'],
#     ['The native language a of {} is', 'Willem Johan Kolff'],
#     ['In {}, an official language is', 'Dutch Language Union']
# ]

# for data in datas:
#     prompt_sr_swap_idx, relation = _sr_swap_idx(data[0], data[1], -2)
    
#     print(f'prompt : {data[0]}')
#     print(f'subject : {data[1]}\n')
#     print(f'prompt_sr_swap_idx : {prompt_sr_swap_idx}')
#     print(f'relation : {relation}\n\n')

In [17]:
def _print_data(data: dict, do_sr_swap=False):
    case_id = data['case_id']
    data_rq = data['requested_rewrite']

    subject = data_rq['subject']
    prompt = data_rq['prompt']

    print(f'\ncase_id : {case_id}')
    if 'org_prompt' in data_rq.keys():
        print(f'org_prompt : {data_rq["org_prompt"]}')
    if 'org_subject' in data_rq.keys():
        print(f'org_subject : {data_rq["org_subject"]}')

    print(f'\nprompt : {prompt}')
    print(f'subject : {subject}\n')

    if do_sr_swap:
        prompt_sr_swap, relation = _sr_swap(prompt, subject)

        relation_toks = falcon_util.trim(relation.split(), True)
        relation_last_tok = relation_toks[-1]

        print(f'prompt_sr_swap : {prompt_sr_swap}')
        print(f'relation : {relation}')
        print(f'relation_last_tok : {relation_last_tok}\n')

In [18]:
# def _print_known(ent: dict):
#     known_id = ent['known_id']
#     subject = ent['subject']
#     attribute = ent['attribute']
#     template = ent['template']
#     prediction = ent['prediction']
#     prompt = ent['prompt']
#     relation_id = ent['relation_id']

#     print(f'known_id : {known_id}')
#     print(f'subject : {subject}')
#     print(f'attribute : {attribute}')
#     print(f'template : {template}')
#     print(f'prediction : {prediction}')
#     print(f'prompt : {prompt}')
#     print(f'relation_id : {relation_id}\n')

# 여기서부턴 실제 실행 코드

In [19]:
###############################################################################
# 실행 코드
###############################################################################

do_run, do_simple = False, False

if do_run:
    out_dir = './data/hclt'

    rand_n = 10000
    ext_n = 300

    # data load
    ds_name = 'mcf'
    dataset_size_limit = None
    datas = _load_data(ds_name, tok, dataset_size_limit)

    # random n 추출
    sampled = _random_sampling(datas.data, rand_n)

    # subject가 중복되지 않은 데이터만 선별
    rm_dup_subj_datas = _rm_dup_subj(sampled, ext_n)

    out_file_path = f'{out_dir}/hclt_multi_counterfact_rand_uniq_subj_{ext_n}.json'
    _write_datas(out_file_path, rm_dup_subj_datas, do_simple)

    # subject가 중복된 데이터만을 선별
    dup_subj_datas_lists = _get_dup_subj_lists(datas)

    for i, dup_subj_datas_list in enumerate(dup_subj_datas_lists):
        out_file_path = f'{out_dir}/hclt_multi_counterfact_dup_subj_n{i+1}.json'
        _write_datas(out_file_path, dup_subj_datas_list, do_simple)

    # subject가 중복되지 않은 데이터와 중복된 데이터의 비율을 변경하면서 실험 데이터 생성
    total_len = len(rm_dup_subj_datas)
    step = 10
    step_len = int(total_len / step)

    for i in range(0, step+1):
        idx = i * step_len

        if idx == 0:
            datas_1 = rm_dup_subj_datas

        else:
            datas_1 = rm_dup_subj_datas[:-(i*step_len)]
        datas_2 = dup_subj_datas_list[1][:(i*step_len)]
        
        write_datas = []
        write_datas.extend(datas_1)
        write_datas.extend(datas_2)

        out_file_path = f'{out_dir}/hclt_multi_counterfact_experiment1_{(step-i)}-{i}.json'
        _write_datas(out_file_path, write_datas, do_simple)

    # subject가 중복된 데이터를 기준으로 하나의 batch 에서는 중복이 없고, batch 단위로는 중복되도록 실험 데이터 생성
    batch_size = 3
    batch_dup_subj_datas_list = [[] for _ in range(batch_size)]
    batch_dup_subj_datas_all = []

    for i, data in enumerate(dup_subj_datas_list[batch_size-1]):
        batch_dup_subj_datas_list[i%batch_size].append(data)

    for i, batch_dup_subj_datas in enumerate(batch_dup_subj_datas_list):
        out_file_path = f'{out_dir}/hclt_multi_counterfact_experiment2_batch{i+1}.json'
        _write_datas(out_file_path, batch_dup_subj_datas, do_simple)

        batch_dup_subj_datas_all.extend(batch_dup_subj_datas)

    out_file_path = f'{out_dir}/hclt_multi_counterfact_experiment2_batch_all.json'
    # batch_dup_subj_datas_all.reverse()
    _write_datas(out_file_path, batch_dup_subj_datas_all, do_simple)
else:
    print('아무것도 실행되지 않았음')

아무것도 실행되지 않았음


In [20]:
###############################################################################
# 실행 코드 : prompt에서 기호가 포함된 경우 확인, _sr_swap() 테스트
###############################################################################

do_run = False

if do_run:
    out_dir = './data'

    ds_name = 'mcf'
    datas = _load_data(ds_name, tok)


    prompt_in_symbol_freq_dict = dict()

    for data in datas:
        data_rq = data['requested_rewrite']
        prompt_org = data_rq['prompt']
        prompt = prompt_org

        # prompt = re.sub('{},', ' ', prompt)         # {},
        # prompt = re.sub('{}\\?', ' ', prompt)       # {}?
        # prompt = re.sub('{}\\\'s', ' ', prompt)     # {}'s
        # prompt = re.sub('{}', ' ', prompt)

        if falcon_util.contains_symbol(prompt):
            falcon_util.add_dict_freq(prompt_in_symbol_freq_dict, prompt_org)
        else:
            print(f'### error 지금은 여기 걸리면 안됨 : {prompt_org}')

    print(f'prompt_in_symbol_freq_dict size : {len(prompt_in_symbol_freq_dict)}')

    out_file_path = f'{out_dir}/multi_counterfact_prompt_freq.txt'
    _write_sorted(prompt_in_symbol_freq_dict, out_file_path)

    
    # subject-relation swap 잘 되는지 확인
    out_file_path = f'{out_dir}/multi_counterfact_prompt_sr_swap_check.txt'
    out_file = falcon_util.open_file(out_file_path, mode='w')

    subject = '#%#'
    for prompt in prompt_in_symbol_freq_dict.keys():
        prompt_sr_swap, relation = _sr_swap(prompt, subject)
        relation_rm_symbol_edge = falcon_util.remove_symbol_edge(relation)
        _a = prompt.format(subject)
        _b = prompt_sr_swap.format(relation)

        ''' text 파일 검토용 '''
        # out_file.write(f'prompt : {prompt}\n')
        # out_file.write(f'subject : {subject}\n\n')
        # out_file.write(f'prompt_sr_swap : {prompt}\n')
        # out_file.write(f'relation : {relation}\n\n')
        # if _a == _b:
        #     out_file.write(f'# subject-relation swap check passed\n\n')
        # else:
        #     out_file.write(f'# error\n_a : {_a}\n_b : {_b}\n\n')
        
        ''' excel 파일 검토용 '''
        out_line = f'{prompt}\t{subject}\t{prompt_sr_swap}\t{relation}'
        out_line += f'\t{relation_rm_symbol_edge}\t{relation==relation_rm_symbol_edge}'
        out_line += f'\t{_a}\t{_b}\t{_a == _b}\n'
        out_file.write(out_line)

    
    out_file.close()
else:
    print('아무것도 실행되지 않았음')

아무것도 실행되지 않았음


In [22]:
###############################################################################
# 실행 코드 : relation을 subject로 봤을 때의 모델 편집 성능 비교를 위한 데이터셋 구축
###############################################################################

do_run, do_print = True, False

if do_run:
    out_dir = './data/test_sets'
    rand_n = 1000
    
    ds_name = 'mcf'
    datas = _load_data(ds_name, tok)

    # subject를 key 값으로 묶고, 하나의 subject가 가지는 relation의 수 별로 다시 묶음
    data_dup_subj_lists = _get_dup_subj_lists(datas)

    # 하나의 subject가 단일 relation을 갖는 데이터
    data_uniq_subj_list = data_dup_subj_lists[0]
    print(f'\ndata_uniq_subj_list size : {len(data_uniq_subj_list)}\n')


    '''
        전치사의 제거 여부를 미리 확인
            - datas_rmd_f : prompt의 relation 부분에서 전치사가 제거되지 않는 경우 (speaks the language)
            - datas_rmd_t : prompt의 relation 부분에서 전치사가 제거되는 경우 (an official language "is")
    '''
    datas_rmd_f, datas_rmd_t = [], []
    for data in data_uniq_subj_list:
        prompt, subject, prompt_sr_swap, relation, prompt_sr_swap_rmd_prep, relation_rmd_prep, is_rm_prep, rmd_prep = _check_prep(data)
        # data['requested_rewrite']['org_prompt'] = prompt
        # data['requested_rewrite']['org_subject'] = subject
        # data['requested_rewrite']['prompt'] = prompt_sr_swap
        # data['requested_rewrite']['subject'] = relation_rmd_prep

        if not is_rm_prep:
            datas_rmd_f.append(data)
        else:
            datas_rmd_t.append(data)
    print(f'datas_rmd_f size : {len(datas_rmd_f)}')
    print(f'datas_rmd_t size : {len(datas_rmd_t)}\n')

    datas_rmd_f_rand = _random_sampling(datas_rmd_f, rand_n)
    datas_rmd_t_rand = _random_sampling(datas_rmd_t, rand_n)
    out_file_path = f'{out_dir}/multi_counterfact_uniq_subj_rmd_f_rand_{rand_n}.json'
    _write_datas(out_file_path, datas_rmd_f_rand)
    out_file_path = f'{out_dir}/multi_counterfact_uniq_subj_rmd_t_rand_{rand_n}.json'
    _write_datas(out_file_path, datas_rmd_t_rand)

    
    '''
        datas_rmd_t : prompt의 relation 부분에서 전치사가 제거되는 경우 (an official language "is")
            - 위 경우에 대하여, 실제 데이터에서 subject와 relation의 value를 swap

            - relation에 대해서 편집을 수행하기 위함
                - subject를 제외한 나머지 부분을 relation으로 사용해서 편집한 경우 (*_sr_swap.json)
                - subject를 제외한 나머지 부분을 relation으로 사용하되, 전치사 부분을 제거하고 편집한 경우 (*_sr_swap_rmd_prep.json)
    '''
    datas_rmd_t_rand_sr_swap = []
    datas_rmd_t_rand_sr_swap_rmd_prep = []
    datas_rmd_t_rand_sr_swap_random = []
    
    for data in datas_rmd_t_rand:
        prompt, subject, prompt_sr_swap, relation, prompt_sr_swap_rmd_prep, relation_rmd_prep, is_rm_prep, rmd_prep = _check_prep(data)

        data_1 = copy.deepcopy(data)
        data_1['requested_rewrite']['prompt'] = prompt_sr_swap
        data_1['requested_rewrite']['subject'] = relation
        datas_rmd_t_rand_sr_swap.append(data_1)

        data_2 = copy.deepcopy(data)
        data_2['requested_rewrite']['prompt'] = prompt_sr_swap_rmd_prep
        data_2['requested_rewrite']['subject'] = relation_rmd_prep
        datas_rmd_t_rand_sr_swap_rmd_prep.append(data_2)

        prompt_sr_swap_random, relation_random = _sr_swap_random(prompt, subject)
        data_3 = copy.deepcopy(data)
        data_3['requested_rewrite']['prompt'] = prompt_sr_swap_random
        data_3['requested_rewrite']['subject'] = relation_random
        datas_rmd_t_rand_sr_swap_random.append(data_3)

    out_file_path = f'{out_dir}/multi_counterfact_uniq_subj_rmd_t_rand_{rand_n}_sr_swap.json'
    _write_datas(out_file_path, datas_rmd_t_rand_sr_swap)

    out_file_path = f'{out_dir}/multi_counterfact_uniq_subj_rmd_t_rand_{rand_n}_sr_swap_rmd_prep.json'
    _write_datas(out_file_path, datas_rmd_t_rand_sr_swap_rmd_prep)

    out_file_path = f'{out_dir}/multi_counterfact_uniq_subj_rmd_t_rand_{rand_n}_sr_swap_random.json'
    _write_datas(out_file_path, datas_rmd_t_rand_sr_swap_random)

    '''
        무조건 마지막 토큰, 또는 마지막 이전 토큰 등의 단순한 토큰 index로 편집 대상 토큰 설정
    '''
    idxs = [0, 1, 2, -3, -2, -1]
    for idx in idxs:

        datas_rmd_t_rand_sr_swap_idx = []
        for data in datas_rmd_t_rand:
            prompt = data['requested_rewrite']['prompt']
            subject = data['requested_rewrite']['subject']

            prompt_sr_swap_idx, relation_idx = _sr_swap_idx(prompt, subject, idx)

            data_1 = copy.deepcopy(data)
            data_1['requested_rewrite']['prompt'] = prompt_sr_swap_idx
            data_1['requested_rewrite']['subject'] = relation_idx
            datas_rmd_t_rand_sr_swap_idx.append(data_1)
        
        out_file_path = f'{out_dir}/multi_counterfact_uniq_subj_rmd_t_rand_{rand_n}_sr_swap_idx_{str(idx)}.json'
        _write_datas(out_file_path, datas_rmd_t_rand_sr_swap_idx)

else:    
    print('아무것도 실행되지 않았음')

# data_load() ds_class : <class 'dsets.counterfact.MultiCounterFactDataset'>
# data_load() ds_eval_method : <function compute_rewrite_quality_counterfact at 0x7f0610955700>

Loaded dataset with 20877 elements

_get_subj_rela_dict() datas : 20877, subj_rela_dict size : 20099
_get_dup_subj_lists() dup_sizes : {1: 19366, 2: 693, 3: 35, 4: 5}

data_uniq_subj_list size : 19366

datas_rmd_f size : 2655
datas_rmd_t size : 16711

_random_sampling() datas : 2655, sampled : 1000
_random_sampling() datas : 16711, sampled : 1000
_write_datas() datas : 1000 -> ./data/test_sets/multi_counterfact_uniq_subj_rmd_f_rand_1000.json
_write_datas() datas : 1000 -> ./data/test_sets/multi_counterfact_uniq_subj_rmd_t_rand_1000.json
_write_datas() datas : 1000 -> ./data/test_sets/multi_counterfact_uniq_subj_rmd_t_rand_1000_sr_swap.json
_write_datas() datas : 1000 -> ./data/test_sets/multi_counterfact_uniq_subj_rmd_t_rand_1000_sr_swap_rmd_prep.json
_write_datas() datas : 1000 -> ./data/test_sets/multi_counterfac