In [1]:
import json
import torch
import pandas as pd

from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training
from contextlib import nullcontext
from transformers import (default_data_collator, Trainer, TrainingArguments,
                          TrainerCallback, AutoTokenizer, AutoModelForCausalLM,
                          BitsAndBytesConfig)

from transformers.integrations import WandbCallback

from utils.dataset import CombineDataset, template

# setting up wandb
%env WANDB_PROJECT=disco-limbic-dialogue

env: WANDB_PROJECT=disco-limbic-dialogue


# Params

In [2]:
# model
model_id = 'microsoft/phi-2'
model_type = 'phi'

# data settings
dataset_train_path = 'data/dataset/v1/train.json'
dataset_test_path = 'data/dataset/v1/test.json'
max_data_length = 384

# lora settings
lora_r = 32
lora_alpha = 16
lora_dropout = 0.05
target_modules = ['Wqkv', 'out_proj']

# train settings 
device = 'cuda'
lr = 3e-4
num_train_epochs = 5

gradient_accumulation_steps = 32
per_device_train_bs = 1
per_device_eval_bs = 1

log_steps = 10
eval_steps = 30


output_dir = f'lora/disco-limbic-dialogue-phi2-eos'

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    flash_attn=True,
    flash_rotary=True,
    fused_dense=True,
    trust_remote_code=True,
    device_map={'': 0},
    quantization_config=bnb_config
)


tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# load data

In [4]:
with open(dataset_train_path, 'r') as f:
    raw_data = json.load(f)

with open(dataset_test_path, 'r') as f:
    raw_data_test = json.load(f)

## filter
data_train = []
for i in raw_data:
    if len(tokenizer(template(i, model_type))['input_ids']) < max_data_length:
        data_train.append(i)
print(len(raw_data), len(data_train))

data_test = []
for i in raw_data_test:
    if len(tokenizer(template(i, model_type))['input_ids']) < max_data_length:
        data_test.append(i)
print(len(raw_data_test), len(data_test))

#data_train, data_test = torch.utils.data.random_split(data, lengths=[0.95, 0.05])
train_dataset = CombineDataset(data_train, tokenizer,
                               max_words=max_data_length,
                               model_type=model_type)
test_dataset = CombineDataset(data_test, tokenizer,
                              max_words=max_data_length,
                              model_type=model_type)
len(train_dataset), len(test_dataset)

4768 4755
118 118


(4755, 118)

In [5]:
model.train()

def create_peft_config(model):
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules = target_modules
    )

    model = prepare_model_for_int8_training(model, use_gradient_checkpointing=False)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model, peft_config

# create peft config
model, lora_config = create_peft_config(model)

trainable params: 15,728,640 || all params: 2,790,783,096 || trainable%: 0.5635923487763593




In [6]:
enable_profiler = False
config = {
    'lora_config': lora_config,
    'learning_rate': lr,
    'num_train_epochs': num_train_epochs,
    'gradient_accumulation_steps': gradient_accumulation_steps,
    'per_device_train_batch_size': per_device_train_bs,
    'gradient_checkpointing': False,
}

profiler = nullcontext()


In [7]:
def decode_predictions(tokenizer, predictions):
    prediction_text = tokenizer.batch_decode(predictions.predictions.argmax(axis=-1))
    return {"predictions": prediction_text} # "labels": labels, 


class WandbPredictionProgressCallback(WandbCallback):
    def __init__(self, trainer, tokenizer, val_dataset, num_samples=100, freq=2):
        super().__init__()
        self.trainer = trainer
        self.tokenizer = tokenizer
        self.sample_dataset = [val_dataset[i] for i in range(num_samples)]
        self.freq = freq

    def on_evaluate(self, args, state, control, **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        if state.global_step % self.freq == 0:
            predictions = self.trainer.predict(self.sample_dataset)
            predictions = decode_predictions(self.tokenizer, predictions)
            predictions_df = pd.DataFrame(predictions)
            predictions_df["epoch"] = state.epoch
            records_table = self._wandb.Table(dataframe=predictions_df)
            self._wandb.log({"sample_predictions": records_table})

In [8]:
print(output_dir)
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    #bf16=True,  # Use BF16 if available
    ## eval strat
    do_eval=True,
    evaluation_strategy='steps',
    eval_steps=eval_steps,
    per_device_eval_batch_size=per_device_eval_bs,
    ## logging strategies
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=log_steps,
    ## wandb
    report_to="wandb",
    run_name=output_dir.split('/')[-1],
    ## other
    save_strategy="no",
    #optim="adamw_torch_fused",
    max_steps=-1,
    **{k:v for k,v in config.items() if k != 'lora_config'}
)

with profiler:
    # Create Trainer instance``
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=default_data_collator,
        callbacks=[],
    )
    # Instantiate the WandbPredictionProgressCallback
    progress_callback = WandbPredictionProgressCallback(
        trainer=trainer,
        tokenizer=tokenizer,
        val_dataset=test_dataset,
        num_samples=16,
        freq=30,
    )

    # Add the callback to the trainer
    trainer.add_callback(progress_callback)
    
    trainer.train()

lora/disco-limbic-dialogue-phi2-eos


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxxond[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/740 [00:00<?, ?it/s]

{'loss': 3.2236, 'learning_rate': 0.0002959459459459459, 'epoch': 0.07}
{'loss': 2.6829, 'learning_rate': 0.0002918918918918919, 'epoch': 0.13}
{'loss': 2.4297, 'learning_rate': 0.0002878378378378378, 'epoch': 0.2}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 2.1827139854431152, 'eval_runtime': 70.4136, 'eval_samples_per_second': 1.676, 'eval_steps_per_second': 1.676, 'epoch': 0.2}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 2.3622, 'learning_rate': 0.00028378378378378377, 'epoch': 0.27}
{'loss': 2.3757, 'learning_rate': 0.0002797297297297297, 'epoch': 0.34}
{'loss': 2.2837, 'learning_rate': 0.00027567567567567564, 'epoch': 0.4}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 2.035541534423828, 'eval_runtime': 71.1061, 'eval_samples_per_second': 1.659, 'eval_steps_per_second': 1.659, 'epoch': 0.4}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 2.2348, 'learning_rate': 0.0002716216216216216, 'epoch': 0.47}
{'loss': 2.2665, 'learning_rate': 0.00026756756756756756, 'epoch': 0.54}
{'loss': 2.2118, 'learning_rate': 0.0002635135135135135, 'epoch': 0.61}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.9697753190994263, 'eval_runtime': 70.7888, 'eval_samples_per_second': 1.667, 'eval_steps_per_second': 1.667, 'epoch': 0.61}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 2.1797, 'learning_rate': 0.00025945945945945944, 'epoch': 0.67}
{'loss': 2.1993, 'learning_rate': 0.00025540540540540537, 'epoch': 0.74}
{'loss': 2.2058, 'learning_rate': 0.0002513513513513513, 'epoch': 0.81}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.9466423988342285, 'eval_runtime': 70.9011, 'eval_samples_per_second': 1.664, 'eval_steps_per_second': 1.664, 'epoch': 0.81}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 2.1102, 'learning_rate': 0.0002472972972972973, 'epoch': 0.87}
{'loss': 2.0951, 'learning_rate': 0.00024324324324324323, 'epoch': 0.94}
{'loss': 2.0526, 'learning_rate': 0.00023918918918918917, 'epoch': 1.01}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.8913027048110962, 'eval_runtime': 70.6881, 'eval_samples_per_second': 1.669, 'eval_steps_per_second': 1.669, 'epoch': 1.01}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.9671, 'learning_rate': 0.0002351351351351351, 'epoch': 1.08}
{'loss': 1.9674, 'learning_rate': 0.00023108108108108106, 'epoch': 1.14}
{'loss': 1.9685, 'learning_rate': 0.00022702702702702703, 'epoch': 1.21}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.8418546915054321, 'eval_runtime': 71.405, 'eval_samples_per_second': 1.653, 'eval_steps_per_second': 1.653, 'epoch': 1.21}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.9107, 'learning_rate': 0.00022297297297297293, 'epoch': 1.28}
{'loss': 1.9056, 'learning_rate': 0.0002189189189189189, 'epoch': 1.35}
{'loss': 1.8871, 'learning_rate': 0.00021486486486486486, 'epoch': 1.41}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.8122501373291016, 'eval_runtime': 71.2387, 'eval_samples_per_second': 1.656, 'eval_steps_per_second': 1.656, 'epoch': 1.41}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.9272, 'learning_rate': 0.0002108108108108108, 'epoch': 1.48}
{'loss': 1.8774, 'learning_rate': 0.00020675675675675673, 'epoch': 1.55}
{'loss': 1.9759, 'learning_rate': 0.0002027027027027027, 'epoch': 1.62}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.7968631982803345, 'eval_runtime': 71.0649, 'eval_samples_per_second': 1.66, 'eval_steps_per_second': 1.66, 'epoch': 1.62}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.8823, 'learning_rate': 0.00019864864864864863, 'epoch': 1.68}
{'loss': 1.9766, 'learning_rate': 0.0001945945945945946, 'epoch': 1.75}
{'loss': 1.7849, 'learning_rate': 0.0001905405405405405, 'epoch': 1.82}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.7745530605316162, 'eval_runtime': 71.3129, 'eval_samples_per_second': 1.655, 'eval_steps_per_second': 1.655, 'epoch': 1.82}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.8021, 'learning_rate': 0.00018648648648648646, 'epoch': 1.88}
{'loss': 1.7248, 'learning_rate': 0.00018243243243243242, 'epoch': 1.95}
{'loss': 1.7413, 'learning_rate': 0.00017837837837837839, 'epoch': 2.02}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.7711162567138672, 'eval_runtime': 71.1884, 'eval_samples_per_second': 1.658, 'eval_steps_per_second': 1.658, 'epoch': 2.02}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.6942, 'learning_rate': 0.0001743243243243243, 'epoch': 2.09}
{'loss': 1.6288, 'learning_rate': 0.00017027027027027026, 'epoch': 2.15}
{'loss': 1.6064, 'learning_rate': 0.0001662162162162162, 'epoch': 2.22}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.7206695079803467, 'eval_runtime': 71.1625, 'eval_samples_per_second': 1.658, 'eval_steps_per_second': 1.658, 'epoch': 2.22}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.7151, 'learning_rate': 0.00016216216216216215, 'epoch': 2.29}
{'loss': 1.6608, 'learning_rate': 0.0001581081081081081, 'epoch': 2.36}
{'loss': 1.6717, 'learning_rate': 0.00015405405405405402, 'epoch': 2.42}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.719363808631897, 'eval_runtime': 70.8947, 'eval_samples_per_second': 1.664, 'eval_steps_per_second': 1.664, 'epoch': 2.42}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.6593, 'learning_rate': 0.00015, 'epoch': 2.49}
{'loss': 1.6779, 'learning_rate': 0.00014594594594594595, 'epoch': 2.56}
{'loss': 1.5934, 'learning_rate': 0.00014189189189189188, 'epoch': 2.62}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.7094346284866333, 'eval_runtime': 70.59, 'eval_samples_per_second': 1.672, 'eval_steps_per_second': 1.672, 'epoch': 2.62}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.6777, 'learning_rate': 0.00013783783783783782, 'epoch': 2.69}
{'loss': 1.6242, 'learning_rate': 0.00013378378378378378, 'epoch': 2.76}
{'loss': 1.6045, 'learning_rate': 0.00012972972972972972, 'epoch': 2.83}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.7291371822357178, 'eval_runtime': 71.2976, 'eval_samples_per_second': 1.655, 'eval_steps_per_second': 1.655, 'epoch': 2.83}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.6177, 'learning_rate': 0.00012567567567567565, 'epoch': 2.89}
{'loss': 1.6313, 'learning_rate': 0.00012162162162162162, 'epoch': 2.96}
{'loss': 1.5747, 'learning_rate': 0.00011756756756756755, 'epoch': 3.03}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.6819088459014893, 'eval_runtime': 71.1538, 'eval_samples_per_second': 1.658, 'eval_steps_per_second': 1.658, 'epoch': 3.03}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.5624, 'learning_rate': 0.00011351351351351351, 'epoch': 3.1}
{'loss': 1.5442, 'learning_rate': 0.00010945945945945945, 'epoch': 3.16}
{'loss': 1.5451, 'learning_rate': 0.0001054054054054054, 'epoch': 3.23}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.7009185552597046, 'eval_runtime': 71.4913, 'eval_samples_per_second': 1.651, 'eval_steps_per_second': 1.651, 'epoch': 3.23}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.5141, 'learning_rate': 0.00010135135135135135, 'epoch': 3.3}
{'loss': 1.5681, 'learning_rate': 9.72972972972973e-05, 'epoch': 3.36}
{'loss': 1.3834, 'learning_rate': 9.324324324324323e-05, 'epoch': 3.43}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.675742745399475, 'eval_runtime': 71.0196, 'eval_samples_per_second': 1.662, 'eval_steps_per_second': 1.662, 'epoch': 3.43}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.5383, 'learning_rate': 8.918918918918919e-05, 'epoch': 3.5}
{'loss': 1.4985, 'learning_rate': 8.513513513513513e-05, 'epoch': 3.57}
{'loss': 1.4509, 'learning_rate': 8.108108108108108e-05, 'epoch': 3.63}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.6477264165878296, 'eval_runtime': 70.8866, 'eval_samples_per_second': 1.665, 'eval_steps_per_second': 1.665, 'epoch': 3.63}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.5927, 'learning_rate': 7.702702702702701e-05, 'epoch': 3.7}
{'loss': 1.4288, 'learning_rate': 7.297297297297297e-05, 'epoch': 3.77}
{'loss': 1.4929, 'learning_rate': 6.891891891891891e-05, 'epoch': 3.84}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.6548166275024414, 'eval_runtime': 71.0992, 'eval_samples_per_second': 1.66, 'eval_steps_per_second': 1.66, 'epoch': 3.84}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.4757, 'learning_rate': 6.486486486486486e-05, 'epoch': 3.9}
{'loss': 1.4985, 'learning_rate': 6.081081081081081e-05, 'epoch': 3.97}
{'loss': 1.428, 'learning_rate': 5.6756756756756757e-05, 'epoch': 4.04}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.6648551225662231, 'eval_runtime': 71.3656, 'eval_samples_per_second': 1.653, 'eval_steps_per_second': 1.653, 'epoch': 4.04}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.4649, 'learning_rate': 5.27027027027027e-05, 'epoch': 4.11}
{'loss': 1.3716, 'learning_rate': 4.864864864864865e-05, 'epoch': 4.17}
{'loss': 1.4827, 'learning_rate': 4.4594594594594596e-05, 'epoch': 4.24}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.6489076614379883, 'eval_runtime': 70.98, 'eval_samples_per_second': 1.662, 'eval_steps_per_second': 1.662, 'epoch': 4.24}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.4928, 'learning_rate': 4.054054054054054e-05, 'epoch': 4.31}
{'loss': 1.4289, 'learning_rate': 3.648648648648649e-05, 'epoch': 4.37}
{'loss': 1.4378, 'learning_rate': 3.243243243243243e-05, 'epoch': 4.44}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.6418145895004272, 'eval_runtime': 70.6234, 'eval_samples_per_second': 1.671, 'eval_steps_per_second': 1.671, 'epoch': 4.44}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.3064, 'learning_rate': 2.8378378378378378e-05, 'epoch': 4.51}
{'loss': 1.402, 'learning_rate': 2.4324324324324324e-05, 'epoch': 4.58}
{'loss': 1.3731, 'learning_rate': 2.027027027027027e-05, 'epoch': 4.64}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.6465519666671753, 'eval_runtime': 71.1071, 'eval_samples_per_second': 1.659, 'eval_steps_per_second': 1.659, 'epoch': 4.64}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.5085, 'learning_rate': 1.6216216216216215e-05, 'epoch': 4.71}
{'loss': 1.4209, 'learning_rate': 1.2162162162162162e-05, 'epoch': 4.78}
{'loss': 1.375, 'learning_rate': 8.108108108108107e-06, 'epoch': 4.85}


  0%|          | 0/118 [00:00<?, ?it/s]

{'eval_loss': 1.6428526639938354, 'eval_runtime': 71.2188, 'eval_samples_per_second': 1.657, 'eval_steps_per_second': 1.657, 'epoch': 4.85}


  0%|          | 0/16 [00:00<?, ?it/s]

{'loss': 1.4029, 'learning_rate': 4.054054054054054e-06, 'epoch': 4.91}
{'loss': 1.4102, 'learning_rate': 0.0, 'epoch': 4.98}
{'train_runtime': 31989.8005, 'train_samples_per_second': 0.743, 'train_steps_per_second': 0.023, 'train_loss': 1.7601289143433443, 'epoch': 4.98}


In [9]:
torch.cuda.current_device()

0

In [10]:
model.save_pretrained(output_dir)

In [11]:

dialog = [
    #"[Electrochemistry]: Whoa! In your hand: *pyrholidon* -- the double rainbow of synthetic hallucinogens. Rare and gritty, a product of the age of atomic power.",
    #"Look at the little puck of liquid.",
    #"[Electrochemistry]: What a funny little cap! Don't let the *scary* medical warnings throw you off. It's an inadequate antidote to radiation poisoning, but a *potent* antidote to *boredom*.",
    #"Hmm... open the cap.",
    #'Look around',
    #'"How can I take shit without taking off my sweater?"',
    'who am i?</s>'
]



query = template(dialog) + ' [|Assistant|] '
model_inputs = tokenizer(query, return_tensors="pt", add_special_tokens=False).input_ids.to('cuda')
generated_ids = model.generate(input_ids=model_inputs, max_new_tokens=64,
                               do_sample=True,
                               #pad_token_id=tokenizer.eos_token_id,
                               temperature=0.7,
                               repetition_penalty=1.15)
#outputs = model(input_ids=input_ids),# max_length=cut_len, min_length=8, top_p=0.9, do_sample=True)
output = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
print(output)

TypeError: template() missing 1 required positional argument: 'model_type'