In [None]:
import torch
from pathlib import Path
from huggingface_hub import constants as hub_c
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from transformers import Trainer, TrainingArguments, AutoTokenizer, DataCollatorForLanguageModeling
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from evals import load_eval_dataset, evaluate
from utils import DatasetArgs, get_dataset_args, load_train_datasets

assert torch.cuda.is_available(), "CUDA not available"
device = torch.device("cuda")

seed = 42
torch.manual_seed(seed)

dataset_cache_path = Path(r"D:/datasets/general-3-tasks")

model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

args = get_dataset_args(tokenizer, Path(hub_c.HF_HUB_CACHE))

## Load and Preprocess Dataset

In [None]:
def train_preprocess_causal(args: DatasetArgs, dataset_id: str, example: dict):
    # Create prompt and target text
    prompt_args = [example[key] for key in args.prompt_args[dataset_id]]
    prompt = args.prompt_templates[dataset_id].format(*prompt_args)

    target = args.id2labels[dataset_id][example["label"]]

    return tokenizer(prompt + target, truncation=False)

In [None]:
train_dataset = load_train_datasets(args, train_preprocess_causal)

## LoRA Setup

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.01,
    target_modules=["q_proj", "v_proj"]
)

model_id = "meta-llama/Llama-3.2-1B"
base_model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype="float16")

peft_model = get_peft_model(base_model, peft_config)
peft_model.print_trainable_parameters()

## Train

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)

out_dir = Path(rf"D:/models/general-Llama-3_2-1B") ## OUTPUT DIRECTORY
training_args = TrainingArguments(
    output_dir=str(out_dir),
    num_train_epochs=5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-3,
    weight_decay=0.01,
    warmup_steps=128,
    logging_steps=32,
    save_strategy="epoch",
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)

trainer.train()

# Evals

In [None]:
dataset_id = "FPB"
testset = load_eval_dataset(tokenizer, dataset_id, args)

### Evaluate general model checkpoint

In [None]:
ckpt_path = Path(r"D:/models/general-Llama-3_2-1B") / "checkpoint-28725"

model_id = "meta-llama/Llama-3.2-1B"
base_model = LlamaForCausalLM.from_pretrained(model_id, torch_dtype="float16")
expert_model = PeftModel.from_pretrained(base_model, ckpt_path, torch_dtype="float16").eval().to(device)

In [None]:
results = evaluate(expert_model, tokenizer,
                   testset,
                   args.token_opts[dataset_id])
print(results)