In [1]:
import torch
import pandas as pd
from functools import partial
from pathlib import Path
from datasets import Dataset
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from transformers import Trainer, TrainingArguments, AutoTokenizer, DataCollatorForLanguageModeling
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from evals import evaluate, load_eval_dataset
from utils import DatasetArgs, get_dataset_args

assert torch.cuda.is_available(), "CUDA not available"
device = torch.device("cuda")

model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype="float16")

dataset_id = "FPB"

### Dataset setup
Run the cell for eval as well

In [2]:
hub_basepath = Path(r"C:/Users/samba/.cache/huggingface/hub")
args = get_dataset_args(tokenizer, hub_basepath)

### Dataset preprocessing

In [None]:
def train_preprocess(args: DatasetArgs, example: dict):
    # Create prompt and target text
    args = [example[key] for key in args.prompt_args[dataset_id]]
    prompt = args.prompt_templates[dataset_id].format(*args)

    target = args.id2labels[dataset_id][example["label"]]
    full_text = prompt + target

    # tokenize text
    tokenized = tokenizer(full_text,
                          truncation=True,
                          padding="max_length",
                          max_length=args.max_length)
    
    # add padding tokens
    prompt_tokenized = tokenizer(prompt,
                              truncation=True,
                              max_length=args.max_length)
    prompt_length = len(prompt_tokenized["input_ids"])

    labels = tokenized["input_ids"].copy()
    labels[:prompt_length] = [-100] * prompt_length
    tokenized["labels"] = labels
    return tokenized

In [None]:
dataset_path = args.paths[dataset_id]
train_dataset = pd.read_csv(dataset_path / "train.csv",
                            delimiter=args.del_mapping[dataset_id],
                            names=args.names_mapping[dataset_id])

preprocess_func = partial(train_preprocess, args)
train_dataset = (Dataset
                 .from_pandas(train_dataset)
                 .map(preprocess_func, batched=False)
                 .remove_columns(args.columns[dataset_id]))

## LoRA Setup

In [7]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.01,
    target_modules=["q_proj", "v_proj"]
)

peft_model = get_peft_model(base_model, peft_config)
peft_model.print_trainable_parameters()

trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689


## LoRA Continue Training

In [None]:
lora_ckpt = Path(rf"D:/models/expert-Llama-3_2-1B-{dataset_id}") / "checkpoint-best"

peft_model = PeftModel.from_pretrained(
    base_model,
    lora_ckpt
).to(device).train()

## Trainer setup

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)

out_dir = Path(rf"D:/models/expert-Llama-3_2-1B-{dataset_id}")
training_args = TrainingArguments(
    output_dir=str(out_dir),
    num_train_epochs=6,
    per_device_train_batch_size=2,
    learning_rate=1e-3,
    weight_decay=0.01,
    warmup_steps=5000,
    save_strategy="epoch",
    do_train=True,
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)

In [None]:
trainer.train()

# Eval

In [3]:
dataset_id = "Topics"
ckpt_path = Path(rf"D:/models/expert-Llama-3_2-1B-{dataset_id}") / "checkpoint-best"

base_model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype="float16").eval()
expert_model = PeftModel.from_pretrained(base_model, ckpt_path, torch_dtype="float16").eval().to(device)

In [4]:
testset = load_eval_dataset(tokenizer, dataset_id, args)
results = evaluate(expert_model, tokenizer,
                   testset,
                   guidance=True,
                   token_opts=args.token_opts[dataset_id])

Loading Topics from path C:\Users\samba\.cache\huggingface\hub\datasets--Sujet--TopicClassification


Map:   0%|          | 0/850 [00:00<?, ? examples/s]

Map:   0%|          | 0/850 [00:00<?, ? examples/s]

91.18: 100%|██████████| 850/850 [00:33<00:00, 25.43it/s]
