In [14]:
import argparse
from argparse import Namespace
from pathlib import Path
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1'
os.environ['HF_HUB_CACHE'] = '/next_share/hf_cache/hub/'
import json
import torch
from transformers import (
    AutoTokenizer, PreTrainedTokenizer
)
from peft import (
    LoraConfig
)
from accelerate import PartialState, Accelerator

import context
os.chdir(context.proj_dir)

from cont_gen.trainer.utils_dist import initialize_accelerator, DistLogger
from cont_gen.data_loader.cuad_prompt import CUAD_SFT, SFT_Padding, CUAD_SFT_Seq2Seq
from cont_gen.data_loader.cuad_sft import CUAD_SFT_Cached
from cont_gen.utils.model_utils import build_hf_or_peft_model, smart_resize_embeddings, load_hf_model_from_checkpoint
from cont_gen.trainer.utils import get_smart_optimizer, compute_clm_loss_with_ignore
from cont_gen.trainer.train_only_accelerate import Trainer_Basic, TrainingArgs_Basic
from cont_gen.model.loss import LM_Simple_Feed
from cont_gen.run.infer_sft import SimpleGenerator, load_test_dataset

In [2]:
from cont_gen.run.train_sft import (
    load_train_dataset, build_model, get_parser
)

In [8]:
tk_name='llama3'
model_path = 'meta-llama/Meta-Llama-3-8B'
ckpt = 'runs/ood/llama3/seed42_tr29/pmt_01_lr1e-5_bs16_wd0.0/checkpoint-15692'
args = get_parser().parse_args([])
args.__dict__.update(dict(
    total_batch_size = 16,
    data_path = f'data/ood_split/seed42_tr29/{tk_name}/pmt_01/train_data.jsonl',
    base_model = model_path,
    dtype = 'bf16',
    lr = 1e-5,
    weight_decay = 0.0,
    device_batch_size = 1,
    max_epochs = 1,
    logging_steps = 5
))

In [9]:
state = PartialState()
# grad acc step
if args.grad_acc_steps is None:
    args.grad_acc_steps = int(args.total_batch_size / args.device_batch_size / state.num_processes)

accelerator, ds_config = initialize_accelerator(
    args.ds_config, args.device_batch_size, args.grad_acc_steps
)

# Get logger
log_file = None
logger = DistLogger(file = log_file)

# Build tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code = True)
if "pad_token" not in tokenizer.special_tokens_map:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Build dataset
is_seq2seq = ('t5' in args.base_model)
dataset = load_train_dataset(args, tokenizer)

## data collate fn
pad_args = {'pad_side': 'right'}
if args.debug:
    pad_args['pad_to_max_len'] = args.max_length if is_seq2seq else args.max_length*2
collate_fn = SFT_Padding(tokenizer.pad_token_id, **pad_args)

# Build model
# model = build_model(args, accelerator, logger)
# Load from checkpoint
model = load_hf_model_from_checkpoint(ckpt, accelerator, args.dtype)

## resize model embedding if necessary
# smart_resize_embeddings(tokenizer, model, logger)

accelerator.wait_for_everyone()

## build optimizer
optimizer = get_smart_optimizer(model, args.lr, args.weight_decay)

# Build trainer
## training args
tr_args = TrainingArgs_Basic(
    device_batch_size = args.device_batch_size,
    grad_acc_steps=args.grad_acc_steps,
    max_epochs = args.max_epochs,
    max_steps = args.max_steps,
    logging_steps = args.logging_steps,
    save_steps = args.save_steps,
    save_epochs = args.save_epochs,
    save_total_limit = args.save_total_limit
)
assert tr_args.total_batch_size == args.total_batch_size

## Trainer
trainer = Trainer_Basic(
    tr_args, model, dataset, optimizer, accelerator,
    ds_config = ds_config,
    collate_fn = collate_fn,
    compute_loss_fn = LM_Simple_Feed(),
    output_dir = args.output_dir,
    logger = logger
)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Load from cache: data/ood_split/seed42_tr29/llama3/pmt_01/cache/cached_train_data.jsonl_Meta-Llama-3-8B_v1.0.pkl


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Resize embedding num to 128257


In [10]:
sample = dataset[0]
print(sample.keys())
print(sample['input_ids'])
print(sample['labels'])
print(tokenizer.decode(sample['input_ids']))
print(tokenizer.decode([k for k in sample['labels'] if k != -100]))

dict_keys(['input_ids', 'attention_mask', 'labels'])
[128000, 2675, 527, 264, 11190, 18328, 13, 10506, 279, 5226, 50198, 323, 4320, 4860, 13, 9442, 279, 9932, 50198, 422, 3073, 26, 6062, 2612, 330, 2822, 11690, 14711, 65217, 4881, 512, 3337, 39, 3336, 964, 220, 605, 13, 21, 198, 98941, 878, 76590, 16837, 198, 10245, 98941, 878, 76590, 16837, 320, 1820, 330, 9219, 17589, 909, 374, 1903, 555, 323, 1990, 21246, 4409, 22621, 2637, 264, 40838, 27767, 3573, 14831, 909, 323, 21246, 4409, 315, 19174, 15620, 3573, 35, 79488, 909, 420, 220, 22, 339, 1938, 315, 6250, 11, 220, 2550, 24, 627, 75236, 34288, 50, 198, 362, 13, 578, 8351, 596, 8184, 13, 578, 8351, 374, 50801, 17045, 304, 279, 2626, 315, 11486, 459, 4907, 15374, 3756, 11, 902, 374, 14183, 311, 439, 459, 330, 33775, 328, 7403, 1, 902, 1253, 387, 13241, 477, 6062, 5614, 505, 1202, 3118, 18528, 320, 1820, 330, 18219, 1865, 578, 8351, 1253, 16988, 304, 279, 2626, 315, 11486, 1023, 3956, 477, 1023, 7766, 1023, 1109, 279, 15899, 11, 902, 690,

In [45]:
print(tokenizer.convert_ids_to_tokens([k for k in sample['labels'] if k != -100]))

['<|begin_of_text|>', '-', 'Ġ', '7', 'th', 'Ġday', 'Ġof', 'ĠSeptember', ',', 'Ġ', '199', '9', '.', '<|end_of_text|>']


## Generate

In [15]:
generator = SimpleGenerator(tokenizer, is_encoder_decoder=is_seq2seq)

test_args = Namespace(
    data_path = f'data/ood_split/seed42_tr29/{tk_name}/pmt_01/test_data_id.jsonl',
    max_length = 1000,
    debug = False,
)
test_ds = load_test_dataset(test_args, tokenizer, is_seq2seq, part = 'sampled')

Load from cache: data/ood_split/seed42_tr29/llama3/pmt_01/cache/cached_test_data_id.jsonl_Meta-Llama-3-8B_v1.0.pkl


In [28]:
def to_batch(data):
    return {
        'input_ids': torch.tensor(data['input_ids']).unsqueeze(0).cuda(),
        'attention_mask': torch.tensor(data['attention_mask']).unsqueeze(0).cuda(),
        'labels': torch.tensor(data['labels']).unsqueeze(0).cuda(),
    }

In [21]:
tokenizer.decode(test_ds[0]['input_ids'])

'<|begin_of_text|>You are a helpful assistant. Review the contract clauses and answer questions. Output the mentioned clauses if exist; otherwise output "No".\n\n###Clauses:\n3\nSource: LOHA CO. LTD., F-1, 12/9/2019\n12.2 (1) Invoice in 3 originals indicating contract number and L/C number. (2) Final acceptance certificate signed by the Buyer and the Seller. 13. SHIPMENT: CIP The seller shall contract on usual terms at his own expenses for the carriage of the goods to the agreed point at the named place of destination and bear all risks and expenses until the goods have been delivered to the port of destination. The Sellers shall ship the goods within the shipment time from the port of shipment to the port of destination. Transshipment is allowed. Partial Shipment is allowed. In case the goods are to be dispatched by parcel post/sea-freight, the Sellers shall, 3 days before the time of delivery, inform the Buyers by cable/letter of the estimated date of delivery, Contract No., commodit

In [29]:
r = generator(model, to_batch(sample))

In [51]:
r

{'pred_tokens': [tensor([128000,   2822, 128001], device='cuda:0')]}

In [30]:
tokenizer.decode(r['pred_tokens'][0])

'<|begin_of_text|>No<|end_of_text|>'

In [31]:
with torch.no_grad():
    print(model(**to_batch(sample)).loss)

tensor(0.6665, device='cuda:0')


In [41]:
tgt_idx = sample['labels'].index(128000)
new_tgt = r['pred_tokens'][0].tolist()
new_ids = sample['input_ids'][:tgt_idx] + new_tgt
new_mask = [1] * len(new_ids)
new_labels = [-100] * tgt_idx + new_tgt
neg_sample = {'input_ids': new_ids, 'attention_mask': new_mask, 'labels': new_labels}

In [43]:
# print(tokenizer.decode(neg_sample['input_ids']))
with torch.no_grad():
    print(model(**to_batch(neg_sample)).loss)

tensor(0.0191, device='cuda:0')


In [47]:
# Force model to predict a span.
tgt_len = len([k for k in sample['labels'] if k != -100])
new_len = len(sample['input_ids']) - tgt_len + 3 # 3 is the first three token
pos_sample = {k:v[:new_len] for k,v in sample.items()}
print(tokenizer.decode(pos_sample['input_ids'][-20:]))

 7 hereof.

###Question: The date of the contract

###Answer:<|begin_of_text|>- 


In [48]:
r_pos = generator(model, to_batch(pos_sample))

In [50]:
tokenizer.decode(r_pos['pred_tokens'][0])

'7th day of September, 1999<|end_of_text|>'

In [52]:
pos_head = [12800, 12]
neg_head = [12800, 2822]

In [53]:
def add_target_head(sample, head):
    """Append target head to inputs"""
    ori_ids = sample['input_ids']
    ipt_len = len([k for k in sample['labels'] if k == -100])
    new_ids = ori_ids[:ipt_len] + list(head)
    new_mask = [1] * len(new_ids)
    new_labels = [-100] * ipt_len + list(head)
    return {'input_ids': new_ids, 'attention_mask': new_mask, 'labels': new_labels}

In [81]:
with torch.no_grad():
    model.eval()
    pos_batch = to_batch(add_target_head(sample, [12800]))
    pos_out = model(**pos_batch)

In [82]:
probs = torch.softmax(pos_out.logits[0,-1], dim = 0)
top = torch.argsort(probs, descending = True)
for i in range(10):
    token = top[i].item()
    print(f'{tokenizer.convert_ids_to_tokens(token)} {token}: {probs[token]}')

<|end_of_text|> 128001: 0.2812570631504059
<|begin_of_text|> 128000: 0.13285644352436066
Ġon 389: 0.10346870124340057
Cla 65217: 0.07111292332410812
Ġ 220: 0.06275693327188492
No 2822: 0.045913953334093094
Ġclauses 50198: 0.03155616670846939
ĠNo 2360: 0.011608866043388844
Ġthe 279: 0.010244788601994514
Ġfor 369: 0.005837304051965475


In [75]:
ipt = add_target_head(sample, [])
model.generate(torch.tensor(ipt['input_ids']).unsqueeze(0).cuda(), do_sample = False, eos_token_id = tokenizer.eos_token_id)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


tensor([[128000,   2675,    527,    264,  11190,  18328,     13,  10506,    279,
           5226,  50198,    323,   4320,   4860,     13,   9442,    279,   9932,
          50198,    422,   3073,     26,   6062,   2612,    330,   2822,  11690,
          14711,  65217,   4881,    512,   3337,     39,   3336,    964,    220,
            605,     13,     21,    198,  98941,    878,  76590,  16837,    198,
          10245,  98941,    878,  76590,  16837,    320,   1820,    330,   9219,
          17589,    909,    374,   1903,    555,    323,   1990,  21246,   4409,
          22621,   2637,    264,  40838,  27767,   3573,  14831,    909,    323,
          21246,   4409,    315,  19174,  15620,   3573,     35,  79488,    909,
            420,    220,     22,    339,   1938,    315,   6250,     11,    220,
           2550,     24,    627,  75236,  34288,     50,    198,    362,     13,
            578,   8351,    596,   8184,     13,    578,   8351,    374,  50801,
          17045,    304,    

In [78]:
print(tokenizer.convert_ids_to_tokens(ipt['input_ids'][-20:]))

['Ġpursuant', 'Ġto', 'ĠSection', 'Ġ', '7', 'Ġhere', 'of', '.ĊĊ', '###', 'Question', ':', 'ĠThe', 'Ġdate', 'Ġof', 'Ġthe', 'Ġcontract', 'ĊĊ', '###', 'Answer', ':']
