In [1]:
import gc
import json
from glob import glob
from pathlib import Path
from typing import Any

import numpy as np
import polars as pl
import torch
from datasets import Dataset
from peft import AutoPeftModelForCausalLM, LoraConfig, prepare_model_for_kbit_training
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, set_seed
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
set_seed(1234)

#use bf16 and FlashAttention if supported
compute_dtype = torch.bfloat16
attn_implementation = "flash_attention_2"

## Configuration

In [3]:
SEQ_LENGTH = 1024*4
TRAINING_SEQ_LENGTH = 1024
MODEL_NAME = "Qwen/Qwen2.5-14B-Instruct"
DATA_ROOT = Path("/home/yoku/compe/eedi/input/eedi-mining-misconceptions-in-mathematics")
OUT_DIR_NAME = Path("./qwne25_14B")

In [4]:
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, quantization_config=bnb_config, device_map="auto", attn_implementation=attn_implementation, torch_dtype=torch.bfloat16
)

Loading checkpoint shards: 100%|██████████| 8/8 [00:03<00:00,  2.30it/s]


In [5]:
model = prepare_model_for_kbit_training(model, gradient_checkpointing_kwargs={"use_reentrant": True})
model.config.use_cache = False

peft_config = LoraConfig(
    lora_alpha=32,
    lora_dropout=0.05,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["o_proj", "k_proj", "q_proj", "down_proj", "gate_proj", "up_proj", "v_proj"],
)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    add_eos_token=True,
    add_bos_token=True
    )

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "right"

## Data Prep

In [7]:
# Load the dataset prepared in `prepare_dataset.ipynb`
train_long = pl.read_parquet("./df_train.parquet")
synth_long = pl.read_parquet("./df_synth.parquet")

In [8]:
train_long = train_long.with_columns(
    MisconceptionId = train_long["MisconceptionId"].cast(pl.Int64),
)

In [9]:
# Feature Engineering: Information related to the problem's category
train_long = train_long.with_columns(
    pl.col("MisconceptionId_pred").str.split(" ").list.eval(pl.element().cast(pl.Int32())),
    knowledge = pl.concat_str(
        [pl.col("FirstSubjectName"), pl.col("SecondSubjectName"),pl.col("ThirdSubjectName"), pl.col("ConstructName")],
        separator=",",
    ),
)
synth_long = synth_long.with_columns(
    pl.col("MisconceptionId_pred").str.split(" ").list.eval(pl.element().cast(pl.Int32())),
    knowledge = pl.concat_str(
        [pl.col("FirstSubjectName"), pl.col("SecondSubjectName"),pl.col("ThirdSubjectName"), pl.col("ConstructName")],
        separator=",",
    ),
)


In [10]:
df_misconceptions = pl.read_csv(DATA_ROOT / "misconception_mapping.csv")
misconception_map = dict(df_misconceptions.rows())

## Make prompts

In [11]:
report_format = """# OBJECTIVE #
Based on the problem, please select which of the following misconceptions might have led to the student's incorrect answer.

Categories: {knowledge}
Question: {question_tex}
Correct Answer: {correct_answer_text}
Student's Answer: {answer_text}
Step-by-Step Solution: {cot_misunderstanding}

Possible Misconceptions:\n{misconceptions_text}

Please respond only with the letter(s) of the misconception(s) you think are most likely to be the cause."""  # noqa: E501


def make_prompt(row: dict[str, Any]) -> tuple[str, str]:
    # Randomly selects a candidate from the top 5 to top 25
    k = np.random.choice([5, 10, 15, 25])
    target_misconceptions = row["MisconceptionId_pred"]
    target_misconceptions = list(target_misconceptions[:k])

    # If the correct answer is not within the Top N options,
    # randomly replace one of the options with the correct answer.
    if row["MisconceptionId"] not in target_misconceptions:
        idx = np.random.randint(0, k)
        target_misconceptions[idx] = row["MisconceptionId"]

    np.random.shuffle(target_misconceptions)

    # Generate the initials
    letters = [chr(i+65) for i in range(k)]
    misconception_names = [misconception_map[_id] for _id in target_misconceptions]
    ans_pos = target_misconceptions.index(row["MisconceptionId"])
    ans_letter = letters[ans_pos]
    choice_str = "\n".join([f"{a}. {n}" for a, n in zip(letters, misconception_names, strict=False)])

    # Prepare a dictionary for validation
    choices = dict(zip(letters, target_misconceptions, strict=False))

    # Make prompt
    q = report_format.format(
        knowledge=row["knowledge"],
        question_tex=row["QuestionText"],
        construct=row["ConstructName"],
        correct_answer_text=row["CorrectAnswerText"],
        answer_text=row["AnswerText"],
        cot_misunderstanding=row["p000-qwen25-32b-instruct-cot_misunderstanding"],
        misconceptions_text=choice_str,
    )
    context = [
        {"role": "system", "content": "You are a mathematics teacher. Your job is to infer and identify the misconceptions behind the incorrect answers to the questions."},
        {
            "role": "user",
            "content": q,
        },
        {
            "role": "assistant",
            "content": ans_letter,
        },
    ]
    prompt: str = tokenizer.apply_chat_template(context, tokenize=False, add_generation_prompt=False)
    return prompt, json.dumps(choices)


In [12]:
# 初期seed
np.random.seed(0)
results = [make_prompt(row) for row in train_long.rows(named=True)]
train_long = train_long.with_columns(
    whole_prompt = pl.Series([ret[0] for ret in results]),
    choices = pl.Series([ret[1] for ret in results]),
)

# 別seed
np.random.seed(42)
results = [make_prompt(row) for row in train_long.rows(named=True)]
train_long2 = train_long.clone().with_columns(
    whole_prompt = pl.Series([ret[0] for ret in results]),
    choices = pl.Series([ret[1] for ret in results]),
)

In [13]:
results = [make_prompt(row) for row in synth_long.rows(named=True)]
synth_long = synth_long.with_columns(
    whole_prompt = pl.Series([ret[0] for ret in results]),
    choices = pl.Series([ret[1] for ret in results]),
)

In [14]:
train_long.write_parquet("./train_long.parquet")

In [15]:
print(train_long["whole_prompt"][0])

<|im_start|>system
You are a mathematics teacher. Your job is to infer and identify the misconceptions behind the incorrect answers to the questions.<|im_end|>
<|im_start|>user
# OBJECTIVE #
Based on the problem, please select which of the following misconceptions might have led to the student's incorrect answer.

Categories: Number,Basic Arithmetic,BIDMAS,Use the order of operations to carry out calculations involving powers
Question: \[
3 \times 2+4-5
\]
Where do the brackets need to go to make the answer equal \( 13 \) ?
Correct Answer: \( 3 \times(2+4)-5 \)
Student's Answer: Does not need brackets
Step-by-Step Solution: The students' misunderstanding lies in their lack of recognition of the importance of the order of operations, which is governed by the BIDMAS rule (Brackets, Indices, Division/Multiplication, Addition/Subtraction). BIDMAS helps determine the sequence in which arithmetic operations should be performed to ensure the correct result.

In the expression \( 3 \times 2 + 

In [16]:
df_all = pl.concat([
    train_long, synth_long, train_long2,
], how="diagonal_relaxed")

df_all.shape

(22906, 46)

### Data cleansing

In [17]:
# Remove prompts with excessively long token lengths
token_les = [tokenizer(t)["input_ids"] for t in tqdm(df_all["whole_prompt"])]
token_les = [len(t) for t in token_les]
df_all = df_all.with_columns(token_len=pl.Series(token_les))

100%|██████████| 22906/22906 [00:14<00:00, 1591.52it/s]


In [18]:
df_all = df_all.filter(pl.col("token_len") < TRAINING_SEQ_LENGTH)
df_all["token_len"].describe()

statistic,value
str,f64
"""count""",21921.0
"""null_count""",0.0
"""mean""",700.721454
"""std""",143.901548
"""min""",299.0
"""25%""",590.0
"""50%""",689.0
"""75%""",810.0
"""max""",1023.0


### Train the model

In [19]:
fold = 0
train_dataset = Dataset.from_polars(df_all.filter(df_all["fold"] != fold)).shuffle(seed=0)
valid_dataset = Dataset.from_polars(train_long.filter(train_long["fold"] == fold))

In [20]:
if "Gemma" in type(tokenizer).__name__:
    sep = "<start_of_turn>model\n"
else:
    sep = "<|im_start|>assistant\n"

In [21]:
# LLMの応答部分のスコアのみ計算
collator = DataCollatorForCompletionOnlyLM(response_template=sep, tokenizer=tokenizer, mlm=False)

output_dir = f"{OUT_DIR_NAME}_fold{fold}"

In [22]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset=valid_dataset,
    dataset_text_field = "whole_prompt",
    max_seq_length = SEQ_LENGTH,
    data_collator=collator,
    dataset_num_proc = 8,
    peft_config=peft_config,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1,
        learning_rate = 3e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 10,
        save_steps=50,
        optim = "paged_adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 12,
        output_dir = output_dir,
        save_total_limit=1,
        remove_unused_columns=True,
    ),
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
Map (num_proc=8): 100%|██████████| 17551/17551 [00:04<00:00, 3622.61 examples/s]
Map (num_proc=8): 100%|██████████| 868/868 [00:01<00:00, 838.58 examples/s] 
  super().__init__(


In [23]:
# debug
if True:
    batch = next(iter(trainer.get_train_dataloader()))
    whole_text = tokenizer.decode(batch["input_ids"][0])
    print(f"{whole_text=}")

    y_label = batch["labels"][0]
    label_text = tokenizer.decode(y_label[y_label!=-100])
    print(f"{label_text=}")

gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

whole_text="<|im_start|>system\nYou are a mathematics teacher. Your job is to infer and identify the misconceptions behind the incorrect answers to the questions.<|im_end|>\n<|im_start|>user\n# OBJECTIVE #\nBased on the problem, please select which of the following misconceptions might have led to the student's incorrect answer.\n\nCategories: Geometry and Measure,Angles,Basic Angle Facts (straight line, opposite, around a point, etc),Use angles on a straight line to form and solve algebraic equations\nQuestion: What is the size of angle \\( a \\) ? ![Angles on a straight line split into 3 unequal parts labelled with a, 4a and a right angle marker]()\nCorrect Answer: \\( 18^{\\degree} \\)\nStudent's Answer: \\( 36^{\\degree} \\)\nStep-by-Step Solution: The misunderstanding likely stems from a confusion about how angles on a straight line sum up to \\(180^\\circ\\), and how to properly distribute this total among the given angles.\n\nHere’s a step-by-step explanation of the mistake:\n\n

In [24]:
# Start training
trainer_stats = trainer.train()

  0%|          | 0/1097 [00:00<?, ?it/s]The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.19 GiB. GPU 0 has a total capacity of 23.64 GiB of which 431.12 MiB is free. Including non-PyTorch memory, this process has 22.26 GiB memory in use. Of the allocated memory 20.01 GiB is allocated by PyTorch, and 1.80 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [25]:
del model, trainer

gc.collect()
torch.cuda.empty_cache()

## Merge the LoRA model with the base model

In [31]:
lora_path = glob(f"./{output_dir}/checkpoint-*")[0]
lora_path

'./qwne25_14B_fold0/checkpoint-1145'

In [32]:
# Load LoRA
model = AutoPeftModelForCausalLM.from_pretrained(lora_path, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(lora_path)

  0%|          | 0/1097 [03:23<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  7.17it/s]


In [33]:
model = model.to("cuda")
model = model.eval()

In [None]:
merged = model.merge_and_unload()

In [None]:
# Save merged model
out_path = f"./{output_dir}/merged_model"
merged.save_pretrained(out_path)
tokenizer.save_pretrained(out_path)

('./qwne25_14B_fold0/merged_model/tokenizer_config.json',
 './qwne25_14B_fold0/merged_model/special_tokens_map.json',
 './qwne25_14B_fold0/merged_model/tokenizer.json')