In [1]:
import torch
import pandas as pd
import numpy
from functools import partial
from pathlib import Path
from huggingface_hub import constants as hub_c
from datasets import Dataset
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from transformers import Trainer, TrainingArguments, AutoTokenizer, DataCollatorForLanguageModeling
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from evals import evaluate, load_eval_dataset
from utils import DatasetArgs, get_dataset_args

assert torch.cuda.is_available(), "CUDA not available"
device = torch.device("cuda")

seed = 42
torch.manual_seed(seed)

model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype="float16")

dataset_id = "FPB"

### Dataset setup
Run the cell for eval as well

In [2]:
args = get_dataset_args(tokenizer, Path(hub_c.HF_HUB_CACHE))

### Dataset preprocessing

In [3]:
def train_preprocess(args: DatasetArgs, example: dict):
    # Create prompt and target text
    prompt_args = [example[key] for key in args.prompt_args[dataset_id]]
    prompt = args.prompt_templates[dataset_id].format(*prompt_args)

    target = args.id2labels[dataset_id][example["label"]]
    
    # tokenize text
    return tokenizer(prompt + target, truncation=False)

In [None]:
dataset_path = args.paths[dataset_id]
df_dataset = pd.read_csv(dataset_path / "train.csv",
                         delimiter=args.del_mapping[dataset_id],
                         names=args.names_mapping[dataset_id])

preprocess_func = partial(train_preprocess, args)
dataset = (Dataset
            .from_pandas(df_dataset)
            .map(preprocess_func,
                 batched=False,
                 remove_columns=args.columns[dataset_id])
            .filter(lambda sample: len(sample["input_ids"]) <= args.max_length))

dataset = dataset.train_test_split(test_size=0.1, seed=seed)

## LoRA Setup

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.01,
    target_modules=["q_proj", "v_proj"]
)

peft_model = get_peft_model(base_model, peft_config)
peft_model.print_trainable_parameters()

trainable params: 5,636,096 || all params: 1,241,450,496 || trainable%: 0.4540


## Trainer setup

In [6]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)

out_dir = Path(rf"D:/models/expert-Llama-3_2-1B-{dataset_id}-expert256")
training_args = TrainingArguments(
    output_dir=str(out_dir),
    num_train_epochs=12,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-3,
    weight_decay=0.01,
    warmup_steps=128,
    logging_steps=32,
    save_steps=128,
    save_strategy="steps",
    eval_steps=128,
    eval_strategy="steps",
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
)

In [7]:
trainer.train()

  0%|          | 0/1308 [00:00<?, ?it/s]

{'loss': 3.5022, 'grad_norm': 1.8773185014724731, 'learning_rate': 0.00025, 'epoch': 0.29}
{'loss': 2.1105, 'grad_norm': 1.111354112625122, 'learning_rate': 0.0005, 'epoch': 0.59}
{'loss': 1.8909, 'grad_norm': 0.83847975730896, 'learning_rate': 0.00075, 'epoch': 0.88}
{'loss': 1.8484, 'grad_norm': 0.9039217829704285, 'learning_rate': 0.001, 'epoch': 1.17}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 1.8227686882019043, 'eval_runtime': 2.0506, 'eval_samples_per_second': 189.21, 'eval_steps_per_second': 23.895, 'epoch': 1.17}
{'loss': 1.8198, 'grad_norm': 0.7252324223518372, 'learning_rate': 0.0009728813559322034, 'epoch': 1.47}
{'loss': 1.7596, 'grad_norm': 0.9131329655647278, 'learning_rate': 0.0009457627118644068, 'epoch': 1.76}
{'loss': 1.7466, 'grad_norm': 0.8325889706611633, 'learning_rate': 0.0009186440677966102, 'epoch': 2.06}
{'loss': 1.6769, 'grad_norm': 0.9228724837303162, 'learning_rate': 0.0008915254237288136, 'epoch': 2.35}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 1.762556552886963, 'eval_runtime': 2.0346, 'eval_samples_per_second': 190.702, 'eval_steps_per_second': 24.084, 'epoch': 2.35}
{'loss': 1.6605, 'grad_norm': 0.7208706736564636, 'learning_rate': 0.000864406779661017, 'epoch': 2.64}
{'loss': 1.6715, 'grad_norm': 0.7269070148468018, 'learning_rate': 0.0008372881355932204, 'epoch': 2.94}
{'loss': 1.5726, 'grad_norm': 0.7768827080726624, 'learning_rate': 0.0008101694915254238, 'epoch': 3.23}
{'loss': 1.5561, 'grad_norm': 0.8208379149436951, 'learning_rate': 0.0007830508474576272, 'epoch': 3.52}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 1.7494487762451172, 'eval_runtime': 2.1111, 'eval_samples_per_second': 183.793, 'eval_steps_per_second': 23.211, 'epoch': 3.52}
{'loss': 1.5934, 'grad_norm': 0.8654942512512207, 'learning_rate': 0.0007559322033898304, 'epoch': 3.82}
{'loss': 1.5172, 'grad_norm': 0.8670519590377808, 'learning_rate': 0.0007288135593220338, 'epoch': 4.11}
{'loss': 1.4442, 'grad_norm': 0.962871253490448, 'learning_rate': 0.0007016949152542373, 'epoch': 4.4}
{'loss': 1.4476, 'grad_norm': 1.011186122894287, 'learning_rate': 0.0006745762711864407, 'epoch': 4.7}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 1.7689759731292725, 'eval_runtime': 2.0289, 'eval_samples_per_second': 191.239, 'eval_steps_per_second': 24.151, 'epoch': 4.7}
{'loss': 1.4766, 'grad_norm': 1.066024899482727, 'learning_rate': 0.0006474576271186441, 'epoch': 4.99}
{'loss': 1.3497, 'grad_norm': 1.3634631633758545, 'learning_rate': 0.0006203389830508474, 'epoch': 5.28}
{'loss': 1.3504, 'grad_norm': 1.041918158531189, 'learning_rate': 0.0005932203389830508, 'epoch': 5.58}
{'loss': 1.3783, 'grad_norm': 1.1129660606384277, 'learning_rate': 0.0005661016949152542, 'epoch': 5.87}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 1.8163354396820068, 'eval_runtime': 2.1143, 'eval_samples_per_second': 183.513, 'eval_steps_per_second': 23.176, 'epoch': 5.87}
{'loss': 1.3077, 'grad_norm': 0.9928992390632629, 'learning_rate': 0.0005389830508474577, 'epoch': 6.17}
{'loss': 1.238, 'grad_norm': 1.202432632446289, 'learning_rate': 0.0005118644067796611, 'epoch': 6.46}
{'loss': 1.2591, 'grad_norm': 1.1913760900497437, 'learning_rate': 0.00048474576271186445, 'epoch': 6.75}
{'loss': 1.268, 'grad_norm': 1.150791883468628, 'learning_rate': 0.0004576271186440678, 'epoch': 7.05}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 1.925498604774475, 'eval_runtime': 2.0819, 'eval_samples_per_second': 186.367, 'eval_steps_per_second': 23.536, 'epoch': 7.05}
{'loss': 1.1477, 'grad_norm': 1.194187879562378, 'learning_rate': 0.0004305084745762712, 'epoch': 7.34}
{'loss': 1.1758, 'grad_norm': 1.207614541053772, 'learning_rate': 0.0004033898305084746, 'epoch': 7.63}
{'loss': 1.1808, 'grad_norm': 1.2369171380996704, 'learning_rate': 0.00037627118644067796, 'epoch': 7.93}
{'loss': 1.0966, 'grad_norm': 1.3432742357254028, 'learning_rate': 0.00034915254237288134, 'epoch': 8.22}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 2.0290937423706055, 'eval_runtime': 2.4131, 'eval_samples_per_second': 160.788, 'eval_steps_per_second': 20.306, 'epoch': 8.22}
{'loss': 1.0607, 'grad_norm': 1.5493149757385254, 'learning_rate': 0.0003220338983050847, 'epoch': 8.51}
{'loss': 1.1015, 'grad_norm': 1.2644352912902832, 'learning_rate': 0.00029491525423728815, 'epoch': 8.81}
{'loss': 1.0602, 'grad_norm': 1.3782566785812378, 'learning_rate': 0.00026779661016949153, 'epoch': 9.1}
{'loss': 0.9908, 'grad_norm': 1.322336196899414, 'learning_rate': 0.00024067796610169494, 'epoch': 9.39}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 2.116011142730713, 'eval_runtime': 2.02, 'eval_samples_per_second': 192.076, 'eval_steps_per_second': 24.257, 'epoch': 9.39}
{'loss': 1.0188, 'grad_norm': 1.4284334182739258, 'learning_rate': 0.0002135593220338983, 'epoch': 9.69}
{'loss': 1.0046, 'grad_norm': 1.4194382429122925, 'learning_rate': 0.0001864406779661017, 'epoch': 9.98}
{'loss': 0.9215, 'grad_norm': 1.6090210676193237, 'learning_rate': 0.00015932203389830508, 'epoch': 10.28}
{'loss': 0.9493, 'grad_norm': 1.4917473793029785, 'learning_rate': 0.00013220338983050849, 'epoch': 10.57}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 2.1924102306365967, 'eval_runtime': 2.0669, 'eval_samples_per_second': 187.716, 'eval_steps_per_second': 23.706, 'epoch': 10.57}
{'loss': 0.949, 'grad_norm': 1.5981618165969849, 'learning_rate': 0.00010508474576271186, 'epoch': 10.86}
{'loss': 0.9095, 'grad_norm': 1.636603832244873, 'learning_rate': 7.796610169491526e-05, 'epoch': 11.16}
{'loss': 0.8865, 'grad_norm': 1.5029820203781128, 'learning_rate': 5.084745762711865e-05, 'epoch': 11.45}
{'loss': 0.9041, 'grad_norm': 1.4532837867736816, 'learning_rate': 2.3728813559322036e-05, 'epoch': 11.74}


  0%|          | 0/49 [00:00<?, ?it/s]

{'eval_loss': 2.2749266624450684, 'eval_runtime': 2.2017, 'eval_samples_per_second': 176.227, 'eval_steps_per_second': 22.255, 'epoch': 11.74}
{'train_runtime': 662.2032, 'train_samples_per_second': 63.207, 'train_steps_per_second': 1.975, 'train_loss': 1.3838878829909391, 'epoch': 12.0}


TrainOutput(global_step=1308, training_loss=1.3838878829909391, metrics={'train_runtime': 662.2032, 'train_samples_per_second': 63.207, 'train_steps_per_second': 1.975, 'total_flos': 1.74494483226624e+16, 'train_loss': 1.3838878829909391, 'epoch': 12.0})

# Eval

In [None]:
dataset_id = "FPB"
ckpt_path = Path(rf"D:/models/expert-Llama-3_2-1B-{dataset_id}-expert256") / "checkpoint-4620"

base_model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype="float16").eval()
expert_model = PeftModel.from_pretrained(base_model, ckpt_path, torch_dtype="float16").eval().to(device)

In [8]:
testset = load_eval_dataset(tokenizer, dataset_id, args)
results = evaluate(peft_model, tokenizer,
                   testset,
                   guidance=True,
                   token_opts=args.token_opts[dataset_id])

Loading FPB dataset from AdaptLLM/finance-tasks


60.21: 100%|██████████| 970/970 [00:37<00:00, 25.57it/s]
