In [None]:
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap.umap_ as umap

from falcon.falcon_util import *
from falcon.model_editor import *
from falcon.find_with_wiki import *

import torch, random
from transformers import AutoModelForCausalLM, AutoTokenizer

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [None]:
alg_name = 'ROME'
model_name = 'gpt2-xl'
# model_name = 'meta-llama/Meta-Llama-3-8B'
hparams_fname = f'{model_name}.json'
ds_name = 'cf'
num_edits = 1

dataset_size_limit = None
continue_from_run = None
skip_generation_tests = False
generation_test_interval = 1
conserve_memory = False
dir_name = alg_name
use_cache = False
output_hidden_states = True

In [None]:
model_editor = ModelEditor(
    alg_name, model_name, hparams_fname, ds_name,
    dataset_size_limit, continue_from_run, skip_generation_tests,
    generation_test_interval, conserve_memory, dir_name, num_edits, use_cache, output_hidden_states
)

# 원본 모델
model, tok = deepcopy(model_editor._model), model_editor._tok

# 편집된 모델
# model_editor.load_data()
# model_editor.edit()
# edited_model = model_editor._model

In [None]:
def load_data(in_file_path, select_ids: set, do_edit):
    print(f'load_data() in_file_path : {in_file_path}')
    in_file = open_file(in_file_path, mode='r')
    datas = json.load(in_file)
    print(f'load_data() datas size : {len(datas)}')

    data_dict = {}
    for data in datas:
        case_id = data['case_id']

        if not do_edit:
            target_id = data['requested_rewrite']['target_true']['id']
            target = data['requested_rewrite']['target_true']['str']
        else:
            target_id = data['requested_rewrite']['target_new']['id']
            target = data['requested_rewrite']['target_new']['str']
        
        if not target_id in select_ids:
            continue

        key = f'{target_id}\t{target}'

        if not key in data_dict.keys():
            data_dict[key] = []
        data_dict[key].append(data)
    
    key_size, value_size = len(data_dict), sum(len(v) for v in data_dict.values())
    print(f'load_data() data_dict key size : {key_size}')
    print(f'load_data() data_dict value size : {value_size}\n')

    return datas, data_dict

In [None]:
def merge_data(in_file_path1, in_file_path2, out_file_path):
    in_file1 = open_file(in_file_path1, mode='r')
    datas1 = json.load(in_file1)
    print(f'merge_data() datas1 len : {len(datas1)}, in_file_path1 : {in_file_path1}')

    idx_max = -1
    for data in datas1:
        case_id = data['case_id']
        if idx_max < case_id:
            idx_max = case_id

    in_file2 = open_file(in_file_path2, mode='r')
    datas2 = json.load(in_file2)
    print(f'merge_data() datas2 len : {len(datas2)}, in_file_path2 : {in_file_path2}')

    for data in datas2:
        idx_max += 1
        data['case_id'] = idx_max
        datas1.append(data)
    
    print(f'merge_data() datas_merged len : {len(datas1)} -> out_file_path : {out_file_path}\n')
    write_json_to_file(datas1, out_file_path)

In [None]:
def tokenize(tok, text: str, do_print=False):
    text_encs = tok(text, return_tensors='pt')['input_ids'][0]

    if do_print:
        for text_enc in text_encs:
            print(f'{text_enc}\t\t->\t{tok.decode(text_enc)}')
        print()
    
    return text_encs, len(text_encs)

In [None]:
def reduce_dim(data_org, method, n_components, random_state):
    data = np.copy(data_org)

    (N, L, D) = data.shape

    # (N * L, D) 형태로 데이터 reshape
    reshaped_data = data.reshape(-1, D)  # (N * L, D)
    # print(reshaped_data.shape)

    if method == 'pca':
        reducer = PCA(n_components=n_components)    
    elif method == 'tsne':
        reducer = TSNE(n_components=n_components, random_state=random_state)
    elif method == 'umap':
        reducer = umap.UMAP(n_components=n_components, random_state=random_state)

    reduced_data = reducer.fit_transform(reshaped_data)  # (N * L, 3)
    reduced_data = reduced_data.reshape(N, L, n_components)
    return reduced_data

In [None]:
def plot_from_reduce(words, data, title, method, save_path: str, do_fix):
    (N, L, D) = data.shape
    
    # 레이어별 색상 정의
    colors = plt.cm.viridis(np.linspace(0, 1, L))  # L개의 색상 생성
    colors[-1] = np.array([1, 0, 0, 1])

    # 시각화 (3D Scatter Plot)
    fig = plt.figure(figsize=(20, 14))
    ax = fig.add_subplot(111, projection='3d')

    # 각 레이어별로 데이터 시각화
    for layer in range(L):
        layer_data = data[:, layer, :]  # 해당 레이어의 모든 데이터 (N, 3)

        # 전체 데이터 포인트 동일 색상
        ax.scatter(layer_data[:, 0], layer_data[:, 1], layer_data[:, 2], color=colors[layer], label=f'Layer {layer+1}', s=100, alpha=0.7)
        
        # 마지막 데이터 포인트만 다른 색상으로
        # ax.scatter(layer_data[:-1, 0], layer_data[:-1, 1], layer_data[:-1, 2], color=colors[layer], label=f'Layer {layer+1}', s=100, alpha=0.7)
        # ax.scatter(layer_data[-1, 0], layer_data[-1, 1], layer_data[-1, 2], color='red', s=150, alpha=0.9, edgecolor='k', label=f'Last Point (Layer {layer+1})')

        # 데이터 포인트에 레이블(words) 추가
        for i in range(N):
            if i == N-1:
                # word = f'({i})_{words[i]}' # 각 데이터에 해당하는 레이블
                word = i
                color, fontsize, weight = 'red', 12, 'bold' # 마지막 데이터 포인트만 다르게
            else:
                word = i
                color, fontsize, weight = 'black', 10, 'normal' # 마지막 데이터 포인트만 다르게
            
            # word = i
            ax.text(layer_data[i, 0], layer_data[i, 1], layer_data[i, 2], word, ha='center', va='center', alpha=0.8,
                    color=color, fontsize=fontsize, weight=weight)
        
    def _set1():
        # 범위 고정
        if do_fix:
            # if not 'tsne' in title:
            #     size_ = 50
            # else:
            #     size_ = 200
            
            # ax.set_xlim([-1*size_, size_])
            # ax.set_ylim([-1*size_, size_])
            # ax.set_zlim([-1*size_, size_])

            # 모든 축에 대해 개별적으로 최대값과 최소값 설정
            all_data = data.reshape(-1, D)  # (N * L, 3)로 변환
            x_min, y_min, z_min = all_data.min(axis=0) - 10  # 여백 추가
            x_max, y_max, z_max = all_data.max(axis=0) + 10  # 여백 추가

            # 각 축의 범위 설정
            ax.set_xlim([x_min, x_max])
            ax.set_ylim([y_min, y_max])
            ax.set_zlim([z_min, z_max])

            ax.view_init(elev=30, azim=30)  # 고정된 각도에서 시각화 (elev는 고도, azim은 방위각)
        
        # 범위 자동 설정: 전체 데이터의 최소/최대값 계산
        else:
            all_data = data.reshape(-1, D)  # (N * L, 3)로 변환
            x_min, y_min, z_min = all_data.min(axis=0) - 5  # 여백 추가
            x_max, y_max, z_max = all_data.max(axis=0) + 5  # 여백 추가

            # 축 범위 설정
            ax.set_xlim([x_min, x_max])
            ax.set_ylim([y_min, y_max])
            ax.set_zlim([z_min, z_max])

    _set1()


    # 그래프 꾸미기
    vi_title = f'[ {title} ] with {method}'
    ax.set_title(f'3D Visualization of Data across Layers {vi_title}')
    ax.set_xlabel('Component 1')
    ax.set_ylabel('Component 2')
    ax.set_zlabel('Component 3')
    ax.legend()
    plt.show()

    if not do_fix:
        save_path = save_path.format(title, 'auto', title, method)
    else:
        save_path = save_path.format(title, 'fix', title, method)

    make_parent(save_path)
    fig.savefig(save_path, dpi=300, bbox_inches='tight')

In [None]:
def plot_embeddings_3d_by_layer(words, embeddings, title, n_components=3, save_path='', random_state=42):
    methods = ['pca', 'tsne', 'umap']
    # methods = ['pca']

    for method in methods:
        embeddings_reduce = reduce_dim(embeddings, method, n_components, random_state)
        plot_from_reduce(words, embeddings_reduce, title, method, save_path, do_fix=False)
        plot_from_reduce(words, embeddings_reduce, title, method, save_path, do_fix=True)

In [None]:
def average_token(batch_hidden_states, do_print):
    # 하나의 엔티티가 여러 토큰으로 쪼개진 경우 -> 모든 토큰 벡터의 평균을 해당 엔티티의 벡터로 사용
    ent_states = []

    for hidden_states in batch_hidden_states:
        ent_state = hidden_states.mean(dim=0)
        ent_states.append(ent_state.detach().cpu().numpy())
    
    ent_states = np.array(ent_states)
    
    if do_print:
        shape_org = (len(batch_hidden_states),) + batch_hidden_states[0].shape
        shape = ent_states.shape

        print(f'average_token() batch_hidden_states shape : (N, T, L, D) = {shape_org}')
        print(f'average_token() ent_states shape : (N, L, D) = {shape}\n')
    
    return ent_states


def compare_embedding_in_3d(model: AutoModelForCausalLM, tok: AutoTokenizer, data_dict: dict, save_path: str, do_print=False):
    for key in data_dict.keys():
        print(f'target_id : {key}\n')

        batch_hidden_states = []
        prompt_for_gens = []
        target = ''

        ''' 데이터 하나씩 확인 '''
        for data in data_dict[key]:
            case_id = data['case_id']
            prompt = data['requested_rewrite']['prompt']
            relation_id = data['requested_rewrite']['relation_id']
            target = ' ' + data['requested_rewrite']['target_true']['str']
            target_id = data['requested_rewrite']['target_true']['id']
            subject = data['requested_rewrite']['subject']
            prompt_for_gen = prompt.format(subject)

            prompt_tok_ids, prompt_tok_len = tokenize(tok, prompt_for_gen, do_print)
            target_tok_ids, target_tok_len = tokenize(tok, target, do_print)

            if 'gpt' in model_name:
                target_tok_len += 1

            hidden_states = None
            for i in range(1, target_tok_len):
                text_gen, outputs_1 = generate(model, tok, [prompt_for_gen], 1, prompt_tok_len+i)
                prompt_for_gen = text_gen

                outputs_2 = model(**tok([prompt_for_gen], return_tensors='pt').to('cuda'))
                hidden_states = outputs_2.hidden_states

                if do_print:
                    print(f'text_gen : {text_gen}')
                    print(f'outputs : {outputs_1}\n')

                    for layer_idx, layer_hidden_state in enumerate(hidden_states):
                        print(f'Layer {layer_idx} shape: {layer_hidden_state.shape}')
                    print()
            
            '''
                원래 hidden_states 는 ['레이어 별로' / '토큰 별로' / '벡터']
                이걸 ['토큰 별로' / '레이어 별로' / '벡터']로 변환
            '''
            _hidden_states = []
            for layer_idx, layer_hidden_state in enumerate(hidden_states):            
                _hidden_states.append(layer_hidden_state)
            
            stacked_hidden_states = torch.cat(_hidden_states, dim=0)
            transposed_hidden_states = stacked_hidden_states.permute(1, 0, 2)
            # print(transposed_hidden_states.shape)

            target_hidden_states = transposed_hidden_states[(-1*(target_tok_len-1)):] # 전체 토큰이 아닌 target 부분만 가져옴
            # print(target_hidden_states.shape)

            batch_hidden_states.append(target_hidden_states)
            prompt_for_gens.append(prompt.format(subject))
        

        ent_states = average_token(batch_hidden_states, do_print)
        (N, L, D) = ent_states.shape

        # ent_states = ent_states.reshape(N, 7, 7, D)
        # ent_states = ent_states.mean(axis=2)
        # print(ent_states.shape)

        plot_embeddings_3d_by_layer(prompt_for_gens, ent_states, target.strip(), 3, save_path)

In [None]:
def run(data_dir, ds_name, model_name, select_ids: set, do_print=False):
    in_path = f'{data_dir}/find_with_wiki_from_{ds_name}_{model_name}'
    print(f'in_path : {in_path}')

    in_file_path_org = f'{in_path}/json/find_knowledge_in_model_with_wiki.json'
    in_file_path_edit = f'{in_path}/json/find_knowledge_in_model_with_wiki_edit.json'
    in_file_path_new = f'{in_path}/json/find_knowledge_in_model_with_wiki_new.json'
    in_file_path_merged = f'{in_path}/json/find_knowledge_in_model_with_wiki_merged.json'

    # 데이터 로드
    datas_org, data_dict_org = load_data(in_file_path_org, select_ids, False)
    datas_edit, data_dict_edit = load_data(in_file_path_edit, select_ids, True)
    merge_data(in_file_path_org, in_file_path_new, in_file_path_merged)
    datas_merged, data_dict_merged = load_data(in_file_path_merged, select_ids, False)

    
    # 1. 원본 모델, '내재된 지식' 데이터
    save_path = f'{in_path}/images/' + '{}/lim_{}/{}_{}_1_org_org.jpg'
    compare_embedding_in_3d(model, tok, data_dict_org, save_path, do_print)

    # key 별로 편집하고, 시각화
    for data_key in data_dict_org.keys():
        # 모델 편집
        _datas = data_dict_edit[data_key]
        model_editor.edit_ext_datas(_datas)
        edited_model = model_editor._model

        # 2. 편집 모델, '내재된 지식' 데이터
        _data_dict = {data_key: data_dict_org[data_key]}
        save_path = f'{in_path}/images/' + '{}/lim_{}/{}_{}_2_edit_org.jpg'
        compare_embedding_in_3d(edited_model, tok, _data_dict, save_path, do_print)

        # 3. 편집 모델, '내재된 지식 + 편집' 데이터
        _data_dict = {data_key: data_dict_merged[data_key]}
        save_path = f'{in_path}/images/' + '{}/lim_{}/{}_{}_3_edit_eidt.jpg'
        compare_embedding_in_3d(edited_model, tok, _data_dict, save_path, do_print)

        # 편집된 가중치 복원
        model_editor.restore_weights()

In [None]:
data_dir = f'./data'
ds_name = 'mquake'
model_name = 'gpt'

select_ids = ['Q61', 'Q62', 'Q49', 'Q95', 'Q2283']
run(data_dir, ds_name, model_name, set(select_ids))

In [None]:
# ds = []

# d1 = torch.randn(1, 2, 3)
# print(f'{d1.shape}\n{d1}\n')

# d2 = torch.randn(1, 2, 3)
# print(f'{d2.shape}\n{d2}\n')

# ds.append(d1)
# ds.append(d2)

# stacked_ds = torch.cat(ds, dim=0)
# print(f'{stacked_ds.shape}\n{stacked_ds}\n')

# transposed_ds = stacked_ds.permute(1, 0, 2)
# print(f'{transposed_ds.shape}\n{transposed_ds}\n')