In [None]:
################# Convert LAMA into preparation file.

import os
from transformers import AutoTokenizer
import torch
import jsonlines
import re
import torch.nn as nn
from tqdm import tqdm
import json
import numpy as np
import os
import random

model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token  


# Relation configurations
config_dic = {}
with jsonlines.open('data/LAMA/relations.jsonl', 'r') as reader:
    for dic in reader:
        config_dic[dic['relation']] =dic

##TREX file list
list0 = os.listdir('data/LAMA/TREx')
list0 = sorted(list0, key=lambda x: int(x.replace('P','').replace('.jsonl','')))
TREX_relations = [x.replace('.jsonl','') for x in list0]

### invariant, variant   list
invariant_rel = ['P19','P20','P279','P37','P449','P47','P138','P364','P527','P176','P27','P407','P30','P178','P1376','P131','P1412','P17','P276','P937','P140','P103','P190','P1001','P495','P36','P740','P361']
file_list = os.listdir('dataset/data/TREx')
total_rel_list = [x.replace('.jsonl','') for x in file_list]
total_rel_list = list(sorted(total_rel_list, key=lambda x:int(x[1:])))
variant_rel = []
for file0 in total_rel_list:
    if file0 not in invariant_rel:
        variant_rel.append(file0)

### Filter and tailor into preparation file
def make_task_schematic(sub, rel, obj):
    task = f"Guess the object. \n  subject is {sub} , relation is {rel} , object is {obj}"
    return task

def make_task_descriptive(sub, obj,template):
    task = template.replace('[X]',sub).replace('[Y]',obj)
    return task



with jsonlines.open(f'temp/TReX_for_train_attention.jsonl', 'w')  as writer:
    for rel in total_rel_list:
        relation_label = config_dic[rel]['label']  # blace_of_birth, ...
        template = config_dic[rel]['template']

        if rel in TREX_relations:
            
            if rel in invariant_rel:
                invariant= True
            elif rel in variant_rel:
                invariant= False
            else:
                raise Exception('wrong')
            
            with jsonlines.open(f'data/LAMA/{rel}.jsonl', 'r') as reader1:
                for dic1 in tqdm(reader1):
                    with torch.no_grad():
                        X = dic1['sub_label'] # X
                        Y = dic1['obj_label'] # Y

                        # 1. make 2 tasks: task_descriptive,  task_schematic
                        #1)
                        task_descriptive = make_task_descriptive(X, Y, template)
                        
                        #2)
                        task_schematic = make_task_schematic(X, relation_label, Y)

                        ### position & length
                        evidences=[]
                        for ev in dic1['evidences']:
                            evidence = ev['masked_sentence']
                            evidences.append(evidence)
                        evidences = list(sorted(evidences , key=lambda x:len(x), reverse=True))
                        evidence = evidences[0]
                        position = re.search(r'\[MASK\]', evidence).span()[0] / len(evidence)
                        
                        masked_evidence = evidence 
                        evidence = masked_evidence.replace('[MASK]', Y)
                        
                        
                        write_dic ={'relation_code':rel, 'uuid': dic1['uuid'],
                                    'task_descriptive':task_descriptive,
                                    'task_schematic':task_schematic,
                                    'subject':X,
                                    'relation_label': relation_label,
                                    'object':Y, 
                                    'masked_evidence':masked_evidence, 'evidence':evidence, 
                                    'position':position,
                                    'evidence_length': len(evidence),
                                    'invariant':invariant,
                                    'scores':{}}
                        writer.write(write_dic)


In [None]:
############### Measure Accuracy with baseline models ###########
from transformers import AutoTokenizer
from TAALM import TAALM, Llama2_kadapter

model_name = "meta-llama/Llama-2-7b-hf"
adapter_file= "Llama-2-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token  

############## put in other models if you want ##################
model_7b = TAALM.init_theta(model_name="meta-llama/Llama-2-7b-hf", adapter_file="Llama-2-7b", onepiece=True).to('cuda')
model_1b = TAALM.init_theta(model_name='TinyLlama/TinyLlama-1.1B-Chat-v1.0', adapter_file="Llama-2-1b", onepiece=True).to('cuda')
model_1b_kadapter = Llama2_kadapter.init_gamma_theta(model_name='TinyLlama/TinyLlama-1.1B-Chat-v1.0' , onepiece=True).to('cuda')
model_7b_kadapter = Llama2_kadapter.init_gamma_theta(model_name="meta-llama/Llama-2-7b-hf" , onepiece=True).to('cuda')
####################################################################


# Execute calculating
def giveme_label_mask(query_token, label_token):  
    pallet = torch.zeros(len(query_token))
    for start0 in list(range(len(pallet)- len(label_token)+1))[::-1]:
        if torch.equal(query_token[start0: start0+len(label_token)], label_token):
            pallet[start0: start0+len(label_token)] =1 
            break
    return pallet[1:]

def giveme_acc(model, task_descriptive, task_descriptive_label_mask):
    response= model(**task_descriptive)
    log_probs = -nn.functional.log_softmax(response.logits[:,:-1,:], dim=-1)
    output_tk = torch.argmin(log_probs, dim=-1) 
    output_labeled = output_tk * task_descriptive_label_mask
    output_cop = output_labeled[output_labeled !=0]
    score = output_cop == object.input_ids[0]
    score = score.float().mean()
    return score



with jsonlines.open(f'temp/TREx_for_train_attention.jsonl', 'r') as reader, jsonlines.open(f'temp/TREx_for_train_attention_w_scores.jsonl', 'w') as writer:
    n=0
    for dic in tqdm(reader):
        obj = dic['object']
        task_descriptive = dic['task_descriptive']
        task_schematic = dic['task_schematic']
        object = dic['object']
        task_descriptive = tokenizer(task_descriptive, return_tensors='pt').to('cuda')
        task_schematic = tokenizer(task_schematic, return_tensors='pt').to('cuda')
        object = tokenizer(object, return_tensors='pt', add_special_tokens=False).to('cuda')

        task_descriptive_label_mask = giveme_label_mask(task_descriptive.input_ids[0] , object.input_ids[0]).to('cuda')
        task_schematic_label_mask = giveme_label_mask(task_schematic.input_ids[0] , object.input_ids[0]).to('cuda')
        
        with torch.no_grad():
            score_7b_desc = giveme_acc(model_7b, task_descriptive, task_descriptive_label_mask)
            score_1b_desc = giveme_acc(model_1b, task_descriptive, task_descriptive_label_mask)
            score_7b_kadapter_desc = giveme_acc(model_7b_kadapter, task_descriptive, task_descriptive_label_mask)
            score_1b_kadapter_desc = giveme_acc(model_1b_kadapter, task_descriptive, task_descriptive_label_mask)

            score_7b_schem = giveme_acc(model_7b, task_schematic, task_schematic_label_mask)
            score_1b_schem = giveme_acc(model_1b, task_schematic, task_schematic_label_mask)
            score_7b_kadapter_schem = giveme_acc(model_7b_kadapter, task_schematic, task_schematic_label_mask)
            score_1b_kadapter_schem = giveme_acc(model_1b_kadapter, task_schematic, task_schematic_label_mask)

            scores = {  'descriptive' : 
             {'llama_7b': score_7b_desc.cpu().numpy().tolist(),
              'llama_1b': score_1b_desc.cpu().numpy().tolist(),
              'llama_7b_kadapter': score_7b_kadapter_desc.cpu().numpy().tolist(),
              'llama_1b_kadapter': score_1b_kadapter_desc.cpu().numpy().tolist()},
              'schematic':
              {'llama_7b': score_7b_schem.cpu().numpy().tolist(),
              'llama_1b': score_1b_schem.cpu().numpy().tolist(),
              'llama_7b_kadapter': score_7b_kadapter_schem.cpu().numpy().tolist(),
              'llama_1b_kadapter': score_1b_kadapter_schem.cpu().numpy().tolist()}
            }
        dic['scores'] = scores
        # print(dic)
        writer.write(dic)
        # n+=1
        # if n>10:
            # break

In [None]:
################### Sample and Save the dataset ###################
from matplotlib import pyplot as plt


invar_descriptive = []
invar_schematic = []
variant = []
train_attention = []

TP_invar_descriptive = 0  # true positive
TP_invar_schematic = 0
TN_var_descriptive = 0  # true negative
TN_var_schematic = 0  

with jsonlines.open('temp/TREx_for_train_attention_w_scores.jsonl', 'r') as reader:
    for dic in tqdm(reader):
        evidence = dic['evidence']
        sub = dic['subject']
        obj = dic['object']
        if dic['evidence_length'] >200 and (obj in evidence and sub in evidence):
            scores = dic['scores']
            descriptive = scores['descriptive']
            schematic = scores['schematic']
            invariant = dic['invariant']
            models = descriptive.keys()
            if invariant:
                invar_descriptive_score = np.mean(list(descriptive.values()))

                invar_schematic_score = np.mean(list(schematic.values()))

                if invar_descriptive_score == 1 :
                    invar_descriptive.append(dic)

                if invar_schematic_score ==1:
                    invar_schematic.append(dic)

                if invar_descriptive_score < 0.5 and invar_schematic_score < 0.5:
                    train_attention.append(dic)

            else:
                var_descriptive_score = np.mean(list(descriptive.values()))
                var_schematic_score = np.mean(list(schematic.values()))
                if var_descriptive_score == 0 and var_schematic_score ==0:
                    variant.append(dic)
                if var_descriptive_score < 0.5 and var_schematic_score < 0.5:
                    train_attention.append(dic)


### control 'P530' relation type

p530 = 0
temp_variant=[]
for var in variant:
    if var['relation_code'] == 'P530':
        p530+=1
        if p530 >130:
            continue
        
    temp_variant.append(var)


### Sample and observe the distribution

random.seed(42)
sample_invar_descriptive =  random.sample(invar_descriptive, k=500)
random.seed(42)
sample_invar_schematic = random.sample(invar_schematic, k=500)
random.seed(42)
sample_variant = random.sample(temp_variant, k=500)
uuids = [x['uuid'] for x in sample_variant]
p530 = 0
sample_train_attention = []
for x in train_attention:
    if x['uuid'] not in uuids:
        if x['relation_code'] == 'P530':
            p530+=1
            if p530 >350:
                continue
        sample_train_attention.append(x)

file_list = os.listdir('dataset/data/TREx')
total_rel_list = [x.replace('.jsonl','') for x in file_list]
total_rel_list = list(sorted(total_rel_list, key=lambda x:int(x[1:])))

rel_dics={}
rel_codes = []
for dic in sample_train_attention: #### 여기 바꿔가면서 보면 됨
    rel_codes.append(dic['relation_code'])

for rel in total_rel_list:
    rel_dics[rel] = rel_codes.count(rel)

plt.figure(figsize=(15,1))
plt.bar(list(rel_dics.keys()), list(rel_dics.values()))
plt.xticks(rotation=90)
plt.show()


In [38]:
#### Save into dataset
with jsonlines.open('data/LAMA_ckl/invariant_descriptive.jsonl', 'w') as writer:
    writer.write_all(sample_invar_descriptive)

with jsonlines.open('data/LAMA_ckl/invariant_schematic.jsonl', 'w') as writer:
    writer.write_all(sample_invar_schematic)

with jsonlines.open('data/LAMA_ckl/variant.jsonl', 'w') as writer:
    writer.write_all(sample_variant)

with jsonlines.open('data/LAMA_ckl/train_attention_traindata.jsonl', 'w') as writer:
    writer.write_all(sample_train_attention)

159