In [18]:
from pathlib import Path
import sys
import os
os.environ['HF_HUB_CACHE'] = '/next_share/hf_cache/hub/'
import json
from tqdm import tqdm
from transformers import AutoTokenizer, PreTrainedTokenizer
import importlib
import numpy as np
import difflib
from collections import defaultdict
import pandas as pd

import context
os.chdir(context.proj_dir)

import cont_gen
import cont_gen.data_process.ood.build_src_tgt
import cont_gen.data_process.ood.build_sft_meta_data
importlib.reload(cont_gen.data_process.ood.build_src_tgt)
importlib.reload(cont_gen.data_process.ood.build_sft_meta_data)
from cont_gen.data_process.ood.build_src_tgt import process, SFT_Builder, SFT_Builder_YesNo, SFT_Builder_YesNo_Natural
from cont_gen.data_process.ood.build_sft_meta_data import CUAD_Basic, MetaSFT_Train_Builder, MetaSFT_Test_Builder
from cont_gen.data_loader.cuad_sft import CUAD_SFT_Cached
from cont_gen.utils import load_jsonl, save_jsonl

## Overview

- Get train and test label sets.
  - Files: `train_labels.csv`, `test_labels.csv`
    - Keys: `['clause_id', 'clause_type']`
- [Link](#build-meta-data) For each tokenizer's data, build train and test meta data. (`data/ood_split/{split_name}/{tokenizer_name}`)
  - Input para data: `data/cuad_clean/merge_split/paras_{tokenizer_name}_512.jsonl`
  - Files: `train_meta.csv`, `test_meta_ood.csv`, `test_meta_id.csv`
    - Keys: `['title', 'para_idx', 'q_id', 'answers', 'type']`
- [Link](#build-sft-data) For each tokenizer and prompt method, build source and target data and save to `data/ood_split/{split_name}/{tokenizer_name}/{prompt_name}`
  - Files: `train_data.jsonl`, `test_data_ood.jsonl`, `test_data_id.jsonl`
    - Keys: `['title', 'para_idx', 'q_id', 'source', 'target']`
- [Link](#chat-data) Data for Chat Model

### Build Meta Data

In [15]:
def build_train_meta(train_para_data, train_labels, output_dir, neg_clause_ratio=1.0, num_neg_quest = 1):
    """Build and save train meta data."""
    all_df = MetaSFT_Train_Builder.build_pos_neg_samples(
        train_para_data,
        train_labels,
        neg_clause_ratio=neg_clause_ratio,
        num_neg_quest=num_neg_quest)

    Path(output_dir).mkdir(parents = True, exist_ok=True)
    all_df.to_csv(Path(output_dir) / 'train_meta.csv', index = False)

    return all_df

def build_test_meta(test_para_data, test_labels, train_labels, output_dir, neg_ratio = 0.1):
    """
    Build and save test meta data.
    """
    Path(output_dir).mkdir(parents = True, exist_ok=True)
    # OOD test
    test_df = MetaSFT_Test_Builder.build_test_and_small(test_para_data, test_labels, neg_ratio = neg_ratio)
    
    test_df.to_csv(Path(output_dir) / 'test_meta_ood.csv', index = False)

    # ID test
    test_id_df = MetaSFT_Test_Builder.build_test_and_small(test_para_data, train_labels, neg_ratio = neg_ratio)
    
    test_id_df.to_csv(Path(output_dir) / 'test_meta_id.csv', index = False)

    return test_df, test_id_df

def process_meta_tokenizer(tkn_name, split_dirs, proj_dir = './'):
    """
    Build meta data for one tokenizer under multiple splits
    """
    proj_dir = Path(proj_dir)
    cuad_basic = CUAD_Basic(
        proj_dir / 'data/clause/all_info.csv',
        proj_dir / f'data/cuad_clean/merge_split/paras_{tkn_name}_512.jsonl',
        proj_dir / 'data/cuad_split/ori_train_titles.json',
        proj_dir / 'data/cuad_split/ori_test_titles.json',
    )
    for split_dir in split_dirs:
        split_dir = Path(split_dir)
        train_labels = pd.read_csv(split_dir / 'train_labels.csv')['clause_id'].to_list()
        test_labels = pd.read_csv(split_dir / 'test_labels.csv')['clause_id'].to_list()
        build_train_meta(cuad_basic.train_para_data, train_labels, split_dir / tkn_name)
        build_test_meta(cuad_basic.test_para_data, test_labels, train_labels, split_dir / tkn_name)

In [16]:
print(list(Path('data/ood_split/').glob('*')))

[PosixPath('data/ood_split/seed42_tr29'), PosixPath('data/ood_split/seed128_tr29'), PosixPath('data/ood_split/seed89_tr29')]


In [17]:
tkn_names = ['flan-t5', 'llama2', 'llama3', 'mistral', 'phi2']

split_names = ['seed42_tr29', 'seed128_tr29', 'seed89_tr29']

for tkn_name in tkn_names:
    print(f'Handle tokenizer data: {tkn_name}')
    process_meta_tokenizer(tkn_name, [f'data/ood_split/{k}' for k in split_names])

Handle tokenizer data: flan-t5
Handle tokenizer data: llama2
Handle tokenizer data: llama3
Handle tokenizer data: mistral
Handle tokenizer data: phi1
Handle tokenizer data: phi2


In [21]:
# Get some statistics
tkn_name = tkn_names[2]
split = split_names[0]
train_meta = pd.read_csv(f'data/ood_split/{split}/{tkn_name}/train_meta.csv')
test_meta_ood = pd.read_csv(f'data/ood_split/{split}/{tkn_name}/test_meta_ood.csv')
test_meta_id = pd.read_csv(f'data/ood_split/{split}/{tkn_name}/test_meta_id.csv')
print(train_meta.columns)
print(test_meta_ood.columns)
print(f'Train: {len(train_meta)}, Test OOD: {len(test_meta_ood)}, Test ID: {len(test_meta_id)}')
print(train_meta['type'].value_counts())
print(test_meta_ood['type'].value_counts())


Index(['title', 'para_idx', 'q_id', 'answers', 'type'], dtype='object')
Index(['title', 'para_idx', 'q_id', 'answers', 'type'], dtype='object')
Train: 15692, Test OOD: 67188, Test ID: 162371
type
0    5734
1    5734
2    4224
Name: count, dtype: int64
type
0    55380
3     6132
2     5093
1      583
Name: count, dtype: int64


### Build SFT Data

In [27]:
from ast import literal_eval

def process_sft_tokenizer(tkn_name, split_dirs, builder: SFT_Builder, pmt_name):
    """Build meta data for one tokenizer under multiple splits"""
    all_para_data = load_jsonl(f'data/cuad_clean/merge_split/paras_{tkn_name}_512.jsonl')
    builder.set_para_data(all_para_data)

    for split in split_dirs:
        print(f'Process {split}')
        meta_dir = Path(split) / tkn_name
        save_dir = meta_dir / pmt_name
        train_meta = pd.read_csv(meta_dir / 'train_meta.csv', converters={'answers': literal_eval})
        train_data  = process(builder, train_meta)
        save_jsonl(train_data, save_dir / 'train_data.jsonl')

        test_meta_id = pd.read_csv(meta_dir / 'test_meta_id.csv', converters={'answers': literal_eval})
        test_data_id  = process(builder, test_meta_id)
        save_jsonl(test_data_id, save_dir / 'test_data_id.jsonl')

        test_meta_ood = pd.read_csv(meta_dir / 'test_meta_ood.csv', converters={'answers': literal_eval})
        test_data_ood  = process(builder, test_meta_ood)
        save_jsonl(test_data_ood, save_dir / 'test_data_ood.jsonl')

In [28]:
tkn_names = ['flan-t5', 'llama2', 'llama3', 'mistral', 'phi2']

split_names = ['seed42_tr29', 'seed128_tr29', 'seed89_tr29']

clause_info = pd.read_csv('./data/clause/all_info.csv')

prompt_01 = open('config/prompts/pmt_01.txt', 'r').read()

BUILDER_MAP = {
    'pmt_01': SFT_Builder(prompt_01, clause_info, None, lambda k: k),
    'pmt_01_yes_no': SFT_Builder_YesNo(prompt_01, clause_info, None, lambda k: k)
}

bd_name = 'pmt_01_yes_no' # customize

for tkn_name in tkn_names:
    print(f'Handle tokenizer data: {tkn_name}')
    process_sft_tokenizer(
        tkn_name, [f'data/ood_split/{k}' for k in split_names], 
        BUILDER_MAP[bd_name], bd_name
    )

Handle tokenizer data: flan-t5
Process data/ood_split/seed42_tr29


100%|██████████| 15760/15760 [00:00<00:00, 69887.99it/s]
100%|██████████| 164749/164749 [00:02<00:00, 72025.91it/s]
100%|██████████| 68172/68172 [00:00<00:00, 73251.75it/s]


Process data/ood_split/seed128_tr29


100%|██████████| 14427/14427 [00:00<00:00, 33567.59it/s]
100%|██████████| 164749/164749 [00:02<00:00, 71222.98it/s]
100%|██████████| 68172/68172 [00:00<00:00, 73235.71it/s]


Process data/ood_split/seed89_tr29


100%|██████████| 16548/16548 [00:00<00:00, 68615.79it/s]
100%|██████████| 164749/164749 [00:03<00:00, 45737.71it/s]
100%|██████████| 68172/68172 [00:01<00:00, 41914.12it/s]


Handle tokenizer data: llama2
Process data/ood_split/seed42_tr29


100%|██████████| 15951/15951 [00:00<00:00, 68605.58it/s]
100%|██████████| 167591/167591 [00:02<00:00, 71996.07it/s]
100%|██████████| 69348/69348 [00:00<00:00, 72200.06it/s]


Process data/ood_split/seed128_tr29


100%|██████████| 14598/14598 [00:00<00:00, 69199.55it/s]
100%|██████████| 167591/167591 [00:02<00:00, 71684.12it/s]
100%|██████████| 69348/69348 [00:01<00:00, 42295.26it/s]


Process data/ood_split/seed89_tr29


100%|██████████| 16719/16719 [00:00<00:00, 67865.40it/s]
100%|██████████| 167591/167591 [00:03<00:00, 46311.66it/s]
100%|██████████| 69348/69348 [00:01<00:00, 43955.10it/s]


Handle tokenizer data: llama3
Process data/ood_split/seed42_tr29


100%|██████████| 15692/15692 [00:00<00:00, 69742.33it/s]
100%|██████████| 162371/162371 [00:02<00:00, 71902.33it/s]
100%|██████████| 67188/67188 [00:00<00:00, 73169.04it/s]


Process data/ood_split/seed128_tr29


100%|██████████| 14364/14364 [00:00<00:00, 32788.94it/s]
100%|██████████| 162371/162371 [00:02<00:00, 71719.96it/s]
100%|██████████| 67188/67188 [00:01<00:00, 42879.88it/s]


Process data/ood_split/seed89_tr29


100%|██████████| 16484/16484 [00:00<00:00, 68036.44it/s]
100%|██████████| 162371/162371 [00:02<00:00, 71710.40it/s]
100%|██████████| 67188/67188 [00:01<00:00, 41461.18it/s]


Handle tokenizer data: mistral
Process data/ood_split/seed42_tr29


100%|██████████| 15812/15812 [00:00<00:00, 69851.10it/s]
100%|██████████| 166083/166083 [00:02<00:00, 71738.88it/s]
100%|██████████| 68724/68724 [00:00<00:00, 72884.07it/s]


Process data/ood_split/seed128_tr29


100%|██████████| 14495/14495 [00:00<00:00, 68420.07it/s]
100%|██████████| 166083/166083 [00:02<00:00, 72069.25it/s]
100%|██████████| 68724/68724 [00:00<00:00, 72882.17it/s]


Process data/ood_split/seed89_tr29


100%|██████████| 16599/16599 [00:00<00:00, 67757.91it/s]
100%|██████████| 166083/166083 [00:02<00:00, 70876.73it/s]
100%|██████████| 68724/68724 [00:01<00:00, 42491.53it/s]


Handle tokenizer data: phi1
Process data/ood_split/seed42_tr29


100%|██████████| 15719/15719 [00:00<00:00, 69933.91it/s]
100%|██████████| 162342/162342 [00:02<00:00, 72209.52it/s]
100%|██████████| 67176/67176 [00:01<00:00, 45154.94it/s]


Process data/ood_split/seed128_tr29


100%|██████████| 14376/14376 [00:00<00:00, 69214.18it/s]
100%|██████████| 162342/162342 [00:02<00:00, 71913.78it/s]
100%|██████████| 67176/67176 [00:00<00:00, 72131.58it/s]


Process data/ood_split/seed89_tr29


100%|██████████| 16497/16497 [00:00<00:00, 67033.68it/s]
100%|██████████| 162342/162342 [00:03<00:00, 45891.35it/s]
100%|██████████| 67176/67176 [00:00<00:00, 72228.18it/s]


Handle tokenizer data: phi2
Process data/ood_split/seed42_tr29


100%|██████████| 15719/15719 [00:00<00:00, 37682.40it/s]
100%|██████████| 162342/162342 [00:02<00:00, 72154.71it/s]
100%|██████████| 67176/67176 [00:00<00:00, 73326.98it/s]


Process data/ood_split/seed128_tr29


100%|██████████| 14376/14376 [00:00<00:00, 69160.51it/s]
100%|██████████| 162342/162342 [00:03<00:00, 46881.36it/s]
100%|██████████| 67176/67176 [00:00<00:00, 71887.18it/s]


Process data/ood_split/seed89_tr29


100%|██████████| 16497/16497 [00:00<00:00, 68046.58it/s]
100%|██████████| 162342/162342 [00:02<00:00, 71585.25it/s]
100%|██████████| 67176/67176 [00:01<00:00, 42980.93it/s]


In [30]:
# Show SFT Data
tkn_name = tkn_names[2]
split = split_names[0]
bd_name = 'pmt_01_yes_no'
train_data = load_jsonl(f'data/ood_split/{split}/{tkn_name}/{bd_name}/train_data.jsonl')
test_data_ood = load_jsonl(f'data/ood_split/{split}/{tkn_name}/{bd_name}/test_data_ood.jsonl')
test_data_id = load_jsonl(f'data/ood_split/{split}/{tkn_name}/{bd_name}/test_data_id.jsonl')
print(train_data[0].keys())
print(test_data_ood[0].keys())
print(train_data[0])

dict_keys(['title', 'para_idx', 'q_id', 'source', 'target', 'type'])
dict_keys(['title', 'para_idx', 'q_id', 'source', 'target', 'type'])
{'title': 'LIMEENERGYCO_09_09_1999-EX-10-DISTRIBUTOR AGREEMENT', 'para_idx': 0, 'q_id': 2, 'source': 'You are a helpful assistant. Review the contract clauses and answer questions. Output the mentioned clauses if exist; otherwise output "No".\n\n###Clauses:\nEXHIBIT 10.6\n DISTRIBUTOR AGREEMENT\n THIS DISTRIBUTOR AGREEMENT (the "Agreement") is made by and between Electric City Corp., a Delaware corporation ("Company") and Electric City of Illinois LLC ("Distributor") this 7th day of September, 1999.\n RECITALS\n A. The Company\'s Business. The Company is presently engaged in the business of selling an energy efficiency device, which is referred to as an "Energy Saver" which may be improved or otherwise changed from its present composition (the "Products"). The Company may engage in the business of selling other products or other devices other than th

## Chat Data

In [19]:
def build_tkn(path):
    return AutoTokenizer.from_pretrained(path, trust_remote_code = True)

TKN_MAP = {'flan-t5': build_tkn('google/flan-t5-large'),
    'llama2': build_tkn('meta-llama/Llama-2-7b-hf'),
    'llama3': build_tkn('meta-llama/Meta-Llama-3-8B'),
    'mistral': build_tkn('mistralai/Mistral-7B-v0.1'),
    # 'phi1': build_tkn('microsoft/phi-1_5'),
    'phi2': build_tkn('microsoft/phi-2')
}

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
tk = TKN_MAP['llama3']
print(tk.chat_template)

None


In [30]:
path = 'meta-llama/Meta-Llama-3-8B-Instruct'
# path = 'mistralai/Mistral-7B-Instruct-v0.2'
tokenizer = AutoTokenizer.from_pretrained(path)

msg = [
    #  {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]
r = tokenizer.apply_chat_template(msg, tokenize=False, 
        add_generation_prompt=True)
print(r)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>




In [34]:
data_path = 'data/ood_split/seed42_tr29/llama3/pmt_01/train_data.jsonl'
train_ds = CUAD_SFT_Cached(
    data_path, tokenizer, is_seq2seq=False, is_chat=True, small = True
    # cache_dir = Path(data_path).parent / 'cache',
)

100%|███████████████████████████████████████| 200/200 [00:00<00:00, 1123.16it/s]


In [35]:
print(tokenizer.decode(train_ds[0]['input_ids']))

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

You are a helpful assistant. Review the contract clauses and answer questions. Output the mentioned clauses if exist; otherwise output "No".

###Clauses:
EXHIBIT 10.6
 DISTRIBUTOR AGREEMENT
 THIS DISTRIBUTOR AGREEMENT (the "Agreement") is made by and between Electric City Corp., a Delaware corporation ("Company") and Electric City of Illinois LLC ("Distributor") this 7th day of September, 1999.
 RECITALS
 A. The Company's Business. The Company is presently engaged in the business of selling an energy efficiency device, which is referred to as an "Energy Saver" which may be improved or otherwise changed from its present composition (the "Products"). The Company may engage in the business of selling other products or other devices other than the Products, which will be considered Products if Distributor exercises its options pursuant to Section 7 hereof.

###Question: The date of the contract

###Answer:<|eot_id|><|start_header_i

In [29]:
print(tokenizer.convert_ids_to_tokens(train_ds[0]['input_ids'][-30:]))

['▁The', '▁date', '▁of', '▁the', '▁contract', '<0x0A>', '<0x0A>', '###', 'An', 'swer', ':', '▁[', '/', 'INST', ']', '▁-', '▁', '7', 'th', '▁day', '▁of', '▁September', ',', '▁', '1', '9', '9', '9', '.', '</s>']


In [23]:
np.max([len(k['input_ids']) for k in train_ds])

1127

In [7]:
tokenizer.chat_template

"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"

In [24]:
r2 = tokenizer.apply_chat_template(msg, tokenize=True, 
        add_generation_prompt=True)
print(r2)

[1, 733, 16289, 28793, 6526, 460, 368, 28804, 733, 28748, 16289, 28793]


In [25]:
print(tokenizer.convert_ids_to_tokens(r2))

['<s>', '▁[', 'INST', ']', '▁Who', '▁are', '▁you', '?', '▁[', '/', 'INST', ']']
