In [None]:
import gc
import json
from glob import glob
from pathlib import Path
from typing import Any

import numpy as np
import polars as pl
import torch
from datasets import Dataset
from omegaconf import OmegaConf
from peft import AutoPeftModelForCausalLM, LoraConfig, prepare_model_for_kbit_training
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, set_seed
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer


In [None]:
set_seed(386)

#use bf16 and FlashAttention if supported
compute_dtype = torch.bfloat16
attn_implementation = "flash_attention_2"

## Configuration

In [None]:
config = OmegaConf.load("config.yaml")

In [None]:
MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"

SEQ_LENGTH = config.max_seq_length
TRAINING_SEQ_LENGTH = config.prompt_max_seq_length
DATA_ROOT = Path(config.data_dir)

OUT_DIR_NAME = Path("./llama33_70B")

In [None]:
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, quantization_config=bnb_config, device_map="auto", attn_implementation=attn_implementation, torch_dtype=torch.bfloat16
)

In [None]:
model = prepare_model_for_kbit_training(model, gradient_checkpointing_kwargs={"use_reentrant": True})
model.config.use_cache = False

peft_config = LoraConfig(
    lora_alpha=32,
    lora_dropout=0.05,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["o_proj", "k_proj", "q_proj", "down_proj", "gate_proj", "up_proj", "v_proj"],
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    add_eos_token=True,
    add_bos_token=True,
    )

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "right"

## Data Prep

In [None]:
# Load the dataset prepared in `prepare_dataset.ipynb`
train_long = pl.read_parquet("./df_train.parquet")
synth_long = pl.read_parquet("./df_synth.parquet")

In [None]:
train_long = train_long.with_columns(
    MisconceptionId = train_long["MisconceptionId"].cast(pl.Int64),
)

In [None]:
# Feature Engineering: Information related to the problem's category
train_long = train_long.with_columns(
    pl.col("MisconceptionId_pred").str.split(" ").list.eval(pl.element().cast(pl.Int32())),
    knowledge = pl.concat_str(
        [pl.col("FirstSubjectName"), pl.col("SecondSubjectName"),pl.col("ThirdSubjectName"), pl.col("ConstructName")],
        separator=",",
    ),
)

synth_long = synth_long.with_columns(
    pl.col("MisconceptionId_pred").str.split(" ").list.eval(pl.element().cast(pl.Int32())),
    knowledge = pl.concat_str(
        [pl.col("FirstSubjectName"), pl.col("SecondSubjectName"),pl.col("ThirdSubjectName"), pl.col("ConstructName")],
        separator=",",
    ),
)

In [None]:
df_misconceptions = pl.read_csv(DATA_ROOT / "misconception_mapping.csv")
misconception_map = dict(df_misconceptions.rows())

## Make prompts

In [None]:
report_format = """# OBJECTIVE #
Based on the problem, please select which of the following misconceptions might have led to the student's incorrect answer.

Categories: {knowledge}
Question: {question_tex}
Correct Answer: {correct_answer_text}
Student's Answer: {answer_text}
Step-by-Step Solution: {cot_misunderstanding}

Possible Misconceptions:\n{misconceptions_text}

Please respond only with the letter(s) of the misconception(s) you think are most likely to be the cause."""  # noqa: E501


def make_prompt(row: dict[str, Any]) -> tuple[str, str]:
    # Randomly selects a candidate from the top 5 to top 25
    k = np.random.choice([5, 10, 15, 25])
    target_misconceptions = row["MisconceptionId_pred"]
    target_misconceptions = list(target_misconceptions[:k])

    # If the correct answer is not within the Top N options,
    # randomly replace one of the options with the correct answer.
    if row["MisconceptionId"] not in target_misconceptions:
        idx = np.random.randint(0, k)
        target_misconceptions[idx] = row["MisconceptionId"]

    np.random.shuffle(target_misconceptions)

    # Generate the initials
    letters = [chr(i+65) for i in range(k)]
    misconception_names = [misconception_map[_id] for _id in target_misconceptions]
    ans_pos = target_misconceptions.index(row["MisconceptionId"])
    ans_letter = letters[ans_pos]
    choice_str = "\n".join([f"{a}. {n}" for a, n in zip(letters, misconception_names, strict=False)])

    # Prepare a dictionary for validation
    choices = dict(zip(letters, target_misconceptions, strict=False))

    # Make prompt
    q = report_format.format(
        knowledge=row["knowledge"],
        question_tex=row["QuestionText"],
        construct=row["ConstructName"],
        correct_answer_text=row["CorrectAnswerText"],
        answer_text=row["AnswerText"],
        cot_misunderstanding=row["p000-qwen25-32b-instruct-cot_misunderstanding"],
        misconceptions_text=choice_str,
    )
    context = [
        {"role": "system", "content": "You are a mathematics teacher. Your job is to infer and identify the misconceptions behind the incorrect answers to the questions."},
        {
            "role": "user",
            "content": q,
        },
        {
            "role": "assistant",
            "content": ans_letter,
        },
    ]
    prompt: str = tokenizer.apply_chat_template(context, tokenize=False, add_generation_prompt=False)
    return prompt, json.dumps(choices)


In [None]:
# epoch1
np.random.seed(8996)
results = [make_prompt(row) for row in train_long.rows(named=True)]
train_long = train_long.with_columns(
    whole_prompt = pl.Series([ret[0] for ret in results]),
    choices = pl.Series([ret[1] for ret in results]),
)

# epoch2
np.random.seed(537)
results = [make_prompt(row) for row in train_long.rows(named=True)]
train_long2 = train_long.clone().with_columns(
    whole_prompt = pl.Series([ret[0] for ret in results]),
    choices = pl.Series([ret[1] for ret in results]),
)

In [None]:
results = [make_prompt(row) for row in synth_long.rows(named=True)]
synth_long = synth_long.with_columns(
    whole_prompt = pl.Series([ret[0] for ret in results]),
    choices = pl.Series([ret[1] for ret in results]),
)

In [None]:
# Store the preprocessed prompt for quantization
train_long.write_parquet("./train_long.parquet")

In [None]:
print(train_long["whole_prompt"][0])

In [None]:
df_all = pl.concat([
    train_long, synth_long, train_long2,
], how="diagonal_relaxed")

df_all.shape

### Data cleansing

In [None]:
# Remove prompts with excessively long token lengths
token_les = [tokenizer(t)["input_ids"] for t in tqdm(df_all["whole_prompt"])]
token_les = [len(t) for t in token_les]
df_all = df_all.with_columns(token_len=pl.Series(token_les))

In [None]:
df_all = df_all.filter(pl.col("token_len") < TRAINING_SEQ_LENGTH)
df_all["token_len"].describe()

### Train the model

In [None]:
fold = 0

train_dataset = Dataset.from_polars(df_all.filter(df_all["fold"] != fold)).shuffle(seed=0)
valid_dataset = Dataset.from_polars(train_long.filter(train_long["fold"] == fold))

In [None]:
if "Gemma" in type(tokenizer).__name__:
    sep = "<start_of_turn>model\n"
elif "Llama" in MODEL_NAME:
    sep = "<|start_header_id|>assistant<|end_header_id|>\n\n"
else:
    sep = "<|im_start|>assistant\n"

In [None]:
# LLMの応答部分のスコアのみ計算
collator = DataCollatorForCompletionOnlyLM(response_template=sep, tokenizer=tokenizer, mlm=False)

output_dir = f"{OUT_DIR_NAME}_fold{fold}"

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset=valid_dataset,
    dataset_text_field = "whole_prompt",
    max_seq_length = SEQ_LENGTH,
    data_collator=collator,
    dataset_num_proc = 8,
    peft_config=peft_config,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        learning_rate = 3e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 20,
        save_steps=50,
        optim = "paged_adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 334,
        output_dir = output_dir,
        save_total_limit=1,
        remove_unused_columns=True,
    ),
)

In [None]:
# debug
if True:
    batch = next(iter(trainer.get_train_dataloader()))
    whole_text = tokenizer.decode(batch["input_ids"][0])
    print(f"{whole_text=}")

    y_label = batch["labels"][0]
    label_text = tokenizer.decode(y_label[y_label!=-100])
    print(f"{label_text=}")

gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
# Start training
trainer_stats = trainer.train()

In [None]:
del model, trainer

gc.collect()
torch.cuda.empty_cache()

## Merge the LoRA model with the base model

In [None]:
lora_path = glob(f"./{output_dir}/checkpoint-*")[0]
lora_path

In [None]:
# Load LoRA
model = AutoPeftModelForCausalLM.from_pretrained(lora_path, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(lora_path)

In [None]:
model = model.eval()

In [None]:
# Merge QLoRA and the base model
merged = model.merge_and_unload()

In [None]:
# Save merged model
out_path = f"./{output_dir}/merged_model"
merged.save_pretrained(out_path)
tokenizer.save_pretrained(out_path)