In [3]:
from pathlib import Path
import sys
import os
import json
from tqdm import tqdm
from transformers import AutoTokenizer
import importlib
import difflib
from collections import defaultdict

import context
os.chdir(context.proj_dir)

import cont_gen
import cont_gen.data_process.pre_process.split_long_para
importlib.reload(cont_gen.data_process.pre_process.split_long_para)
from cont_gen.data_process.pre_process.split_long_para import process_doc_split_paras, split_to_chunk_span
from cont_gen.utils import load_jsonl, save_jsonl

In [4]:
def build_tkn(path):
    return AutoTokenizer.from_pretrained(
        path, cache_dir = '/next_share/hf_cache/hub', trust_remote_code = True
    )

# Load data
all_docs = load_jsonl('./data/cuad_clean/CUADv1_paras_merge_new.jsonl')

# tokenizer name to path
TKN_MAP = {
    'flan-t5': build_tkn('google/flan-t5-large'),
    'llama2': build_tkn('meta-llama/Llama-2-7b-hf'),
    'llama3': build_tkn('meta-llama/Meta-Llama-3-8B'),
    'mistral': build_tkn('mistralai/Mistral-7B-v0.1'),
    # 'phi1': build_tkn('microsoft/phi-1_5'),
    'phi2': build_tkn('microsoft/phi-2')
}


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [43]:

# save_data = [process_doc_split_paras(d, TKN_MAP['llama3'], 512) for d in tqdm(all_docs)]


100%|██████████| 510/510 [00:09<00:00, 53.04it/s]


In [50]:
# Process with all tokenizers
tk_names = TKN_MAP.keys()
save_path_tpl = 'data/cuad_clean/merge_split/paras_{tk_name}_512.jsonl'
for tk_name in tk_names:
    print(f'Process {tk_name}')
    save_data = [process_doc_split_paras(d, TKN_MAP[tk_name], 512) for d in tqdm(all_docs)]
    save_jsonl(save_data, save_path_tpl.format(tk_name = tk_name))

Process flan-t5


  0%|          | 0/510 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (520 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 510/510 [00:09<00:00, 52.58it/s]


Process llama2


100%|██████████| 510/510 [00:11<00:00, 44.00it/s]


Process llama3


100%|██████████| 510/510 [00:09<00:00, 51.46it/s]


Process mistral


100%|██████████| 510/510 [00:11<00:00, 44.04it/s]


Process phi1


100%|██████████| 510/510 [00:09<00:00, 54.68it/s]


Process phi2


100%|██████████| 510/510 [00:09<00:00, 56.09it/s]


In [47]:
def check_one_answer(context, answer):
    """make sure the the position of one answer is correct"""
    a_text = answer['text']
    pos_span = context[answer['start_pos']: answer['end_pos'] + 1]
    return a_text == pos_span

def check_answers_of_para(para_data):
    context = para_data['text']
    answers = [a for k in para_data['qas'] for a in k['answers']]
    res = [check_one_answer(context ,a) for a in answers]
    return all(res)

def sim_score(a, b):
    r = difflib.SequenceMatcher(a = a, b = b)
    return r.ratio()

def check_all_answer(pro_data):
    results = []
    for p_data in pro_data:
        doc_r = [check_answers_of_para(k) for k in p_data['paras']]
        results.append(all(doc_r))
    print('Error(anser match span): ', len([k for k in results if not k]))
    return results

def check_all_doc(all_docs, pro_data):
    """To check the joint of split paragraphs are equal to original paragraph"""
    results = []
    for ori_doc, p_doc in zip(all_docs, pro_data):
        ori_paras = [k['text'] for k in ori_doc['paras']]
        split_paras = [[] for _ in ori_paras]
        for para in p_doc['paras']:
            split_paras[para['old_para_idx']].append(para['text'])
        split_paras = [''.join(k) for k in split_paras]

        para_r = [a == b for a,b in zip(ori_paras, split_paras)]
        
        results.append(para_r)
    print('Doc match < 1.0: ', len([k for k in results if not all(k)]))
    return results

def check_doc_clause(ori_doc, pro_data):
    def extract_answers(doc_data):
        q2anws = defaultdict(list)
        for para in doc_data['paras']:
            for qa in para['qas']:
                anws = [k['text'] for k in qa['answers']]
                q2anws[qa['q_id']].extend(anws)
        return q2anws
    
    q2as_ori = extract_answers(ori_doc)
    q2as_pro = extract_answers(pro_data)
    assert len(q2as_ori) == len(q2as_pro)

    q2sim = {}
    for q_id in q2as_ori:
        q2sim[q_id] = sim_score(
            ''.join(q2as_ori[q_id]),
            ''.join(q2as_pro[q_id]),
        )
    return q2sim
        

In [55]:
for tk_name in TKN_MAP.keys():
    print(f'Check {tk_name}')
    save_data = load_jsonl(save_path_tpl.format(tk_name  = tk_name))

    r = check_all_answer(save_data)
    doc_match_r = check_all_doc(all_docs, save_data)
    all_q2sim = [check_doc_clause(ori_d, new_d) for ori_d, new_d in zip(all_docs, save_data)]
    clause_r = [all([v == 1.0 for v in k.values()]) for k in all_q2sim]
    print('Error (clause): ', len([k for k in clause_r if not k]))

Check flan-t5
Error(anser match span):  0
Doc match < 1.0:  0
Error (clause):  0
Check llama2
Error(anser match span):  0
Doc match < 1.0:  0
Error (clause):  0
Check llama3
Error(anser match span):  0
Doc match < 1.0:  0
Error (clause):  0
Check mistral
Error(anser match span):  0
Doc match < 1.0:  1
Error (clause):  1
Check phi1
Error(anser match span):  0
Doc match < 1.0:  0
Error (clause):  0
Check phi2
Error(anser match span):  0
Doc match < 1.0:  0
Error (clause):  0


In [56]:
save_data = load_jsonl(save_path_tpl.format(tk_name  = 'mistral'))

In [57]:
doc_match_r = check_all_doc(all_docs, save_data)

Doc match < 1.0:  1


In [60]:
error_idx = [i for i,k in enumerate(doc_match_r) if not all(k)]

ci = error_idx[0]
print(ci)

pi = [i for i,k in enumerate(doc_match_r[ci]) if not k]
print(pi)

ori_ctx = all_docs[ci]['paras'][pi[0]]['text']
sp_ctxs = [p['text'] for p in save_data[ci]['paras'] if p['old_para_idx'] == pi[0]]

print([len(k) for k in sp_ctxs])
print([k in ori_ctx for k in sp_ctxs])

210
[0]
[2138, 2018, 1428]
[True, True, True]


In [63]:
print(ori_ctx[2138 - 20: 2138 + 20])
print(sp_ctxs[0][-20:])
print(sp_ctxs[1][:20])

Target Demand ("          "). The Auctio
Target Demand ("    
       "). The Aucti


In [65]:
print(TKN_MAP['mistral'].tokenize(ori_ctx[2138 - 20: 2138 + 20]))

['▁Target', '▁Dem', 'and', '▁("', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '<0xE2>', '<0x80>', '<0xAF>', '").', '▁The', '▁A', 'uct', 'io']


### Compare Tokenizer

To check whether phi1 and phi2's tokenizers are same.

Conclusion: they are same.

In [6]:
def compare_vocab(vocab1, vocab2):
    inv_1 = {v:k for k,v in vocab1.items()}
    inv_2 = {v:k for k,v in vocab2.items()}
    tot = max(len(inv_1), len(inv_2))
    match_r = [inv_1.get(i, None) == inv_2.get(i, None) for i in range(tot)]
    return [i for i, m in enumerate(match_r) if not m]

def compare_tokenizer(tk1, tk2):
    print(tk1.vocab_size)
    print(tk2.vocab_size)
    not_match = compare_vocab(tk1.vocab, tk2.vocab)
    print(len(not_match))

In [8]:
compare_tokenizer(TKN_MAP['llama2'], TKN_MAP['llama3'])

32000
128000
128250
