In [1]:
import argparse
from pathlib import Path
import os
import json
import torch
from transformers import (
    AutoTokenizer, PreTrainedTokenizer
)
from peft import (
    LoraConfig
)
from accelerate import PartialState, Accelerator

import context

from cont_gen.trainer.utils_dist import initialize_accelerator, DistLogger
from cont_gen.data_loader.cuad_prompt import CUAD_SFT, SFT_Padding, CUAD_SFT_Seq2Seq
from cont_gen.data_loader.cuad_sft import CUAD_SFT_Cached
from cont_gen.utils.model_utils import build_hf_or_peft_model, smart_resize_embeddings
from cont_gen.trainer.utils import get_smart_optimizer, compute_clm_loss_with_ignore
from cont_gen.trainer.train_only_accelerate import Trainer_Basic, TrainingArgs_Basic
from cont_gen.model.loss import LM_Simple_Feed

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from cont_gen.run.train_sft import (
    load_train_dataset, build_model, get_parser
)

In [3]:
os.chdir('/storage_fast/rhshui/workspace/contract_review')

In [4]:
args = get_parser().parse_args([])
args.__dict__.update(dict(
    total_batch_size = 16,
    data_path = 'data/ood_split/seed42_tr29/pmt_01/train_data.jsonl',
    base_model = '/storage_fast/rhshui/llm/ms-phi-1_5/',
    dtype = 'bf16',
    lr = 1e-5,
    weight_decay = 0.0,
    device_batch_size = 1,
    max_epochs = 1,
    logging_steps = 5
))

In [5]:
state = PartialState()
# grad acc step
if args.grad_acc_steps is None:
    args.grad_acc_steps = int(args.total_batch_size / args.device_batch_size / state.num_processes)

accelerator, ds_config = initialize_accelerator(
    args.ds_config, args.device_batch_size, args.grad_acc_steps
)

# Get logger
log_file = None if args.output_dir is None else str(Path(args.output_dir) / 'log.txt')
logger = DistLogger(file = log_file)

## log arguments
args_str = json.dumps(args.__dict__, indent = 4, ensure_ascii=False)
if args.output_dir:
    with open(Path(args.output_dir) / 'args.json', 'w') as f:
        f.write(args_str)
logger.log(f'Training args:\n {args_str}')

# Build tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code = True)
if "pad_token" not in tokenizer.special_tokens_map:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Build dataset
is_seq2seq = ('t5' in args.base_model)
dataset = load_train_dataset(args, tokenizer)

## data collate fn
pad_args = {'pad_side': 'right'}
if args.debug:
    pad_args['pad_to_max_len'] = args.max_length if is_seq2seq else args.max_length*2
collate_fn = SFT_Padding(tokenizer.pad_token_id, **pad_args)

# Build model
model = build_model(args, accelerator, logger)

## resize model embedding if necessary
smart_resize_embeddings(tokenizer, model, logger)

accelerator.wait_for_everyone()

## build optimizer
optimizer = get_smart_optimizer(model, args.lr, args.weight_decay)

# Build trainer
## training args
tr_args = TrainingArgs_Basic(
    device_batch_size = args.device_batch_size,
    grad_acc_steps=args.grad_acc_steps,
    max_epochs = args.max_epochs,
    max_steps = args.max_steps,
    logging_steps = args.logging_steps,
    save_steps = args.save_steps,
    save_epochs = args.save_epochs,
    save_total_limit = args.save_total_limit
)
assert tr_args.total_batch_size == args.total_batch_size

## Trainer
trainer = Trainer_Basic(
    tr_args, model, dataset, optimizer, accelerator,
    ds_config = ds_config,
    collate_fn = collate_fn,
    compute_loss_fn = LM_Simple_Feed(),
    output_dir = args.output_dir,
    logger = logger
)


[2024-04-25 01:22:00][Main] Training args:
 {
    "output_dir": null,
    "ds_config": null,
    "debug": false,
    "total_batch_size": 16,
    "data_path": "data/ood_split/seed42_tr29/pmt_01/train_data.jsonl",
    "max_length": 512,
    "labels_on_full": false,
    "base_model": "/storage_fast/rhshui/llm/ms-phi-1_5/",
    "saved_model": null,
    "dtype": "bf16",
    "lora": false,
    "lora_r": 8,
    "lora_alpha": 16,
    "lora_target_modules": null,
    "lora_dropout": 0.05,
    "lr": 1e-05,
    "weight_decay": 0.0,
    "device_batch_size": 1,
    "grad_acc_steps": 16,
    "max_epochs": 1,
    "max_steps": null,
    "logging_steps": 5,
    "save_steps": null,
    "save_epochs": 1,
    "save_total_limit": 5,
    "local_rank": null
}


100%|███████████████████████████████████| 14886/14886 [00:09<00:00, 1597.73it/s]


Write to cache: data/ood_split/seed42_tr29/pmt_01/cache/cached_train_data.jsonl_ms-phi-1_5_v1.0.pkl
[2024-04-25 01:22:16][Main] Not resize embeddings. model: 51200, tokenizer: 50296


In [7]:
sample = dataset[0]
print(sample.keys())
print(sample['input_ids'])
print(sample['labels'])
print(tokenizer.decode(sample['input_ids']))
print(tokenizer.decode([k for k in sample['labels'] if k != -100]))

dict_keys(['input_ids', 'attention_mask', 'labels'])
[1639, 389, 257, 7613, 8796, 13, 6602, 262, 2775, 31485, 290, 3280, 2683, 13, 25235, 262, 4750, 31485, 611, 2152, 26, 4306, 5072, 366, 2949, 1911, 198, 198, 21017, 47404, 2664, 25, 198, 6369, 39, 9865, 2043, 838, 13, 21, 198, 34957, 9865, 3843, 1581, 13077, 2200, 12529, 198, 12680, 34957, 9865, 3843, 1581, 13077, 2200, 12529, 357, 1169, 366, 10262, 10237, 4943, 318, 925, 416, 290, 1022, 13944, 2254, 11421, 1539, 257, 19603, 12017, 5855, 39154, 4943, 290, 13944, 2254, 286, 9486, 11419, 5855, 20344, 2455, 273, 4943, 428, 767, 400, 1110, 286, 2693, 11, 7358, 13, 198, 19644, 2043, 23333, 198, 317, 13, 383, 5834, 338, 7320, 13, 383, 5834, 318, 27606, 7953, 287, 262, 1597, 286, 6301, 281, 2568, 9332, 3335, 11, 543, 318, 6412, 284, 355, 281, 366, 28925, 311, 8770, 1, 543, 743, 307, 6596, 393, 4306, 3421, 422, 663, 1944, 11742, 357, 1169, 366, 48650, 11074, 383, 5834, 743, 8209, 287, 262, 1597, 286, 6301, 584, 3186, 393, 584, 4410, 584, 621,

In [10]:
raw_d = dataset.data[0]

In [12]:
src_enc = tokenizer(raw_d['source'])
tgt_enc = tokenizer(raw_d['target'])
print(len(src_enc.input_ids))
print(len(tgt_enc.input_ids))

input_ids = src_enc.input_ids
input_ids = input_ids + tgt_enc.input_ids
print(len(src_enc.input_ids))

204
9
204
