In [1]:
import os
import sys

PATH = os.path.abspath('..')
print(f"path : {PATH}")

# 최상위 경로 저장
sys.path.append(PATH)

path : /home/nlpshlee/dev_env/git/repos/memit


In [3]:
import shutil
from typing import Tuple, Union, List
from itertools import chain
from time import time

from falcon.falcon_util import *
from util.globals import *
from util import nethook

from baselines.ft import FTHyperParams, apply_ft_to_model
from baselines.mend import MENDHyperParams, MendRewriteExecutor
from rome import ROMEHyperParams, apply_rome_to_model
from memit import MEMITHyperParams, apply_memit_to_model

from experiments.py.eval_utils_counterfact import compute_rewrite_quality_counterfact, test_batch_prediction, test_generation
from experiments.py.eval_utils_zsre import compute_rewrite_quality_zsre


from dsets import (
    AttributeSnippets,
    CounterFactDataset,
    MENDQADataset,
    MultiCounterFactDataset,
    get_tfidf_vectorizer
)

ALG_DICT = {
    'FT': (FTHyperParams, apply_ft_to_model),
    'MEND': (MENDHyperParams, MendRewriteExecutor().apply_to_model),
    'ROME': (ROMEHyperParams, apply_rome_to_model),
    'MEMIT': (MEMITHyperParams, apply_memit_to_model)
}


DS_DICT = {
    'mcf': (MultiCounterFactDataset, compute_rewrite_quality_counterfact),
    'cf': (CounterFactDataset, compute_rewrite_quality_counterfact),
    'zsre': (MENDQADataset, compute_rewrite_quality_zsre),
}


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

seed = 7
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


import warnings
warnings.filterwarnings('ignore')

In [4]:
def init_model(model_name):
    if type(model_name) is str:
        print('# init_model() Instantiating model')
        model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
        tok = AutoTokenizer.from_pretrained(model_name)
        tok.pad_token = tok.eos_token
    else:
        model, tok = model_name
        model_name = model.config._name_or_path
    print(f'\tmodel : {type(model)}')
    print(f'\ttokenizer : {type(tok)}\n')

    return model, tok

In [5]:
def load_data(ds_name, tok, dataset_size_limit):
    ds_class, ds_eval_method = DS_DICT[ds_name]
    print(f'# data_load() ds_class : {ds_class}')
    print(f'# data_load() ds_eval_method : {ds_eval_method}\n')
    ds = ds_class(DATA_DIR, tok=tok, size=dataset_size_limit)

    return ds

In [6]:
model_name = 'gpt2-xl'
model, tok = init_model(model_name)

# init_model() Instantiating model
	model : <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>
	tokenizer : <class 'transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast'>



In [7]:
ds_name = 'mcf'
dataset_size_limit = None
ds = load_data(ds_name, tok, dataset_size_limit)

# data_load() ds_class : <class 'dsets.counterfact.MultiCounterFactDataset'>
# data_load() ds_eval_method : <function compute_rewrite_quality_counterfact at 0x7fb84f2713a0>

Loaded dataset with 20877 elements


In [8]:
def generate(model, tok, prompts: List[str], top_k, max_out_len):
    outputs = model.generate(**tok(prompts, return_tensors='pt').to('cuda'),
                             top_k=top_k, max_length=max_out_len,
                             do_sample=False, num_beams=1,
                             pad_token_id=tok.eos_token_id
    )
    
    return str(tok.decode(outputs[0], skip_special_token=True)).replace('\n', '\t').replace('[\t]+', '\t')

In [None]:
###############################################################################
# 실행 코드
###############################################################################

record_dict = {}

for record in ds:
    case_id = record['case_id']
    prompt, subject, target_new, target_true = (
        record['requested_rewrite'][x] for x in ['prompt', 'subject', 'target_new', 'target_true']
    )

    prompt_with_subject = prompt.format(subject)
    target_new = target_new['str']
    target_true = target_true['str']

    value = {'case_id': case_id, 'prompt': prompt_with_subject, 'target_new': target_new, 'target_true': target_true}

    if not subject in record_dict.keys():
        record_dict[subject] = [value]
    else:
        record_dict[subject].append(value)

with open('./searched_multi_subject.txt', 'w') as f:
    for subject, values in record_dict.items():
        value_len = len(values)
        if value_len > 1:
            
            _checked = True
            for idx, value in enumerate(values):
                case_id, prompt, target_true, target_new = value['case_id'], value['prompt'], value['target_true'], value['target_new']
                org_texts = f'{prompt} {target_true}'
                gen_texts = generate(model, tok, [prompt], 1, 100)

                if not gen_texts.startswith(org_texts):
                    _checked = False
                    break
            
            if _checked:
                f.write(f'### Subject : {subject}, value_len : {value_len}\n\n')

                for idx, value in enumerate(values):
                    case_id, prompt, target_true, target_new = value['case_id'], value['prompt'], value['target_true'], value['target_new']
                    org_texts = f'{prompt} {target_true}'
                    gen_texts = generate(model, tok, [prompt], 1, 100)

                    f.write(f'[{idx}] case_id : {case_id}, prompt : {prompt}\n')
                    f.write(f'[{idx}] target_true : {target_true}\n')
                    f.write(f'[{idx}] target_new : {target_new}\n\n')
                    f.write(f'[{idx}] org_texts : {org_texts}\n')
                    f.write(f'[{idx}] gen_texts : {gen_texts}\n\n')
                    f.flush()
    f.close()

In [9]:
'''
    key를 기준으로 정렬
        - is_reverse = False : 오름 차순
        - is_reverse = True : 내림 차순
'''
def sorted_dict_key(in_dict: dict, is_reverse=False) :
    return dict(sorted(in_dict.items(), key=lambda item:item[0], reverse=is_reverse))

'''
    value를 기준으로 정렬
        - is_reverse = False : 오름 차순
        - is_reverse = True : 내림 차순
'''
def sorted_dict_value(in_dict: dict, is_reverse=False) :
    return dict(sorted(in_dict.items(), key=lambda item:item[1], reverse=is_reverse))

'''
    key를 기준으로 오름 차순 정렬, value를 기준으로 내림 차순 정렬
'''
def sorted_dict(in_dict: dict):
    return sorted_dict_value(sorted_dict_key(in_dict, False), True)

In [10]:
def remove_edge_symbols(text: str, symbols='!.,?-\'"'):
    return text.strip(symbols)

In [11]:
import re

def _post(text: str):
    return re.sub(r'[\t\n ]+', ' ', text).strip()

In [12]:
def _add_dict_freq(in_dict: dict, key: str, value=1):
    if key in in_dict.keys():
        in_dict[key] += value
    else:
        in_dict[key] = value

In [13]:
def _write_sorted(in_dict: dict, out_file_path: str):
    _sorted_dict = sorted_dict(in_dict)

    with open(out_file_path, 'w') as f:
        for key, value in _sorted_dict.items():
            f.write(f'{key}\t{value}\n')
        f.close()

In [14]:
def get_all_substrings(text):
    # 어절 단위로 문자열 분리
    words = text.split()

    # 가능한 모든 부분 문자열 리스트 생성
    temp = []
    for start in range(len(words)):
        for end in range(start + 1, len(words) + 1):
            temp.append(' '.join(words[start:end]))
    
    substrs = []
    for substr in temp:
        substr = substr.replace('{}', '')
        substr = remove_edge_symbols(substr)

        if substr[:2] == 's ':
            substr = substr[2:]
        
        substr = _post(substr)
        
        if len(substr) > 0:
            substrs.append(substr)

    return substrs

In [None]:
###############################################################################
# 실행 코드
###############################################################################

check_dict = {}

for record in ds:
    case_id = record['case_id']
    prompt, subject, target_new, target_true = (
        record['requested_rewrite'][x] for x in ['prompt', 'subject', 'target_new', 'target_true']
    )

    substrs = get_all_substrings(prompt)

    for substr in substrs:
        substr = substr.lower()

        _add_dict_freq(check_dict, substr)

sorted_check_dict = sorted_dict(check_dict)

with open('./searched_multi_substrings_from_prompt.txt', 'w') as f:
    for key, value in sorted_check_dict.items():
        key_len = len(key.split())
        f.write(f'{key_len}\t{key}\t{value}\n')
    f.close()

In [None]:
###############################################################################
# 실행 코드
###############################################################################

relation_dict = {}
cnt, cnt_relation_front, cnt_relation_mid, cnt_relation_last = 0, 0, 0, 0

for record in ds:
    case_id = record['case_id']
    prompt, subject, target_new, target_true = (
        record['requested_rewrite'][x] for x in ['prompt', 'subject', 'target_new', 'target_true']
    )

    cnt += 1
    if prompt[:2] == '{}':
        cnt_relation_front += 1

        relation = remove_edge_symbols(prompt[2:])
        relation = _post(relation)

        _add_dict_freq(relation_dict, relation)
    elif prompt[-2:] == '{}':
        cnt_relation_last += 1
    else:
        cnt_relation_mid += 1

print(f'cnt : {cnt}')
print(f'cnt_relation_front : {cnt_relation_front}')
print(f'cnt_relation_mid : {cnt_relation_mid}')
print(f'cnt_relation_last : {cnt_relation_last}')

_write_sorted(relation_dict, './searched_multi_relation_freq.txt')

In [15]:
def _print_known(ent: dict):
    known_id = ent['known_id']
    subject = ent['subject']
    attribute = ent['attribute']
    template = ent['template']
    prediction = ent['prediction']
    prompt = ent['prompt']
    relation_id = ent['relation_id']

    print(f'known_id : {known_id}')
    print(f'subject : {subject}')
    print(f'attribute : {attribute}')
    print(f'template : {template}')
    print(f'prediction : {prediction}')
    print(f'prompt : {prompt}')
    print(f'relation_id : {relation_id}\n')

In [16]:
###############################################################################
# 실행 코드
###############################################################################

in_file_path = '../data/known_1000.json'
out_file_path = '../data/known_1000_sr_switch.json'

json_dict_list = json_file_to_dict(in_file_path)
write_dict_list = []

for ent in json_dict_list:
    # _print_known(ent)

    template = ent['template']
    if template[:2] == '{}':
        relation = remove_edge_symbols(template[2:])
        relation = _post(relation)

        subject = ent['subject']
        ent['template'] = subject + ' {}'
        ent['subject'] = relation
        # _print_known(ent)

        write_dict_list.append(ent)

print(len(write_dict_list))

f = open_file(out_file_path, mode='w')
f.write(to_json_str(write_dict_list))
f.close()

1004
