In [1]:
# %pip install -q -U bitsandbytes
# %pip install -q -U git+https://github.com/huggingface/transformers.git
# %pip install -q -U git+https://github.com/huggingface/peft.git
# %pip install -q -U git+https://github.com/huggingface/accelerate.git
# %pip install -q -U trl

In [2]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer

In [3]:
peft_config = LoraConfig(
    lora_alpha=8,
    lora_dropout=0.1,
    r=8,
    bias="none",
    task_type="CAUSAL_LM",
)

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

In [5]:
# Load the entire model on the GPU 0
device_map = {"": 0}

In [6]:
from datasets import Dataset
import pandas as pd

# Load the data using pandas
data_file = "../data/ARCSolver_move_obj_puzzles_100000.json"
df = pd.read_json(data_file)

# Convert the pandas dataframe to a dataset
dataset = Dataset.from_pandas(df)

def generate_prompt(data_point):
    text = '<s>[INST] ' + data_point["instruction"] + ' [/INST] ' + str(data_point["output"]) + '</s>'
    return text

# add the "prompt" column in the dataset
text_column = [generate_prompt(data_point) for data_point in dataset]
dataset = dataset.add_column("prompt", text_column)

In [7]:
# Load base model
model_name = "../merged_models/merged_model_step_1"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)
model.config.use_cache = False
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

# Load LLaMA tokenizer
base_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(base_model_name, add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



In [8]:
per_device_train_batch_size = 1
gradient_accumulation_steps = 4
max_seq_length = 1250

new_adapter_name = "Finetuned_merge_001_v001"

output_dir = "../results/" + new_adapter_name

steps_per_epoch = len(dataset)//(per_device_train_batch_size*gradient_accumulation_steps)
print("Steps per epoch:", steps_per_epoch)

# Set training parameters
training_arguments = TrainingArguments(
    output_dir=output_dir,
    max_steps=1000,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim="paged_adamw_8bit",
    save_strategy="steps",
    evaluation_strategy="no",
    save_steps=100,
    logging_steps=1,
    learning_rate=2e-4,
    fp16=True,
    warmup_steps=0.03,
    group_by_length=True,
    gradient_checkpointing=True,
)

# Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="prompt",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)

Steps per epoch: 25000
Total steps: 25000


Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [9]:
# Train model
trainer.train()

  0%|          | 0/1000 [00:00<?, ?it/s]

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 0.881, 'learning_rate': 0.0001998059941798254, 'epoch': 0.0}
{'loss': 0.9062, 'learning_rate': 0.0001996059881796454, 'epoch': 0.0}
{'loss': 0.8805, 'learning_rate': 0.0001996059881796454, 'epoch': 0.0}
{'loss': 0.8698, 'learning_rate': 0.0001996059881796454, 'epoch': 0.0}
{'loss': 0.8238, 'learning_rate': 0.0001996059881796454, 'epoch': 0.0}
{'loss': 0.8857, 'learning_rate': 0.00019940598217946538, 'epoch': 0.0}
{'loss': 0.8357, 'learning_rate': 0.00019920597617928537, 'epoch': 0.0}
{'loss': 0.7429, 'learning_rate': 0.00019900597017910538, 'epoch': 0.0}
{'loss': 0.646, 'learning_rate': 0.00019880596417892537, 'epoch': 0.0}
{'loss': 0.6545, 'learning_rate': 0.00019860595817874536, 'epoch': 0.0}
{'loss': 0.594, 'learning_rate': 0.00019840595217856537, 'epoch': 0.0}
{'loss': 0.5464, 'learning_rate': 0.00019820594617838536, 'epoch': 0.0}
{'loss': 0.5189, 'learning_rate': 0.00019800594017820535, 'epoch': 0.0}
{'loss': 0.5051, 'learning_rate': 0.00019780593417802536, 'epoch': 0.0}




{'loss': 0.1148, 'learning_rate': 0.00018040541216236488, 'epoch': 0.0}
{'loss': 0.134, 'learning_rate': 0.00018020540616218487, 'epoch': 0.0}
{'loss': 0.1302, 'learning_rate': 0.00018000540016200486, 'epoch': 0.0}
{'loss': 0.1507, 'learning_rate': 0.00017980539416182487, 'epoch': 0.0}
{'loss': 0.1327, 'learning_rate': 0.00017960538816164486, 'epoch': 0.0}
{'loss': 0.1414, 'learning_rate': 0.00017940538216146485, 'epoch': 0.0}
{'loss': 0.1693, 'learning_rate': 0.00017920537616128484, 'epoch': 0.0}
{'loss': 0.1252, 'learning_rate': 0.00017900537016110482, 'epoch': 0.0}
{'loss': 0.15, 'learning_rate': 0.00017880536416092484, 'epoch': 0.0}
{'loss': 0.1451, 'learning_rate': 0.00017860535816074482, 'epoch': 0.0}
{'loss': 0.1565, 'learning_rate': 0.0001784053521605648, 'epoch': 0.0}
{'loss': 0.1511, 'learning_rate': 0.00017820534616038483, 'epoch': 0.0}
{'loss': 0.1549, 'learning_rate': 0.00017800534016020481, 'epoch': 0.0}
{'loss': 0.1466, 'learning_rate': 0.0001778053341600248, 'epoch': 0.



{'loss': 0.1258, 'learning_rate': 0.00016040481214436432, 'epoch': 0.01}
{'loss': 0.1364, 'learning_rate': 0.00016020480614418434, 'epoch': 0.01}
{'loss': 0.133, 'learning_rate': 0.00016000480014400433, 'epoch': 0.01}
{'loss': 0.1482, 'learning_rate': 0.0001598047941438243, 'epoch': 0.01}
{'loss': 0.1403, 'learning_rate': 0.0001596047881436443, 'epoch': 0.01}
{'loss': 0.1222, 'learning_rate': 0.00015940478214346431, 'epoch': 0.01}
{'loss': 0.1336, 'learning_rate': 0.0001592047761432843, 'epoch': 0.01}
{'loss': 0.13, 'learning_rate': 0.0001590047701431043, 'epoch': 0.01}
{'loss': 0.1304, 'learning_rate': 0.00015880476414292428, 'epoch': 0.01}
{'loss': 0.1252, 'learning_rate': 0.0001586047581427443, 'epoch': 0.01}
{'loss': 0.1247, 'learning_rate': 0.00015840475214256428, 'epoch': 0.01}
{'loss': 0.1376, 'learning_rate': 0.00015820474614238427, 'epoch': 0.01}
{'loss': 0.1236, 'learning_rate': 0.00015800474014220428, 'epoch': 0.01}
{'loss': 0.1607, 'learning_rate': 0.00015780473414202427, '



{'loss': 0.1237, 'learning_rate': 0.0001406042181265438, 'epoch': 0.01}
{'loss': 0.1277, 'learning_rate': 0.0001404042121263638, 'epoch': 0.01}
{'loss': 0.1204, 'learning_rate': 0.00014020420612618378, 'epoch': 0.01}
{'loss': 0.1424, 'learning_rate': 0.0001400042001260038, 'epoch': 0.01}
{'loss': 0.1224, 'learning_rate': 0.00013980419412582378, 'epoch': 0.01}
{'loss': 0.1271, 'learning_rate': 0.00013960418812564377, 'epoch': 0.01}
{'loss': 0.133, 'learning_rate': 0.00013940418212546375, 'epoch': 0.01}
{'loss': 0.1041, 'learning_rate': 0.00013920417612528377, 'epoch': 0.01}
{'loss': 0.1378, 'learning_rate': 0.00013900417012510376, 'epoch': 0.01}
{'loss': 0.1419, 'learning_rate': 0.00013880416412492374, 'epoch': 0.01}
{'loss': 0.1551, 'learning_rate': 0.00013860415812474376, 'epoch': 0.01}
{'loss': 0.145, 'learning_rate': 0.00013840415212456374, 'epoch': 0.01}
{'loss': 0.1758, 'learning_rate': 0.00013820414612438373, 'epoch': 0.01}
{'loss': 0.1333, 'learning_rate': 0.00013800414012420372



{'loss': 0.1268, 'learning_rate': 0.00012060361810854325, 'epoch': 0.02}
{'loss': 0.1571, 'learning_rate': 0.00012040361210836326, 'epoch': 0.02}
{'loss': 0.1243, 'learning_rate': 0.00012020360610818324, 'epoch': 0.02}
{'loss': 0.1301, 'learning_rate': 0.00012000360010800324, 'epoch': 0.02}
{'loss': 0.1506, 'learning_rate': 0.00011980359410782324, 'epoch': 0.02}
{'loss': 0.1404, 'learning_rate': 0.00011960358810764323, 'epoch': 0.02}
{'loss': 0.1374, 'learning_rate': 0.00011940358210746323, 'epoch': 0.02}
{'loss': 0.1315, 'learning_rate': 0.00011920357610728321, 'epoch': 0.02}
{'loss': 0.1408, 'learning_rate': 0.00011900357010710321, 'epoch': 0.02}
{'loss': 0.1242, 'learning_rate': 0.00011880356410692321, 'epoch': 0.02}
{'loss': 0.139, 'learning_rate': 0.0001186035581067432, 'epoch': 0.02}
{'loss': 0.1417, 'learning_rate': 0.0001184035521065632, 'epoch': 0.02}
{'loss': 0.1458, 'learning_rate': 0.0001182035461063832, 'epoch': 0.02}
{'loss': 0.1113, 'learning_rate': 0.0001180035401062031



{'loss': 0.1402, 'learning_rate': 0.00010060301809054271, 'epoch': 0.02}
{'loss': 0.1271, 'learning_rate': 0.00010040301209036271, 'epoch': 0.02}
{'loss': 0.1388, 'learning_rate': 0.00010020300609018272, 'epoch': 0.02}
{'loss': 0.1231, 'learning_rate': 0.0001000030000900027, 'epoch': 0.02}
{'loss': 0.1606, 'learning_rate': 9.98029940898227e-05, 'epoch': 0.02}
{'loss': 0.1234, 'learning_rate': 9.960298808964268e-05, 'epoch': 0.02}
{'loss': 0.1346, 'learning_rate': 9.940298208946269e-05, 'epoch': 0.02}
{'loss': 0.136, 'learning_rate': 9.920297608928269e-05, 'epoch': 0.02}
{'loss': 0.1378, 'learning_rate': 9.900297008910267e-05, 'epoch': 0.02}
{'loss': 0.1159, 'learning_rate': 9.880296408892266e-05, 'epoch': 0.02}
{'loss': 0.1425, 'learning_rate': 9.860295808874268e-05, 'epoch': 0.02}
{'loss': 0.1382, 'learning_rate': 9.840295208856266e-05, 'epoch': 0.02}
{'loss': 0.1344, 'learning_rate': 9.820294608838265e-05, 'epoch': 0.02}
{'loss': 0.1667, 'learning_rate': 9.800294008820264e-05, 'epoch



{'loss': 0.1127, 'learning_rate': 8.060241807254217e-05, 'epoch': 0.02}
{'loss': 0.1721, 'learning_rate': 8.040241207236217e-05, 'epoch': 0.02}
{'loss': 0.1556, 'learning_rate': 8.020240607218216e-05, 'epoch': 0.02}
{'loss': 0.1235, 'learning_rate': 8.000240007200216e-05, 'epoch': 0.02}
{'loss': 0.1332, 'learning_rate': 7.980239407182215e-05, 'epoch': 0.02}
{'loss': 0.1331, 'learning_rate': 7.960238807164215e-05, 'epoch': 0.02}
{'loss': 0.1115, 'learning_rate': 7.940238207146214e-05, 'epoch': 0.02}
{'loss': 0.1105, 'learning_rate': 7.920237607128214e-05, 'epoch': 0.02}
{'loss': 0.1279, 'learning_rate': 7.900237007110214e-05, 'epoch': 0.02}
{'loss': 0.1219, 'learning_rate': 7.880236407092213e-05, 'epoch': 0.02}
{'loss': 0.1348, 'learning_rate': 7.860235807074211e-05, 'epoch': 0.02}
{'loss': 0.1342, 'learning_rate': 7.840235207056213e-05, 'epoch': 0.02}
{'loss': 0.1473, 'learning_rate': 7.820234607038212e-05, 'epoch': 0.02}
{'loss': 0.128, 'learning_rate': 7.80023400702021e-05, 'epoch': 



{'loss': 0.1387, 'learning_rate': 6.060181805454164e-05, 'epoch': 0.03}
{'loss': 0.1203, 'learning_rate': 6.040181205436163e-05, 'epoch': 0.03}
{'loss': 0.1304, 'learning_rate': 6.020180605418163e-05, 'epoch': 0.03}
{'loss': 0.116, 'learning_rate': 6.000180005400162e-05, 'epoch': 0.03}
{'loss': 0.1202, 'learning_rate': 5.9801794053821616e-05, 'epoch': 0.03}
{'loss': 0.1124, 'learning_rate': 5.9601788053641603e-05, 'epoch': 0.03}
{'loss': 0.113, 'learning_rate': 5.9401782053461604e-05, 'epoch': 0.03}
{'loss': 0.1379, 'learning_rate': 5.92017760532816e-05, 'epoch': 0.03}
{'loss': 0.1292, 'learning_rate': 5.900177005310159e-05, 'epoch': 0.03}
{'loss': 0.1315, 'learning_rate': 5.8801764052921594e-05, 'epoch': 0.03}
{'loss': 0.1466, 'learning_rate': 5.860175805274159e-05, 'epoch': 0.03}
{'loss': 0.1496, 'learning_rate': 5.8401752052561575e-05, 'epoch': 0.03}
{'loss': 0.1306, 'learning_rate': 5.820174605238157e-05, 'epoch': 0.03}
{'loss': 0.1498, 'learning_rate': 5.800174005220157e-05, 'epoc



{'loss': 0.1259, 'learning_rate': 4.060121803654109e-05, 'epoch': 0.03}
{'loss': 0.1127, 'learning_rate': 4.040121203636109e-05, 'epoch': 0.03}
{'loss': 0.1022, 'learning_rate': 4.020120603618109e-05, 'epoch': 0.03}
{'loss': 0.1307, 'learning_rate': 4.000120003600108e-05, 'epoch': 0.03}
{'loss': 0.1328, 'learning_rate': 3.9801194035821075e-05, 'epoch': 0.03}
{'loss': 0.1368, 'learning_rate': 3.960118803564107e-05, 'epoch': 0.03}
{'loss': 0.119, 'learning_rate': 3.9401182035461064e-05, 'epoch': 0.03}
{'loss': 0.1356, 'learning_rate': 3.9201176035281065e-05, 'epoch': 0.03}
{'loss': 0.115, 'learning_rate': 3.900117003510105e-05, 'epoch': 0.03}
{'loss': 0.1377, 'learning_rate': 3.880116403492105e-05, 'epoch': 0.03}
{'loss': 0.1593, 'learning_rate': 3.860115803474104e-05, 'epoch': 0.03}
{'loss': 0.1135, 'learning_rate': 3.840115203456104e-05, 'epoch': 0.03}
{'loss': 0.139, 'learning_rate': 3.820114603438103e-05, 'epoch': 0.03}
{'loss': 0.1219, 'learning_rate': 3.800114003420103e-05, 'epoch'



{'loss': 0.109, 'learning_rate': 2.060061801854056e-05, 'epoch': 0.04}
{'loss': 0.1134, 'learning_rate': 2.0400612018360552e-05, 'epoch': 0.04}
{'loss': 0.1549, 'learning_rate': 2.0200606018180547e-05, 'epoch': 0.04}
{'loss': 0.1075, 'learning_rate': 2.000060001800054e-05, 'epoch': 0.04}
{'loss': 0.1387, 'learning_rate': 1.9800594017820535e-05, 'epoch': 0.04}
{'loss': 0.1263, 'learning_rate': 1.9600588017640532e-05, 'epoch': 0.04}
{'loss': 0.1288, 'learning_rate': 1.9400582017460526e-05, 'epoch': 0.04}
{'loss': 0.1396, 'learning_rate': 1.920057601728052e-05, 'epoch': 0.04}
{'loss': 0.1173, 'learning_rate': 1.9000570017100515e-05, 'epoch': 0.04}
{'loss': 0.1243, 'learning_rate': 1.880056401692051e-05, 'epoch': 0.04}
{'loss': 0.1051, 'learning_rate': 1.8600558016740503e-05, 'epoch': 0.04}
{'loss': 0.1336, 'learning_rate': 1.8400552016560497e-05, 'epoch': 0.04}
{'loss': 0.1191, 'learning_rate': 1.8200546016380495e-05, 'epoch': 0.04}
{'loss': 0.1198, 'learning_rate': 1.800054001620049e-05,

TrainOutput(global_step=1000, training_loss=0.1563199037387967, metrics={'train_runtime': 13860.6359, 'train_samples_per_second': 0.289, 'train_steps_per_second': 0.072, 'train_loss': 0.1563199037387967, 'epoch': 0.04})

In [10]:
# Save trained model
adapter_from_merged = "../adapters/" + new_adapter_name
trainer.model.save_pretrained(adapter_from_merged)