# Setup

In [1]:
%env N_EX=4

env: N_EX=4


# Stage 1: Retriever

Retrieve candidate misconceptions

## 1.1 Retriever Scripts

In [5]:
%%writefile eedi_llm_retriever.py

import sys

sys.path.insert(0, '/kaggle/input/eedi-utils-v04')

import os
import argparse
import gc
import json
import os
import warnings
from copy import deepcopy
from itertools import chain

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from accelerate import Accelerator

from llm_embedding.eedi_dataset import MathDataset
from llm_embedding.eedi_loader import TextCollator
from llm_embedding.eedi_model import BiEncoderModel

from utils.retriever_utils import semantic_search

from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch.distributed as dist

from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoConfig, AutoModel, BitsAndBytesConfig

warnings.filterwarnings("ignore")


def print_line(): print("--"*40)

def query_formatting_func(query):
    task_description = """Retrieve the key misconception behind the wrong answer when given a math problem and its incorrect and correct solutions."""
    return f"Instruct: {task_description}\nQuery: {query}"


def get_base_model(cfg):
    config = AutoConfig.from_pretrained(cfg.model.backbone_path, trust_remote_code=False)
    config.use_cache = False

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16,
    )
    
    model = AutoModel.from_pretrained(
        cfg.model.backbone_path, 
        config=config, 
        quantization_config=bnb_config,
        attn_implementation=cfg.model.attn_implementation, 
        low_cpu_mem_usage=True
    )
        
    model.config.pretraining_tp = 1
    return model
    
# ---
def show_batch(batch, tokenizer, n_examples=4, print_fn=print):
    bs = batch["input_ids"].size(0)
    print_fn(f"batch size: {bs}")

    print_fn(f"shape of input_ids: {batch['input_ids'].shape}")

    print("--" * 80)
    for idx in range(n_examples):
        print_fn(f"[Text]:\n{tokenizer.decode(batch['input_ids'][idx], skip_special_tokens=False)}")
        print_fn("--" * 80)
    
def eedi_process_df(df):
    df = deepcopy(df)
    grouped = df.groupby("QuestionId")

    question_dict = {}
    for question_id, group in grouped:
        question_data = group.to_dict(orient="records")[0]
        del question_data["QuestionId"]
        question_dict[question_id] = question_data

    all_questions = list(question_dict.keys())

    queries = []

    # ---
    for qid in all_questions:
        info = question_dict[qid]

        for answer_key in ["A", "B", "C", "D"]:
            if info["CorrectAnswer"] == answer_key:
                continue
            
            this_example = dict()
            this_key = f"{qid}_{answer_key}"
            this_example["query_id"] = this_key
            
            # ---
            for col in ["SubjectName", "ConstructName", "QuestionText"]:
                this_example[col] = info[col]

            this_example["CorrectAnswerText"] = info[f"Answer{info['CorrectAnswer']}Text"]
            this_example["InCorrectAnswerText"] = info[f"Answer{answer_key}Text"]
            queries.append(this_example)
    # --
    query_df = pd.DataFrame(queries)
    return query_df


def add_retrieved_results(cfg, infer_df, content_df):
    """
    find top-k similar chunks for each question
    """
    # accelerator -----------------------------------------------------------------------#
    accelerator = Accelerator()

    query_df = eedi_process_df(infer_df)

    # queries ---------------------------------------------------------------------------#
    ds_handle = MathDataset(cfg, query_formatting_func=query_formatting_func)
    tokenizer = ds_handle.tokenizer
    
    query_ds = ds_handle.get_dataset(query_df, is_query=True)
    content_ds = ds_handle.get_dataset(content_df, is_query=False)
    
    query_ds = query_ds.sort("input_length")
    content_ds = content_ds.sort("input_length")
    
    query_ids = query_ds["query_id"]
    content_ids = content_ds["content_id"]

    collator = TextCollator(tokenizer=tokenizer)

    query_dl = DataLoader(
        query_ds,
        batch_size=cfg.predict_params.query_bs,
        shuffle=False,
        collate_fn=collator,
    )

    content_dl = DataLoader(
        content_ds,
        batch_size=cfg.predict_params.content_bs,
        shuffle=False,
        collate_fn=collator,
    )
    
    # show ---
    accelerator.print("Showing a batch (Query)...")
    for b in query_dl:
        show_batch(b, tokenizer, print_fn=accelerator.print)
        break
        
    accelerator.print("Showing a batch (Content)...")
    for b in content_dl:
        show_batch(b, tokenizer, print_fn=accelerator.print)
        break
    
    # model -----------------------------------------------------------------------------#
    base_model = get_base_model(cfg)
    model = BiEncoderModel(cfg, base_model, accelerator)
    
    # prepare ---------------------------------------------------------------------------#
    model, query_dl, content_dl = accelerator.prepare(model, query_dl, content_dl)

    # query embeddings ------------------------------------------------------------------#
    query_embeddings = []
    progress_bar = tqdm(range(len(query_dl)))

    for batch in query_dl:
        with torch.no_grad():
            batch_query_embeddings = accelerator.unwrap_model(model).encode(batch)
        batch_query_embeddings = accelerator.gather_for_metrics(batch_query_embeddings)
        query_embeddings.append(batch_query_embeddings)
        progress_bar.update(1)
    progress_bar.close()

    query_embeddings = torch.cat(query_embeddings, dim=0)
    accelerator.print(f"shape of query embeddings: {query_embeddings.shape}")
    assert query_embeddings.shape[0] == len(query_ids)
    
    # content embeddings ----------------------------------------------------------------#
    content_embeddings = []
    progress_bar = tqdm(range(len(content_dl)))

    for batch in content_dl:
        with torch.no_grad():
            batch_content_embeddings = accelerator.unwrap_model(model).encode(batch)
        batch_content_embeddings = accelerator.gather_for_metrics(batch_content_embeddings)
        content_embeddings.append(batch_content_embeddings)
        progress_bar.update(1)
    progress_bar.close()

    content_embeddings = torch.cat(content_embeddings, dim=0)
    accelerator.print(f"shape of content embeddings: {content_embeddings.shape}")
    assert content_embeddings.shape[0] == len(content_ids)

    # top-k search ----------------------------------------------------------------------#
    results = semantic_search(query_embeddings, content_embeddings, top_k = cfg.model.n_neighbour)
    
    pred_content_ids, pred_scores = [], []
    for idx, re_i in enumerate(results):
        query_id = query_ids[idx]
        hit_i = [node["corpus_id"] for node in re_i]
        top_scores_i = [node["score"] for node in re_i]
        top_content_ids_i = [content_ids[pos] for pos in hit_i]
        pred_content_ids.append(top_content_ids_i)
        pred_scores.append(top_scores_i)
    
    result_df = pd.DataFrame()
    result_df["query_id"] = query_ids
    result_df["pred_ids"] = pred_content_ids
    result_df["pred_scores"] = pred_scores
    
    # get oof df
    oof_df = result_df.copy()
    oof_df = oof_df.rename(columns={"query_id": "QuestionId_Answer"})
    oof_df = oof_df.rename(columns={"pred_ids": "MisconceptionId"})
    oof_df["MisconceptionId"] = oof_df["MisconceptionId"].apply(lambda x: list(map(str, x)))

    print_line()
    accelerator.print("Sample Prediction:")
    accelerator.print(oof_df.sample().T)
    print_line()
    
    return oof_df

# ------


def execute_inference(cfg, save_dir, model_name):
    test_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/test.csv")
    
    if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        n_ex = int(os.getenv("N_EX"))
        test_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv").head(n_ex)
    content_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")
    content_df = content_df.rename(columns={"MisconceptionId": "content_id"})
        
    if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        test_df = add_retrieved_results(cfg, test_df, content_df)
    elif cfg.run_on_save:
        test_df = add_retrieved_results(cfg, test_df, content_df)
    else:
        test_df = pd.read_parquet("./retriever_outputs/intfloat.parquet")

    save_path = os.path.join(save_dir, f"{model_name}.parquet")
    test_df.to_parquet(save_path)
    
    if dist.is_initialized():
        dist.destroy_process_group()

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument('--config_path', type=str, required=True)
    ap.add_argument('--save_dir', type=str, required=True)
    ap.add_argument('--model_name', type=str, required=True)

    args = ap.parse_args()
    cfg = OmegaConf.load(args.config_path)

    os.makedirs(args.save_dir, exist_ok=True)

    # execution
    execute_inference(cfg, save_dir=args.save_dir, model_name=args.model_name)

Overwriting eedi_llm_retriever.py


## 1.2 Retriever Configs

In [None]:
%%writefile eedi_retriever_intfloat.yaml
run_on_save: true

model:
    backbone_path: /kaggle/input/eedi-embed-intfloat-cv476-ff-4bit
    max_length: 768
    sentence_pooling_method: last
    attn_implementation: eager
    negatives_cross_device: false
    add_eos_token: true
    padding_side: left
    trust_remote_code: false

    n_neighbour: 128

predict_params:
    query_bs: 8
    content_bs: 8

train_params: # fix, should not need these for infer
    sub_batch_size: 8 
    num_hard_negatives: 0

In [None]:
%%writefile eedi_retriever_qwen.yaml
run_on_save: false
model:
    backbone_path: /kaggle/input/eedi-embed-qwen14b-cv486-ff-4bit
    max_length: 768
    sentence_pooling_method: last
    attn_implementation: eager
    negatives_cross_device: false
    add_eos_token: true
    padding_side: left
    trust_remote_code: false

    n_neighbour: 128

predict_params:
    query_bs: 8
    content_bs: 8

train_params:
    sub_batch_size: 8 # fix
    num_hard_negatives: 0

## 1.3 Retriever Inference

In [None]:
%%time
!accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=2 eedi_llm_retriever.py \
--config_path ./eedi_retriever_intfloat.yaml \
--save_dir ./retriever_outputs \
--model_name intfloat

In [None]:
%%time
!accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=2 eedi_llm_retriever.py \
--config_path ./eedi_retriever_qwen.yaml \
--save_dir ./retriever_outputs \
--model_name qwen_14b

## 1.4 Prepare for Ranker
- Prepares input for re-ranking ("./retriever_outputs/ranker_input.parquet")
- Prepares blended predictions from retrievers ("./retriever_outputs/stage_one_blended.parquet")

In [None]:
%%writefile prepare_for_ranker.py

import argparse
import os
from copy import deepcopy

import pandas as pd
from omegaconf import OmegaConf


def process_df(df):
    df = deepcopy(df)
    grouped = df.groupby("QuestionId")

    question_dict = {}
    for question_id, group in grouped:
        question_data = group.to_dict(orient="records")[0]
        del question_data["QuestionId"]
        question_dict[question_id] = question_data

    all_questions = list(question_dict.keys())

    queries = []

    for qid in all_questions:
        info = question_dict[qid]

        for answer_key in ["A", "B", "C", "D"]:
            if info["CorrectAnswer"] == answer_key:
                continue
            this_example = dict()
            this_key = f"{qid}_{answer_key}"
            this_example["QuestionId_Answer"] = this_key

            # ---
            for col in ["SubjectName", "ConstructName", "QuestionText"]:
                this_example[col] = info[col]

            this_example["CorrectAnswerText"] = info[f"Answer{info['CorrectAnswer']}Text"]
            this_example["InCorrectAnswerText"] = info[f"Answer{answer_key}Text"]
            this_example["AllOptionText"] = "\n- ".join([info[f"Answer{x}Text"] for x in ["A", "B", "C", "D"]])
            this_example["AllOptionText"] = f"\n- {this_example['AllOptionText']}"
            queries.append(this_example)

    query_df = pd.DataFrame(queries)
    return query_df


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-path", type=str)
    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        cfg = OmegaConf.load(f)

    os.makedirs(cfg.output_dir, exist_ok=True)

    # load data ---
    test_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/test.csv")
    if not os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
        n_ex = int(os.getenv("N_EX"))
        test_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv").head(n_ex)
    concept_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")
    concept_df["MisconceptionId"] = concept_df["MisconceptionId"].astype(str)

    # load retriever outputs ---
    flat_dfs = []
    for fp in cfg.retriever_outputs:
        df = pd.read_parquet(fp)
        print(f"Loading from {fp}")
        print(df.head(1).T)
        print("--"*40)
        flat_df = df.explode(["MisconceptionId", "pred_scores"]).reset_index(drop=True)
        flat_dfs.append(flat_df)

    # blend retriever scores ---
    ret_preds = pd.concat(flat_dfs, ignore_index=True)
    ret_preds = ret_preds.groupby(["QuestionId_Answer", "MisconceptionId"])["pred_scores"].agg("sum").reset_index()
    ret_preds["pred_scores"] = ret_preds["pred_scores"] / len(cfg.retriever_outputs)

    # re-sort by pred_scores
    grouped = ret_preds.groupby("QuestionId_Answer")

    results = []
    for question_id, group in grouped:
        sorted_group = group.sort_values("pred_scores", ascending=False)
        result = {"QuestionId_Answer": question_id, "MisconceptionId": list(sorted_group["MisconceptionId"]), "pred_scores": list(sorted_group["pred_scores"])}
        results.append(result)
    ret_preds = pd.DataFrame(results)

    # save for later use ---
    print("--" * 40)
    save_path = os.path.join(cfg.output_dir, cfg.blned_file_name)
    print(f"Saving stage one blended predictions to {save_path}")
    ret_preds.to_parquet(save_path)  
    print(ret_preds.sample().T)
    print("--" * 40)

    # prepare candidate set ---
    margin = cfg.margin
    ret_preds["threshold"] = ret_preds["pred_scores"].apply(lambda x: x[0] - margin)
    ret_preds["cutoff"] = ret_preds.apply(lambda x: sum([y > x["threshold"] for y in x["pred_scores"]]), axis=1)
    ret_preds["cutoff"] = ret_preds["cutoff"].clip(lower=cfg.min_top_k, upper=cfg.max_top_k)

    print("--" * 40)
    print("Cutoff distribution:\n-------")
    print(ret_preds["cutoff"].value_counts().sort_index())
    print(f"Average # candidates: {ret_preds['cutoff'].mean()}")
    print("--" * 40)

    ret_preds["MisconceptionId"] = ret_preds.apply(lambda row: row["MisconceptionId"][:row["cutoff"]], axis=1)
    ret_preds = ret_preds[["QuestionId_Answer", "MisconceptionId"]].copy()
    flat_df = ret_preds[["QuestionId_Answer", "MisconceptionId"]].explode(["MisconceptionId"]).reset_index(drop=True)

    # prepare ranker input ---
    rank_df = process_df(test_df)
    rank_df = rank_df.merge(flat_df, on="QuestionId_Answer", how="left")
    rank_df = rank_df.merge(concept_df, on="MisconceptionId", how="left")
    rank_df["MisconceptionId"] = rank_df["MisconceptionId"].astype(str)
    rank_df = rank_df[
        ["QuestionId_Answer", "MisconceptionId", "SubjectName", "ConstructName", "QuestionText", "CorrectAnswerText", "InCorrectAnswerText", "AllOptionText", "MisconceptionName"]
    ].copy()
    
    print("--" * 40)
    save_path = os.path.join(cfg.output_dir, cfg.ranker_input_file_name)
    print(f"Saving re-ranker input to: {save_path}")
    rank_df.to_parquet(save_path)
    print("--"*40)
    
    print(rank_df.sample().T)
    print(f"shape of ranker input: {rank_df.shape}")
    print("--"*40)

In [None]:
%%writefile ranker_prep.yaml

retriever_outputs:
    - ./retriever_outputs/intfloat.parquet
    - ./retriever_outputs/qwen_14b.parquet

min_top_k: 32
max_top_k: 64 #64
margin: 0.05 # ave n -> 38.2, recall -> 96.5

output_dir: ./retriever_outputs
blned_file_name: stage_one_blended.parquet
ranker_input_file_name: ranker_input_stage_one.parquet

In [None]:
!python prepare_for_ranker.py --config-path ranker_prep.yaml

# COT

In [None]:
%%time
!pip uninstall -y torch

!pip install -q --no-index --find-links=/kaggle/input/wheels-vllm-0-6-3-post1 torchvision==0.19.1
!pip install -q --no-index --find-links=/kaggle/input/wheels-vllm-0-6-3-post1 vllm

!pip install -q -U --upgrade /kaggle/input/vllm-t4-fix/grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install -q -U --upgrade /kaggle/input/vllm-t4-fix/ray-2.11.0-cp310-cp310-manylinux2014_x86_64.whl

!pip uninstall -y pynvml
!pip install --no-deps --no-index /kaggle/input/0-6-3-post1-wheels-vllm/nvidia_ml_py-12.560.30-py3-none-any.whl

## CoT Script

In [None]:
%%writefile run_gen_cot.py

import argparse
import os
from copy import deepcopy

import pandas as pd
import vllm
print('vllm version=', vllm.__version__)

from datasets import Dataset
from omegaconf import OmegaConf
from transformers import AutoTokenizer

sp = "Analyze the incorrect answer to detect flaws in the student's reasoning."


def get_tokenizer(backbone_path):
    tokenizer = AutoTokenizer.from_pretrained(backbone_path, add_eos_token=True)

    if tokenizer.eos_token == "":
        tokenizer.add_special_tokens({"eos_token": "</s>"})
        tokenizer.eos_token = "</s>"

    if tokenizer.pad_token is None:
        if tokenizer.unk_token is not None:
            tokenizer.pad_token = tokenizer.unk_token
            tokenizer.pad_token_id = tokenizer.unk_token_id
        else:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id

    tokenizer.bos_token = "<|im_start|>"
    tokenizer.padding_side = "left"
    return tokenizer


def is_nan(x): return x != x


def eedi_process_df(df):
    df = deepcopy(df)
    grouped = df.groupby("query_id")

    question_dict = {}
    for question_id, group in grouped:
        question_data = group.to_dict(orient="records")[0]
        del question_data["query_id"]
        question_dict[question_id] = question_data

    all_questions = list(question_dict.keys())

    queries = []
    for qid in all_questions:
        info = question_dict[qid]

        for answer_key in ["A", "B", "C", "D"]:
            if info["CorrectAnswer"] == answer_key:
                continue

            this_example = dict()
            this_key = f"{qid}_{answer_key}"
            this_example["query_id"] = this_key

            for col in ["SubjectName", "ConstructName", "QuestionText"]:
                this_example[col] = info[col]

            this_example["CorrectAnswerText"] = info[f"Answer{info['CorrectAnswer']}Text"]
            this_example["InCorrectAnswerText"] = info[f"Answer{answer_key}Text"]
            queries.append(this_example)
    # --
    query_df = pd.DataFrame(queries)
    return query_df


def main(cfg, save_dir, model_id):
    # load data ---
    test_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/test.csv")

    if not os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
        n_ex = int(os.getenv("N_EX"))
        test_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv").head(n_ex)

    test_df = test_df.rename(columns={"QuestionId": "query_id"})
    test_df = eedi_process_df(test_df)

    content_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")
    id2name = dict(zip(content_df['MisconceptionId'], content_df['MisconceptionName']))


    ds = Dataset.from_pandas(test_df)
    print(f"Number of examples: {len(ds)}")
    query_ids = ds["query_id"]

    print("==" * 50)
    print(f"Generating for model: {cfg.model.backbone_path}")
    print("==" * 50)

    llm = vllm.LLM(
        cfg.model.backbone_path,
        tensor_parallel_size=2,
        quantization=cfg.model.quantization,
        gpu_memory_utilization=0.99,
        trust_remote_code=True,
        dtype="half",
        enforce_eager=True,
        max_model_len=1296,
        enable_prefix_caching=True
    )

    tokenizer = get_tokenizer(cfg.model.backbone_path)

    prompts = []
    for example in ds:
        question = example["QuestionText"]
        correct_answer = example["CorrectAnswerText"]
        incorrect_answer = example["InCorrectAnswerText"]

        user_message = f"Question: {question}\nCorrect Answer: {correct_answer}\nIncorrect Answer: {incorrect_answer}"
        text = f"{sp}\n\nQuery: {user_message}\nAnswer:\n"
        prompts.append(text)

    for p in prompts[:5]:
        print(p)
        print("-" * 100)

    sampling_params = vllm.SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.0, max_tokens=cfg.max_new_tokens)

    # do in chunks ---
    response_dfs = []

    chunk_size = 256
    for i in range(0, len(prompts), chunk_size):
        results = []
        generated_texts = []
        all_prompts = []

        chunk_prompts = prompts[i : i + chunk_size]
        chunk_query_ids = query_ids[i : i + chunk_size]

        print(f"Processing chunk {i//chunk_size + 1} of {(len(prompts)-1)//chunk_size + 1}")

        generations = llm.generate(chunk_prompts, sampling_params=sampling_params)

        for output in generations:
            prompt = output.prompt
            generated_text = output.outputs[0].text
            all_prompts.append(prompt)

            full_text = f"{prompt}{generated_text}"
            results.append(full_text)
            generated_texts.append(generated_text)

        # Save intermediate results
        df = pd.DataFrame()
        df["query_id"] = chunk_query_ids
        df["prompt"] = all_prompts
        df["cot"] = generated_texts
        response_dfs.append(df)

    result_df = pd.concat(response_dfs).reset_index(drop=True)

    save_path = os.path.join(save_dir, f"gen_{model_id}.parquet")
    result_df.to_parquet(save_path)
    # ---
    n = min(5, len(result_df))
    samples = result_df.sample(n)['cot'].values.tolist()
    for samp in samples:
        print(samp)
        print('--'*50)


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--config_path", type=str, required=True)
    ap.add_argument("--save_dir", type=str, required=True)
    ap.add_argument("--model_id", type=str, required=True)

    args = ap.parse_args()
    cfg = OmegaConf.load(args.config_path)

    os.makedirs(args.save_dir, exist_ok=True)

    # execution ---
    if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        main(cfg, save_dir=args.save_dir, model_id=args.model_id)
    elif cfg.run_on_save:
        main(cfg, save_dir=args.save_dir, model_id=args.model_id)    
    else:
        result_df = pd.read_parquet("./gen/gen_qwen_7b.parquet")
        save_path = os.path.join(args.save_dir, f"gen_{args.model_id}.parquet")
        result_df.to_parquet(save_path)

In [None]:
%%writefile conf_llm_gen_7b.yaml
run_on_save: true

max_new_tokens: 256 # 384
stage_one_path: retriever_outputs/stage_one_blended.parquet

model:
    backbone_path: "/kaggle/input/eedi-cot-7b-base-dec4/transformers/default/1"
    max_length: 768
    num_proc: 2
    quantization:
    
    tokenizer:
        padding_side: left
        truncation_side: left
        use_fast: true

In [None]:
%%writefile conf_llm_gen_14b.yaml
run_on_save: false

max_new_tokens: 256 # 384
stage_one_path: retriever_outputs/stage_one_blended.parquet

model:
    backbone_path: "/kaggle/input/eedi-cot-14b-dec6-awq/transformers/default/1"
    max_length: 768
    num_proc: 2
    quantization: awq
     
    tokenizer:
        padding_side: left
        truncation_side: left
        use_fast: true

In [None]:
%%writefile conf_llm_gen_32b.yaml
run_on_save: false

max_new_tokens: 256
stage_one_path: retriever_outputs/stage_one_blended.parquet

model:
    backbone_path: "/kaggle/input/eedi-cot-32b-dec6-awq/transformers/default/1"
    max_length: 768
    num_proc: 2
    quantization: awq
    
    tokenizer:
        padding_side: left
        truncation_side: left
        use_fast: true

In [None]:
%%time
!python run_gen_cot.py --config_path conf_llm_gen_7b.yaml --save_dir ./gen --model_id qwen_7b

In [None]:
%%time
!python run_gen_cot.py --config_path conf_llm_gen_14b.yaml --save_dir ./gen --model_id qwen_14b

In [None]:
%%time
!python run_gen_cot.py --config_path conf_llm_gen_32b.yaml --save_dir ./gen --model_id qwen_32b

In [None]:
# for 256 inputs
# 7b gen time  : ~2 mins
# 14b gen time : 
# 32b gen time : ~12mins

# Stage 2: Re-Ranker (Haiku)

## 2.1 Scripts

In [None]:
%%writefile run_haiku.py

import sys
sys.path.insert(0, "/kaggle/input/eedi-utils-v11")

import argparse
import gc
import os
import vllm

import random
from collections import defaultdict

import numpy as np
import pandas as pd
import torch

from llm_oracle.ranker_dataset import RankerDataset
from omegaconf import OmegaConf


from copy import deepcopy

from datasets import Dataset
from transformers import AutoTokenizer


def is_nan(x): return x != x
    
def stable_softmax(x, temp=1.0):
    x = np.array(x) / temp
    x_max = np.max(x)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x)
    
def eedi_process_df(df):
    df = df.copy()
    df = df.rename(columns={"QuestionId": "query_id"})
    grouped = df.groupby("query_id")

    question_dict = {}
    for question_id, group in grouped:
        question_data = group.to_dict(orient="records")[0]
        del question_data["query_id"]
        question_dict[question_id] = question_data

    all_questions = list(question_dict.keys())

    queries = []
    query2content = defaultdict(list)
    content2query = defaultdict(list)

    # ---
    for qid in all_questions:
        info = question_dict[qid]

        for answer_key in ["A", "B", "C", "D"]:
            if info["CorrectAnswer"] == answer_key:
                continue
            this_example = dict()
            this_key = f"{qid}_{answer_key}"
            this_example["query_id"] = this_key

            if is_nan(info[f"Misconception{answer_key}Id"]):
                continue

            mid = str(int(info[f"Misconception{answer_key}Id"]))
            query2content[this_key].append(mid)
            content2query[mid].append(this_key)

            # ---
            for col in ["SubjectId", "SubjectName", "ConstructName", "QuestionText"]:
                this_example[col] = info[col]

            this_example["CorrectAnswerText"] = info[f"Answer{info['CorrectAnswer']}Text"]
            this_example["InCorrectAnswerText"] = info[f"Answer{answer_key}Text"]
            this_example["AllOptionText"] = "\n- ".join([info[f"Answer{x}Text"] for x in ["A", "B", "C", "D"]])
            this_example["AllOptionText"] = f"\n- {this_example['AllOptionText']}"
            queries.append(this_example)
    # --
    query_df = pd.DataFrame(queries)
    corr_df = pd.Series(query2content).reset_index().rename(columns={"index": "query_id", 0: "content_id"})
    corr_df["content_id"] = corr_df["content_id"].apply(lambda x: x[0])

    query_df = query_df.reset_index(drop=True)

    return query_df, corr_df, content2query

def sort_by_scores(pred_ids, scores):
    keep_idxs = np.argsort(-np.array(scores)).tolist()
    ret_ids = [pred_ids[idx] for idx in keep_idxs]
    ret_scores = [scores[idx] for idx in keep_idxs]
    return {"sorted_ids": ret_ids, "sorted_scores": ret_scores}


def format_example(row):
    example = f"Question: {row['QuestionText']}\nAnswer:{row['CorrectAnswerText']}\nMisconception Answer: {row['InCorrectAnswerText']}"
    return example


def add_fs_examples(df, content2query, query2example, rng, n=2):
    cache = {}
    def _add_examples(row):
        cid = row["content_id"]
        if cid in cache:
            return cache[cid]
        else:
            qids = content2query[cid]
            qids = [qid for qid in qids if qid != row["query_id"]]
            if len(qids) == 0:
                cache[cid] = ""
                return ""

            qids = rng.sample(qids, k=min(n, len(qids)))
            examples = [query2example[qid] for qid in qids]
            fs = "\n--\n".join(examples)
            cache[cid] = fs
            return fs

    df["examples"] = df.apply(_add_examples, axis=1)
    return df


def main(cfg, save_dir, model_id):
    test_df = pd.read_parquet(cfg.input_path)
    test_df = test_df.rename(columns={"QuestionId_Answer": "query_id", "MisconceptionId": "content_id"})
    test_df["content_id"] = test_df["content_id"].astype(str)

    # comp data examples ---
    rng = random.Random(cfg.seed)
    comp_df = pd.read_csv(cfg.icl_path).rename(columns={"QuestionId": "query_id"})

    query_df, _, content2query = eedi_process_df(comp_df)
    query_df["demo"] = query_df.apply(format_example, axis=1)
    query2example = dict(zip(query_df["query_id"], query_df["demo"]))

    # add few shot examples --
    test_df = add_fs_examples(test_df, content2query, query2example, rng, n=cfg.k_shot)
    test_df = test_df.sort_values(by='content_id').reset_index(drop=True)

    # cot --
    if cfg.use_cot:
        print("Loading CoT....")
        cot_df = pd.read_parquet(cfg.cot_path)
        test_df = test_df.merge(cot_df, on='query_id', how='left')
        num_missing = test_df["cot"].isna().sum()
        print(f"# of missing cot: {num_missing}")
        test_df["cot"] = test_df["cot"].fillna("")
        print(test_df.sample().T.to_dict())
        print("--"*40)

    #---
    dataset_creator = RankerDataset(cfg)
    infer_ds = dataset_creator.get_dataset(test_df)
    
    tokenizer = dataset_creator.tokenizer    
    infer_qa_ids = infer_ds["query_id"]
    infer_mc_ids = infer_ds["content_id"]
    
    infer_ds = infer_ds.map(lambda example: {"prompt": tokenizer.decode(example['input_ids'], skip_special_tokens=False)})
    prompts = infer_ds['prompt']

    print(f"# of requests: {len(prompts)}")
    print(f"Example:\n\n{prompts[0]}")
    
    # -- in
    llm = vllm.LLM(
        cfg.model.backbone_path,
        quantization="awq",
        tensor_parallel_size=2,
        gpu_memory_utilization=0.99,
        trust_remote_code=True,
        dtype="half",
        enforce_eager=True,
        max_model_len=2048,
        enable_prefix_caching=True,
    )
    
    sampling_params = vllm.SamplingParams(n=1, top_p=0.8, logprobs=20, max_tokens=1, temperature=0.0, skip_special_tokens=False)
    responses = llm.generate(prompts, sampling_params, use_tqdm=True)

    # Get Results
    print("--"*40)
    yes_tok_id = tokenizer("Yes", add_special_tokens=False)["input_ids"][-1]
    no_tok_id = tokenizer("No", add_special_tokens=False)["input_ids"][-1]
    
    print(f">> EediRanker: Yes token id: {yes_tok_id} | Expected: 9454")
    print(f">> EediRanker: No token id: {no_tok_id} | Expected: 2753")
    print("--"*40)
    
    QuestionId_Answer = []
    MisconceptionId =[]
    scores = []
    
    for qid, cid, response in zip(infer_qa_ids, infer_mc_ids, responses):
        logprob_dict = response.outputs[0].logprobs[0]

        top_tok_ids = set(list(logprob_dict.keys()))
        if len(top_tok_ids.intersection(set([yes_tok_id, no_tok_id]))) == 0:
            print(f"Bad Output for {qid} - {cid}")
            continue
        
        yes_logit, no_logit = -10.0, -10.0
        
        if yes_tok_id in logprob_dict:
            yes_logit = logprob_dict[yes_tok_id].logprob
        
        if no_tok_id in logprob_dict:
            no_logit = logprob_dict[no_tok_id].logprob
        
        score = yes_logit - no_logit
        
        QuestionId_Answer.append(qid)
        MisconceptionId.append(cid)
        scores.append(score)
    
    result_df = pd.DataFrame()
    result_df["QuestionId_Answer"] = QuestionId_Answer
    result_df["MisconceptionId"] = MisconceptionId
    result_df["score"] = scores
    
    agg_df = result_df.groupby("QuestionId_Answer")["MisconceptionId"].agg(list).reset_index()
    score_agg_df = result_df.groupby("QuestionId_Answer")["score"].agg(list).reset_index()
    agg_df = pd.merge(agg_df, score_agg_df, on="QuestionId_Answer", how="left")
    
    agg_df["topk_info"] = agg_df.apply(lambda x: sort_by_scores(x["MisconceptionId"], x["score"]), axis=1)
    agg_df["MisconceptionId"] = agg_df["topk_info"].apply(lambda x: x["sorted_ids"])
    agg_df["score"] = agg_df["topk_info"].apply(lambda x: x["sorted_scores"])
    
    # compute oof dataframe ---
    oof_df = agg_df.copy()
    oof_df = oof_df[["QuestionId_Answer", "MisconceptionId", "score"]].copy()
    oof_df = oof_df.rename(columns={"score": "logit_scores"})

    # normalize ---
    oof_df["pred_scores"] = oof_df["logit_scores"].apply(stable_softmax)
    oof_df["MisconceptionId"] = oof_df["MisconceptionId"].apply(lambda x: list(map(str, x)))

    # print ---
    print("--"*40)
    row = oof_df.sample()
    formatted_scores = [f"{s:.3f}" for s in row['pred_scores'].values[0]]
    misconceptions = row['MisconceptionId'].values[0]
    print(f"Showing 1 example: {row['QuestionId_Answer'].values[0]}")
    for rank, (m, s) in enumerate(zip(misconceptions, formatted_scores)):
        print(f"MisconceptionId: {m} -> Score: {s}")
    print("--"*40)

    save_path = os.path.join(save_dir, f"ranker_{model_id}.parquet")
    oof_df.to_parquet(save_path)

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument('--config_path', type=str, required=True)
    ap.add_argument('--save_dir', type=str, required=True)
    ap.add_argument('--model_id', type=str, required=True)

    args = ap.parse_args()
    cfg = OmegaConf.load(args.config_path)

    os.makedirs(args.save_dir, exist_ok=True)

    # execution ---
    main(cfg, save_dir=args.save_dir, model_id=args.model_id)

## 2.2 Re-Ranker Config

In [None]:
%%writefile conf_llm_oracle_14b_awq.yaml

seed: 675
input_path: ./retriever_outputs/ranker_input_stage_one.parquet
icl_path: /kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv

k_shot: 2
use_cot: false
cot_path: ./gen/gen_qwen_14b.parquet

model:
    backbone_path: "/kaggle/input/eedi-oracle-14b-dec7-cv646-awq-ff/transformers/default/1"
    max_length: 768
    num_proc: 2
    
    tokenizer:
        padding_side: left
        truncation_side: left
        use_fast: true

## 2.3 Re-Ranker Execution

In [None]:
%%time
!python run_haiku.py --config_path "./conf_llm_oracle_14b_awq.yaml" --save_dir "./ranker_outputs" --model_id "qwen_14b"

## 2.4 Blend Retriever and Ranker

In [None]:
%%writefile blend_one_two.py

import argparse
import os
from copy import deepcopy

import pandas as pd
import numpy as np
from omegaconf import OmegaConf

def get_sorted_pairs(content_ids, scores):
    collection = [(cid, s) for cid, s in zip(content_ids, scores)]
    sorted_collection = sorted(collection, key=lambda x: x[1], reverse=True)
    return sorted_collection

def cut_at_n(sub_df, n=25):
    sub_df["MisconceptionId"] = sub_df["MisconceptionId"].apply(lambda x: x[:n])
    sub_df["score"] = sub_df["score"].apply(lambda x: x[:n])
    return sub_df

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-path", type=str)
    args = parser.parse_args()
    
    with open(args.config_path, "r") as f:
        cfg = OmegaConf.load(f)

    # read predictions ---
    ret_preds = pd.read_parquet(cfg.ret_path)
    ranker_preds = pd.read_parquet(cfg.ranker_path)
    print(ranker_preds.sample(3))

    # flatten ---
    ret_preds = ret_preds[["QuestionId_Answer", "MisconceptionId", "pred_scores"]].explode(["MisconceptionId", "pred_scores"]).reset_index(drop=True)
    ret_preds = ret_preds.rename(columns={"pred_scores": "score_ret"})
    ret_preds["MisconceptionId"] = ret_preds["MisconceptionId"].astype(str)

    ranker_preds = ranker_preds[["QuestionId_Answer", "MisconceptionId", "pred_scores"]].explode(["MisconceptionId", "pred_scores"]).reset_index(drop=True)
    ranker_preds = ranker_preds.rename(columns={"pred_scores": "score_ranker"})
    ranker_preds["MisconceptionId"] = ranker_preds["MisconceptionId"].astype(str)

    # blend ---
    w_ret = cfg.ret_weight
    w_ranker = cfg.ranker_weight

    candidate_df = pd.merge(ret_preds, ranker_preds, on=["QuestionId_Answer", "MisconceptionId"])
    candidate_df["score"] = candidate_df.apply(lambda x: w_ret*x['score_ret'] + w_ranker*x['score_ranker'], axis=1) # blending
    candidate_df = candidate_df[["QuestionId_Answer", "MisconceptionId", "score"]].copy()

    cdf = candidate_df.groupby("QuestionId_Answer")["MisconceptionId"].agg(list).reset_index()
    sdf = candidate_df.groupby("QuestionId_Answer")["score"].agg(list).reset_index()
    candidate_df = pd.merge(cdf, sdf, on="QuestionId_Answer", how="left")

    candidate_df["sorted"] = candidate_df.apply(lambda x: get_sorted_pairs(x['MisconceptionId'], x['score']), axis=1)
    candidate_df["MisconceptionId"] = candidate_df["sorted"].apply(lambda x: [y[0] for y in x])
    candidate_df["score"] = candidate_df["sorted"].apply(lambda x: [y[1] for y in x])
    candidate_df = candidate_df.drop(columns=['sorted'])
    
    print("--"*40)
    print(f"saving retriever+ranker prediction to: {cfg.blended_pred_path}")
    candidate_df.to_parquet(cfg.blended_pred_path)
    print("Example:")
    print(candidate_df.sample().T)
    print("--"*40)

    # Cut at N ---
    candidate_df = cut_at_n(candidate_df, n=cfg.cutoff_n)
    input_df = pd.read_parquet(cfg.ranker_input_file)
    print(f"Shape of ranker input previously: {input_df.shape}")
    keep_df = candidate_df[['QuestionId_Answer', 'MisconceptionId']].explode("MisconceptionId").reset_index(drop=True)
    keep_df['MisconceptionId'] = keep_df['MisconceptionId'].astype(input_df["MisconceptionId"].dtype)
    input_df = input_df.merge(keep_df, on=["QuestionId_Answer", "MisconceptionId"], how="inner")
    print(f"shape of ranker input for next stage: {input_df.shape}")

    # # Prepare further ranking ---
    save_path = cfg.reranker_input_path
    print(f"saving output to: {save_path}")
    input_df.to_parquet(save_path)
    print(f"shape of input_df output: {input_df.shape}")
    print("--"*40)

In [None]:
%%writefile one_two_blend.yaml

ranker_input_file: ./retriever_outputs/ranker_input_stage_one.parquet

ret_path: ./retriever_outputs/stage_one_blended.parquet
ranker_path: ./ranker_outputs/ranker_qwen_14b.parquet

ret_weight: 0.0
ranker_weight: 1.0

cutoff_n: 8 # recall ~0.82

blended_pred_path: ./ranker_outputs/one_two_blended.parquet
reranker_input_path: ./ranker_outputs/ranker_input_stage_two.parquet

In [None]:
!python blend_one_two.py --config-path one_two_blend.yaml

## Optional Sub

In [None]:
# import pandas as pd

# pred_df = pd.read_parquet("./ranker_outputs/one_two_blended.parquet")
# pred_df["MisconceptionId"] = pred_df["MisconceptionId"].apply(lambda x: x[:25])
# pred_df["MisconceptionId"] = pred_df["MisconceptionId"].apply(lambda x: " ".join(x))
# sub_df = pred_df[["QuestionId_Answer", "MisconceptionId"]].copy()
# sub_df.to_csv("submission.csv", index=False)

# sub_df.head()

# Stage 3: Reranker (Sonnet)

## 3.1 Scripts

In [None]:
%%writefile run_sonnet_v1.py

import sys
sys.path.insert(0, "/kaggle/input/eedi-utils-v08")

import argparse
import gc
import os
import vllm

import random
from collections import defaultdict

import numpy as np
import pandas as pd
import torch

from llm_oracle.ranker_dataset import RankerDataset
from llm_oracle.ranker_loader import RankerCollator, RankerCollatorTrain, show_batch
from llm_oracle.ranker_model import EediRanker

from omegaconf import OmegaConf

def is_nan(x): return x != x
    
def stable_softmax(x, temp=1.0):
    x = np.array(x) / temp
    x_max = np.max(x)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x)
    
def eedi_process_df(df):
    df = df.copy()
    df = df.rename(columns={"QuestionId": "query_id"})
    grouped = df.groupby("query_id")

    question_dict = {}
    for question_id, group in grouped:
        question_data = group.to_dict(orient="records")[0]
        del question_data["query_id"]
        question_dict[question_id] = question_data

    all_questions = list(question_dict.keys())

    queries = []
    query2content = defaultdict(list)
    content2query = defaultdict(list)

    # ---
    for qid in all_questions:
        info = question_dict[qid]

        for answer_key in ["A", "B", "C", "D"]:
            if info["CorrectAnswer"] == answer_key:
                continue
            this_example = dict()
            this_key = f"{qid}_{answer_key}"
            this_example["query_id"] = this_key

            if is_nan(info[f"Misconception{answer_key}Id"]):
                continue

            mid = str(int(info[f"Misconception{answer_key}Id"]))
            query2content[this_key].append(mid)
            content2query[mid].append(this_key)

            # ---
            for col in ["SubjectId", "SubjectName", "ConstructName", "QuestionText"]:
                this_example[col] = info[col]

            this_example["CorrectAnswerText"] = info[f"Answer{info['CorrectAnswer']}Text"]
            this_example["InCorrectAnswerText"] = info[f"Answer{answer_key}Text"]
            this_example["AllOptionText"] = "\n- ".join([info[f"Answer{x}Text"] for x in ["A", "B", "C", "D"]])
            this_example["AllOptionText"] = f"\n- {this_example['AllOptionText']}"
            queries.append(this_example)
    # --
    query_df = pd.DataFrame(queries)
    corr_df = pd.Series(query2content).reset_index().rename(columns={"index": "query_id", 0: "content_id"})
    corr_df["content_id"] = corr_df["content_id"].apply(lambda x: x[0])

    query_df = query_df.reset_index(drop=True)

    return query_df, corr_df, content2query

def sort_by_scores(pred_ids, scores):
    keep_idxs = np.argsort(-np.array(scores)).tolist()
    ret_ids = [pred_ids[idx] for idx in keep_idxs]
    ret_scores = [scores[idx] for idx in keep_idxs]
    return {"sorted_ids": ret_ids, "sorted_scores": ret_scores}


def format_example(row):
    example = f"Question: {row['QuestionText']}\nAnswer:{row['CorrectAnswerText']}\nMisconception Answer: {row['InCorrectAnswerText']}"
    return example

def add_fs_examples(df, content2query, query2example, rng, n=2):
    cache = {}
    def _add_examples(row):
        cid = row["content_id"]
        if cid in cache:
            return cache[cid]
        else:
            qids = content2query[cid]
            qids = [qid for qid in qids if qid != row["query_id"]]
            if len(qids) == 0:
                cache[cid] = ""
                return ""

            qids = rng.sample(qids, k=min(n, len(qids)))
            examples = [query2example[qid] for qid in qids]
            fs = "\n--\n".join(examples)
            cache[cid] = fs
            return fs

    df["examples"] = df.apply(_add_examples, axis=1)
    return df


def main(cfg, save_dir, model_id):
    test_df = pd.read_parquet(cfg.input_path)
    test_df = test_df.rename(columns={"QuestionId_Answer": "query_id", "MisconceptionId": "content_id"})
    test_df["content_id"] = test_df["content_id"].astype(str)

    # comp data examples ---
    rng = random.Random(cfg.seed)
    comp_df = pd.read_csv(cfg.icl_path).rename(columns={"QuestionId": "query_id"})


    query_df, _, content2query = eedi_process_df(comp_df)
    query_df["demo"] = query_df.apply(format_example, axis=1)
    query2example = dict(zip(query_df["query_id"], query_df["demo"]))

    # add few shot examples --
    test_df = add_fs_examples(test_df, content2query, query2example, rng, n=cfg.k_shot)
    test_df = test_df.sort_values(by='content_id').reset_index(drop=True)

    # cot --
    if cfg.use_cot:
        print("Loading CoT....")
        cot_df = pd.read_parquet(cfg.cot_path)
        test_df = test_df.merge(cot_df, on='query_id', how='left')
        num_missing = test_df["cot"].isna().sum()
        print(f"# of missing cot: {num_missing}")
        test_df["cot"] = test_df["cot"].fillna("")
        print(test_df.sample().T.to_dict())
        print("--"*40)

    #---
    dataset_creator = RankerDataset(cfg)
    infer_ds = dataset_creator.get_dataset(test_df)
    
    tokenizer = dataset_creator.tokenizer
    infer_ds = infer_ds.map(lambda example: {"prompt": tokenizer.decode(example['input_ids'], skip_special_tokens=False)})
    
    infer_qa_ids = infer_ds["query_id"]
    infer_mc_ids = infer_ds["content_id"]
    
    prompts = infer_ds['prompt']

    print(f"# of requests: {len(prompts)}")
    print(f"Example:\n\n{prompts[0]}")
    
    # -- in
    llm = vllm.LLM(
        cfg.model.backbone_path,
        quantization="awq",
        tensor_parallel_size=2,
        gpu_memory_utilization=0.99,
        trust_remote_code=True,
        dtype="half",
        enforce_eager=True,
        max_model_len=2048,
        enable_prefix_caching=True
    )
    
    sampling_params = vllm.SamplingParams(n=1, top_p=0.8, logprobs=20, max_tokens=1, temperature=0.0, skip_special_tokens=False)
    responses = llm.generate(prompts, sampling_params, use_tqdm=True)

    # Get Results
    print("--"*40)
    yes_tok_id = tokenizer("Yes", add_special_tokens=False)["input_ids"][-1]
    no_tok_id = tokenizer("No", add_special_tokens=False)["input_ids"][-1]
    
    print(f">> EediRanker: Yes token id: {yes_tok_id} | Expected: 9454")
    print(f">> EediRanker: No token id: {no_tok_id} | Expected: 2753")
    print("--"*40)
    
    QuestionId_Answer = []
    MisconceptionId =[]
    scores = []
    
    for qid, cid, response in zip(infer_qa_ids, infer_mc_ids, responses):
        logprob_dict = response.outputs[0].logprobs[0]

        top_tok_ids = set(list(logprob_dict.keys()))
        if len(top_tok_ids.intersection(set([yes_tok_id, no_tok_id]))) == 0:
            print(f"Bad Output for {qid} - {cid}")
            continue
        
        yes_logit, no_logit = -10.0, -10.0
        
        if yes_tok_id in logprob_dict:
            yes_logit = logprob_dict[yes_tok_id].logprob
        
        if no_tok_id in logprob_dict:
            no_logit = logprob_dict[no_tok_id].logprob
        
        score = yes_logit - no_logit
        
        QuestionId_Answer.append(qid)
        MisconceptionId.append(cid)
        scores.append(score)
    
    result_df = pd.DataFrame()
    result_df["QuestionId_Answer"] = QuestionId_Answer
    result_df["MisconceptionId"] = MisconceptionId
    result_df["score"] = scores
    
    agg_df = result_df.groupby("QuestionId_Answer")["MisconceptionId"].agg(list).reset_index()
    score_agg_df = result_df.groupby("QuestionId_Answer")["score"].agg(list).reset_index()
    agg_df = pd.merge(agg_df, score_agg_df, on="QuestionId_Answer", how="left")
    
    agg_df["topk_info"] = agg_df.apply(lambda x: sort_by_scores(x["MisconceptionId"], x["score"]), axis=1)
    agg_df["MisconceptionId"] = agg_df["topk_info"].apply(lambda x: x["sorted_ids"])
    agg_df["score"] = agg_df["topk_info"].apply(lambda x: x["sorted_scores"])
    
    # compute oof dataframe ---
    oof_df = agg_df.copy()
    oof_df = oof_df[["QuestionId_Answer", "MisconceptionId", "score"]].copy()
    oof_df = oof_df.rename(columns={"score": "logit_scores"})

    # normalize ---
    oof_df["pred_scores"] = oof_df["logit_scores"].apply(stable_softmax)
    oof_df["MisconceptionId"] = oof_df["MisconceptionId"].apply(lambda x: list(map(str, x)))

    # print ---
    print("--"*40)
    row = oof_df.sample()
    formatted_scores = [f"{s:.3f}" for s in row['pred_scores'].values[0]]
    misconceptions = row['MisconceptionId'].values[0]
    print(f"Showing 1 example: {row['QuestionId_Answer'].values[0]}")
    for rank, (m, s) in enumerate(zip(misconceptions, formatted_scores)):
        print(f"MisconceptionId: {m} -> Score: {s}")
    print("--"*40)

    save_path = os.path.join(save_dir, f"ranker_{model_id}.parquet")
    oof_df.to_parquet(save_path)

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument('--config_path', type=str, required=True)
    ap.add_argument('--save_dir', type=str, required=True)
    ap.add_argument('--model_id', type=str, required=True)

    args = ap.parse_args()
    cfg = OmegaConf.load(args.config_path)

    os.makedirs(args.save_dir, exist_ok=True)

    # execution ---
    main(cfg, save_dir=args.save_dir, model_id=args.model_id)

In [None]:
%%writefile run_sonnet_v2.py

import sys
sys.path.insert(0, "/kaggle/input/eedi-utils-v12")

import argparse
import gc
import os
import vllm

import random
from collections import defaultdict

import numpy as np
import pandas as pd
import torch

from llm_oracle.ranker_dataset import RankerDataset
from llm_oracle.ranker_loader import RankerCollator, RankerCollatorTrain, show_batch
from llm_oracle.ranker_model import EediRanker

from omegaconf import OmegaConf

def is_nan(x): return x != x
    
def stable_softmax(x, temp=1.0):
    x = np.array(x) / temp
    x_max = np.max(x)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x)
    
def eedi_process_df(df):
    df = df.copy()
    df = df.rename(columns={"QuestionId": "query_id"})
    grouped = df.groupby("query_id")

    question_dict = {}
    for question_id, group in grouped:
        question_data = group.to_dict(orient="records")[0]
        del question_data["query_id"]
        question_dict[question_id] = question_data

    all_questions = list(question_dict.keys())

    queries = []
    query2content = defaultdict(list)
    content2query = defaultdict(list)

    # ---
    for qid in all_questions:
        info = question_dict[qid]

        for answer_key in ["A", "B", "C", "D"]:
            if info["CorrectAnswer"] == answer_key:
                continue
            this_example = dict()
            this_key = f"{qid}_{answer_key}"
            this_example["query_id"] = this_key

            if is_nan(info[f"Misconception{answer_key}Id"]):
                continue

            mid = str(int(info[f"Misconception{answer_key}Id"]))
            query2content[this_key].append(mid)
            content2query[mid].append(this_key)

            # ---
            for col in ["SubjectId", "SubjectName", "ConstructName", "QuestionText"]:
                this_example[col] = info[col]

            this_example["CorrectAnswerText"] = info[f"Answer{info['CorrectAnswer']}Text"]
            this_example["InCorrectAnswerText"] = info[f"Answer{answer_key}Text"]
            this_example["AllOptionText"] = "\n- ".join([info[f"Answer{x}Text"] for x in ["A", "B", "C", "D"]])
            this_example["AllOptionText"] = f"\n- {this_example['AllOptionText']}"
            queries.append(this_example)
    # --
    query_df = pd.DataFrame(queries)
    corr_df = pd.Series(query2content).reset_index().rename(columns={"index": "query_id", 0: "content_id"})
    corr_df["content_id"] = corr_df["content_id"].apply(lambda x: x[0])

    query_df = query_df.reset_index(drop=True)

    return query_df, corr_df, content2query

def sort_by_scores(pred_ids, scores):
    keep_idxs = np.argsort(-np.array(scores)).tolist()
    ret_ids = [pred_ids[idx] for idx in keep_idxs]
    ret_scores = [scores[idx] for idx in keep_idxs]
    return {"sorted_ids": ret_ids, "sorted_scores": ret_scores}


def format_example(row):
    example = f"Question: {row['QuestionText']}\nAnswer:{row['CorrectAnswerText']}\nMisconception Answer: {row['InCorrectAnswerText']}"
    return example

def add_fs_examples(df, content2query, query2example, rng, n=2):
    cache = {}
    def _add_examples(row):
        cid = row["content_id"]
        if cid in cache:
            return cache[cid]
        else:
            qids = content2query[cid]
            qids = [qid for qid in qids if qid != row["query_id"]]
            if len(qids) == 0:
                cache[cid] = ""
                return ""

            qids = rng.sample(qids, k=min(n, len(qids)))
            examples = [query2example[qid] for qid in qids]
            fs = "\n--\n".join(examples)
            cache[cid] = fs
            return fs

    df["examples"] = df.apply(_add_examples, axis=1)
    return df


def main(cfg, save_dir, model_id):
    test_df = pd.read_parquet(cfg.input_path)
    test_df = test_df.rename(columns={"QuestionId_Answer": "query_id", "MisconceptionId": "content_id"})
    test_df["content_id"] = test_df["content_id"].astype(str)

    # comp data examples ---
    rng = random.Random(cfg.seed)
    comp_df = pd.read_csv(cfg.icl_path).rename(columns={"QuestionId": "query_id"})


    query_df, _, content2query = eedi_process_df(comp_df)
    query_df["demo"] = query_df.apply(format_example, axis=1)
    query2example = dict(zip(query_df["query_id"], query_df["demo"]))

    # add few shot examples --
    test_df = add_fs_examples(test_df, content2query, query2example, rng, n=cfg.k_shot)
    test_df = test_df.sort_values(by='content_id').reset_index(drop=True)

    # cot --
    if cfg.use_cot:
        print("Loading CoT....")
        cot_df = pd.read_parquet(cfg.cot_path)
        test_df = test_df.merge(cot_df, on='query_id', how='left')
        num_missing = test_df["cot"].isna().sum()
        print(f"# of missing cot: {num_missing}")
        test_df["cot"] = test_df["cot"].fillna("")
        print(test_df.sample().T.to_dict())
        print("--"*40)

    #---
    dataset_creator = RankerDataset(cfg)
    infer_ds = dataset_creator.get_dataset(test_df)
    
    tokenizer = dataset_creator.tokenizer
    infer_ds = infer_ds.map(lambda example: {"prompt": tokenizer.decode(example['input_ids'], skip_special_tokens=False)})
    
    infer_qa_ids = infer_ds["query_id"]
    infer_mc_ids = infer_ds["content_id"]
    
    prompts = infer_ds['prompt']

    print(f"# of requests: {len(prompts)}")
    print(f"Example:\n\n{prompts[0]}")
    
    # -- in
    llm = vllm.LLM(
        cfg.model.backbone_path,
        quantization="awq",
        tensor_parallel_size=2,
        gpu_memory_utilization=0.99,
        trust_remote_code=True,
        dtype="half",
        enforce_eager=True,
        max_model_len=2048,
        enable_prefix_caching=True
    )
    
    sampling_params = vllm.SamplingParams(n=1, top_p=0.8, logprobs=20, max_tokens=1, temperature=0.0, skip_special_tokens=False)
    responses = llm.generate(prompts, sampling_params, use_tqdm=True)

    # Get Results
    print("--"*40)
    yes_tok_id = tokenizer("Yes", add_special_tokens=False)["input_ids"][-1]
    no_tok_id = tokenizer("No", add_special_tokens=False)["input_ids"][-1]
    
    print(f">> EediRanker: Yes token id: {yes_tok_id} | Expected: 9454")
    print(f">> EediRanker: No token id: {no_tok_id} | Expected: 2753")
    print("--"*40)
    
    QuestionId_Answer = []
    MisconceptionId =[]
    scores = []
    
    for qid, cid, response in zip(infer_qa_ids, infer_mc_ids, responses):
        logprob_dict = response.outputs[0].logprobs[0]

        top_tok_ids = set(list(logprob_dict.keys()))
        if len(top_tok_ids.intersection(set([yes_tok_id, no_tok_id]))) == 0:
            print(f"Bad Output for {qid} - {cid}")
            continue
        
        yes_logit, no_logit = -10.0, -10.0
        
        if yes_tok_id in logprob_dict:
            yes_logit = logprob_dict[yes_tok_id].logprob
        
        if no_tok_id in logprob_dict:
            no_logit = logprob_dict[no_tok_id].logprob
        
        score = yes_logit - no_logit
        
        QuestionId_Answer.append(qid)
        MisconceptionId.append(cid)
        scores.append(score)
    
    result_df = pd.DataFrame()
    result_df["QuestionId_Answer"] = QuestionId_Answer
    result_df["MisconceptionId"] = MisconceptionId
    result_df["score"] = scores
    
    agg_df = result_df.groupby("QuestionId_Answer")["MisconceptionId"].agg(list).reset_index()
    score_agg_df = result_df.groupby("QuestionId_Answer")["score"].agg(list).reset_index()
    agg_df = pd.merge(agg_df, score_agg_df, on="QuestionId_Answer", how="left")
    
    agg_df["topk_info"] = agg_df.apply(lambda x: sort_by_scores(x["MisconceptionId"], x["score"]), axis=1)
    agg_df["MisconceptionId"] = agg_df["topk_info"].apply(lambda x: x["sorted_ids"])
    agg_df["score"] = agg_df["topk_info"].apply(lambda x: x["sorted_scores"])
    
    # compute oof dataframe ---
    oof_df = agg_df.copy()
    oof_df = oof_df[["QuestionId_Answer", "MisconceptionId", "score"]].copy()
    oof_df = oof_df.rename(columns={"score": "logit_scores"})

    # normalize ---
    oof_df["pred_scores"] = oof_df["logit_scores"].apply(stable_softmax)
    oof_df["MisconceptionId"] = oof_df["MisconceptionId"].apply(lambda x: list(map(str, x)))

    # print ---
    print("--"*40)
    row = oof_df.sample()
    formatted_scores = [f"{s:.3f}" for s in row['pred_scores'].values[0]]
    misconceptions = row['MisconceptionId'].values[0]
    print(f"Showing 1 example: {row['QuestionId_Answer'].values[0]}")
    for rank, (m, s) in enumerate(zip(misconceptions, formatted_scores)):
        print(f"MisconceptionId: {m} -> Score: {s}")
    print("--"*40)

    save_path = os.path.join(save_dir, f"ranker_{model_id}.parquet")
    oof_df.to_parquet(save_path)

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument('--config_path', type=str, required=True)
    ap.add_argument('--save_dir', type=str, required=True)
    ap.add_argument('--model_id', type=str, required=True)

    args = ap.parse_args()
    cfg = OmegaConf.load(args.config_path)

    os.makedirs(args.save_dir, exist_ok=True)

    # execution ---
    main(cfg, save_dir=args.save_dir, model_id=args.model_id)

## 3.2 Qwen 32b Configs

In [None]:
%%writefile conf_oracle_32b_cv663_ff.yaml

seed: 4562
k_shot: 0
use_cot: true

input_path: ./ranker_outputs/ranker_input_stage_two.parquet
icl_path: /kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv
cot_path: ./gen/gen_qwen_14b.parquet

model:
    backbone_path: "/kaggle/input/eedi-oracle-32b-cv663-dec7-awq-ff/transformers/default/1"
    max_length: 768
    num_proc: 2
    
    tokenizer:
        padding_side: left
        truncation_side: left
        use_fast: true

## 3.3 Qwen32b Infer

In [None]:
%%time
!python run_sonnet_v2.py --config_path "./conf_oracle_32b_cv663_ff.yaml" --save_dir "./ranker_outputs" --model_id "qwen_32b_oracle_main"

## 3.4 Prep for stage 4

In [None]:
%%writefile blend_two_three.py

import argparse
import os
from copy import deepcopy

import pandas as pd
import numpy as np
from omegaconf import OmegaConf

def get_sorted_pairs(content_ids, scores):
    collection = [(cid, s) for cid, s in zip(content_ids, scores)]
    sorted_collection = sorted(collection, key=lambda x: x[1], reverse=True)
    return sorted_collection

def cut_at_n(sub_df, n=25):
    sub_df["MisconceptionId"] = sub_df["MisconceptionId"].apply(lambda x: x[:n])
    sub_df["score"] = sub_df["score"].apply(lambda x: x[:n])
    return sub_df

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-path", type=str)
    args = parser.parse_args()
    
    with open(args.config_path, "r") as f:
        cfg = OmegaConf.load(f)

    # read predictions ---
    haiku_preds = pd.read_parquet(cfg.haiku_path)
    sonnet_preds = pd.read_parquet(cfg.sonnet_path)
    
    print("Sample Haiku Preds:")
    print(haiku_preds.sample(3))
    print("--"*50)
    print("Sample Sonnet Preds:")
    print(sonnet_preds.sample(3))
    print("--"*50)

    # flatten ---
    haiku_preds = haiku_preds[["QuestionId_Answer", "MisconceptionId", "pred_scores"]].explode(["MisconceptionId", "pred_scores"]).reset_index(drop=True)
    haiku_preds = haiku_preds.rename(columns={"pred_scores": "score_haiku"})
    haiku_preds["MisconceptionId"] = haiku_preds["MisconceptionId"].astype(str)

    sonnet_preds = sonnet_preds[["QuestionId_Answer", "MisconceptionId", "pred_scores"]].explode(["MisconceptionId", "pred_scores"]).reset_index(drop=True)
    sonnet_preds = sonnet_preds.rename(columns={"pred_scores": "score_sonnet"})
    sonnet_preds["MisconceptionId"] = sonnet_preds["MisconceptionId"].astype(str)

    # blend ---
    w_haiku = cfg.haiku_weight
    w_sonnet = cfg.sonnet_weight

    candidate_df = pd.merge(haiku_preds, sonnet_preds, on=["QuestionId_Answer", "MisconceptionId"])
    candidate_df["score"] = candidate_df.apply(lambda x: w_haiku*x['score_haiku'] + w_sonnet*x['score_sonnet'], axis=1) # blending
    candidate_df = candidate_df[["QuestionId_Answer", "MisconceptionId", "score"]].copy()

    cdf = candidate_df.groupby("QuestionId_Answer")["MisconceptionId"].agg(list).reset_index()
    sdf = candidate_df.groupby("QuestionId_Answer")["score"].agg(list).reset_index()
    candidate_df = pd.merge(cdf, sdf, on="QuestionId_Answer", how="left")

    candidate_df["sorted"] = candidate_df.apply(lambda x: get_sorted_pairs(x['MisconceptionId'], x['score']), axis=1)
    candidate_df["MisconceptionId"] = candidate_df["sorted"].apply(lambda x: [y[0] for y in x])
    candidate_df["score"] = candidate_df["sorted"].apply(lambda x: [y[1] for y in x])
    candidate_df = candidate_df.drop(columns=['sorted'])
    
    print("--"*40)
    print(f"Saving sonnet+haiku prediction to: {cfg.blended_pred_path}")
    candidate_df.to_parquet(cfg.blended_pred_path)
    
    print("Example:")
    print(candidate_df.sample().T)
    print("--"*40)
    print("Distribution:")
    print(candidate_df['MisconceptionId'].apply(len).value_counts())
    print("--"*40)

    # Cut at N ---
    candidate_df = cut_at_n(candidate_df, n=cfg.cutoff_n)
    candidate_df.to_parquet(cfg.tutor_base_path)
    
    input_df = pd.read_parquet(cfg.ranker_input_file)
    
    print(f"Shape of ranker input previously: {input_df.shape}")
    keep_df = candidate_df[['QuestionId_Answer', 'MisconceptionId']].explode("MisconceptionId").reset_index(drop=True)
    keep_df['MisconceptionId'] = keep_df['MisconceptionId'].astype(input_df["MisconceptionId"].dtype)
    input_df = input_df.merge(keep_df, on=["QuestionId_Answer", "MisconceptionId"], how="inner")
    print(f"shape of ranker input for stage 4: {input_df.shape}")

    # # Prepare further ranking ---
    save_path = cfg.stage4_input_path
    print(f"saving output to: {save_path}")
    input_df.to_parquet(save_path)
    print(f"shape of input_df output: {input_df.shape}")
    print("--"*40)

In [None]:
%%writefile two_three_blend.yaml

ranker_input_file: ./retriever_outputs/ranker_input_stage_one.parquet

haiku_path: ./ranker_outputs/ranker_qwen_14b.parquet
sonnet_path: ./ranker_outputs/ranker_qwen_32b_oracle_main.parquet

haiku_weight: 1.0
sonnet_weight: 2.0

cutoff_n: 5

blended_pred_path: ./ranker_outputs/two_three_blended.parquet
tutor_base_path: ./ranker_outputs/tutor_base.parquet
stage4_input_path: ./ranker_outputs/ranker_input_stage_four.parquet

In [None]:
!python blend_two_three.py --config-path two_three_blend.yaml

## 3.5 Tutor Data Prep

In [None]:
%%writefile prep_tutor_data.py

import argparse
import os
from copy import deepcopy

import pandas as pd
import numpy as np
from omegaconf import OmegaConf

def eedi_process_df(df):
    df = deepcopy(df)
    df = df.rename(columns={"QuestionId": "query_id"})
    grouped = df.groupby("query_id")

    question_dict = {}
    for question_id, group in grouped:
        question_data = group.to_dict(orient="records")[0]
        del question_data["query_id"]
        question_dict[question_id] = question_data

    all_questions = list(question_dict.keys())

    queries = []
    for qid in all_questions:
        info = question_dict[qid]

        for answer_key in ["A", "B", "C", "D"]:
            if info["CorrectAnswer"] == answer_key:
                continue

            this_example = dict()
            this_key = f"{qid}_{answer_key}"
            this_example["query_id"] = this_key

            for col in ["SubjectName", "ConstructName", "QuestionText"]:
                this_example[col] = info[col]

            this_example["CorrectAnswerText"] = info[f"Answer{info['CorrectAnswer']}Text"]
            this_example["InCorrectAnswerText"] = info[f"Answer{answer_key}Text"]
            queries.append(this_example)
    # --
    query_df = pd.DataFrame(queries)
    return query_df
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-path", type=str)
    args = parser.parse_args()
    
    with open(args.config_path, "r") as f:
        cfg = OmegaConf.load(f)


    # read data ---
    test_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/test.csv")
    if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        n_ex = int(os.getenv("N_EX"))
        test_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv").head(n_ex)
    content_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")
    id2name = dict(zip(content_df['MisconceptionId'], content_df['MisconceptionName']))

    query_df = eedi_process_df(test_df)

    # read predictions ---
    oof_df = pd.read_parquet(cfg.tutor_base_path)
    oof_df["MisconceptionId"] = oof_df["MisconceptionId"].apply(lambda x: list(map(int, x)))
    oof_df = oof_df.rename(columns={'QuestionId_Answer': 'query_id', 'MisconceptionId': 'content_ids'})
    oof_df = oof_df[['query_id', 'content_ids']].copy()

    # cots
    cot_df_7b = pd.read_parquet(cfg.cot_7b_path)
    cot_df_7b = cot_df_7b[['query_id', 'cot']].rename(columns={'cot': 'cot_7b'})

    cot_df_14b = pd.read_parquet(cfg.cot_14b_path)
    cot_df_14b = cot_df_14b[['query_id', 'cot']].rename(columns={'cot': 'cot_14b'})
    

    cot_df_32b = pd.read_parquet(cfg.cot_32b_path)
    cot_df_32b = cot_df_32b[['query_id', 'cot']].rename(columns={'cot': 'cot_32b'})

    # prep
    query_df = query_df.merge(oof_df, on="query_id", how="left")
    query_df['MisconceptionNameList'] = query_df['content_ids'].apply(lambda x: [id2name[y] for y in x])

    query_df = query_df.merge(cot_df_7b, on='query_id', how='left')
    query_df = query_df.merge(cot_df_14b, on='query_id', how='left')
    query_df = query_df.merge(cot_df_32b, on='query_id', how='left')

    print(query_df["cot_7b"].isna().sum())
    print(query_df["cot_14b"].isna().sum())
    print(query_df["cot_32b"].isna().sum())

    query_df["cot_7b"] = query_df["cot_7b"].fillna("")
    query_df["cot_14b"] = query_df["cot_14b"].fillna("")
    query_df["cot_32b"] = query_df["cot_32b"].fillna("")

    final_df = query_df[
        ["query_id", "content_ids", "SubjectName", "ConstructName", "QuestionText", 
         "CorrectAnswerText", "InCorrectAnswerText", "MisconceptionNameList", "cot_7b", "cot_14b", "cot_32b"]
    ].copy()
    
    print("--"*50)
    final_df = final_df.sort_values(by="query_id").reset_index(drop=True)
    ex = final_df.sample().to_dict(orient='records')[0]
    for k, v in ex.items():
        print(f"{k} -> {v}")

    print(f"saving tutor data to: {cfg.tutor_input_path}")
    final_df.to_parquet(cfg.tutor_input_path)
    print("--"*50)

In [None]:
%%writefile tutor_data_prep.yaml

tutor_base_path: ./ranker_outputs/tutor_base.parquet

cot_7b_path: ./gen/gen_qwen_7b.parquet
cot_14b_path: ./gen/gen_qwen_14b.parquet
cot_32b_path: ./gen/gen_qwen_32b.parquet

tutor_input_path: ./ranker_outputs/tutor_input.parquet

In [None]:
!python prep_tutor_data.py --config-path tutor_data_prep.yaml

# Stage 4: Ranker (Opus)

In [None]:
%%writefile run_expert_tutor.py

import sys

sys.path.insert(0, "/kaggle/input/eedi-utils-v12")

import argparse
import os
import random
from collections import defaultdict

import numpy as np
import pandas as pd
import vllm
from llm_tutor.ranker_dataset import RankerDataset
from omegaconf import OmegaConf


def is_nan(x):
    return x != x


def eedi_process_df(df):
    df = df.copy()
    df = df.rename(columns={"QuestionId": "query_id"})
    grouped = df.groupby("query_id")

    question_dict = {}
    for question_id, group in grouped:
        question_data = group.to_dict(orient="records")[0]
        del question_data["query_id"]
        question_dict[question_id] = question_data

    all_questions = list(question_dict.keys())

    queries = []
    query2content = defaultdict(list)
    content2query = defaultdict(list)

    # ---
    for qid in all_questions:
        info = question_dict[qid]

        for answer_key in ["A", "B", "C", "D"]:
            if info["CorrectAnswer"] == answer_key:
                continue
            this_example = dict()
            this_key = f"{qid}_{answer_key}"
            this_example["query_id"] = this_key

            if is_nan(info[f"Misconception{answer_key}Id"]):
                continue

            mid = str(int(info[f"Misconception{answer_key}Id"]))
            query2content[this_key].append(mid)
            content2query[mid].append(this_key)

            # ---
            for col in ["SubjectId", "SubjectName", "ConstructName", "QuestionText"]:
                this_example[col] = info[col]

            this_example["CorrectAnswerText"] = info[f"Answer{info['CorrectAnswer']}Text"]
            this_example["InCorrectAnswerText"] = info[f"Answer{answer_key}Text"]
            this_example["AllOptionText"] = "\n- ".join([info[f"Answer{x}Text"] for x in ["A", "B", "C", "D"]])
            this_example["AllOptionText"] = f"\n- {this_example['AllOptionText']}"
            queries.append(this_example)
    # --
    query_df = pd.DataFrame(queries)
    corr_df = pd.Series(query2content).reset_index().rename(columns={"index": "query_id", 0: "content_id"})
    corr_df["content_id"] = corr_df["content_id"].apply(lambda x: x[0])

    query_df = query_df.reset_index(drop=True)

    return query_df, corr_df, content2query


def stable_softmax(x, temp=1.0):
    x = np.array(x) / temp
    x_max = np.max(x)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x)


def sort_by_scores(pred_ids, scores):
    keep_idxs = np.argsort(-np.array(scores)).tolist()
    ret_ids = [pred_ids[idx] for idx in keep_idxs]
    ret_scores = [scores[idx] for idx in keep_idxs]
    return {"sorted_ids": ret_ids, "sorted_scores": ret_scores}


def format_example(row, id2name, query2content):
    cid = int(query2content[row["query_id"]])
    misconception_name = id2name[cid]
    example = f"Question: {row['QuestionText']}\nAnswer:{row['CorrectAnswerText']}\nIncorrect Answer: {row['InCorrectAnswerText']}\nMisconception: {misconception_name}"
    return example


def add_fs_examples_tutor_for_eval(df, content2query, query2example, rng):
    cache = {}

    def _add_examples(row):
        cids = row["content_ids"]

        selected_qids = []
        for cid in cids:
            cid = int(cid)
            qids = content2query[cid]  # content2query is a defaultdict(list)
            qids = [qid for qid in qids if qid != row["query_id"]]
            if len(qids) > 0:
                if cid not in cache:
                    selected = rng.choice(qids)
                    cache[cid] = selected
                    selected_qids.append(cache[cid])
                else:
                    selected_qids.append(cache[cid])

        # print(f"selected_qids: {selected_qids}")
        if len(selected_qids) == 0:
            return ""

        # keep max of 3 examples
        selected_qids = rng.sample(selected_qids, k=min(3, len(selected_qids)))
        selected_qids = sorted(selected_qids)  # for prefix cache
        examples = [query2example[qid] for qid in selected_qids]
        fs = "\n--\n".join(examples)
        # print(f"fs: {fs}")
        return fs

    df["examples"] = df.apply(_add_examples, axis=1)
    return df


def main(cfg, save_dir, model_id):
    test_df = pd.read_parquet(cfg.input_path)

    if cfg.use_tta:
        test_df_tta = test_df.copy()
        test_df_tta["content_ids"] = test_df_tta["content_ids"].apply(lambda x: x[::-1])
        test_df_tta["MisconceptionNameList"] = test_df_tta["MisconceptionNameList"].apply(lambda x: x[::-1])
        test_df = pd.concat([test_df, test_df_tta]).reset_index(drop=True)
        test_df = test_df.sort_values(by="query_id").reset_index(drop=True)

    # comp few shot examples ---
    rng = random.Random(cfg.seed)

    fs_df = pd.read_csv(cfg.icl_path).rename(columns={"QuestionId": "query_id"})
    content_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")
    id2name = dict(zip(content_df["MisconceptionId"], content_df["MisconceptionName"]))

    query_df, fs_corr_df, content2query = eedi_process_df(fs_df)
    fs_query2content = dict(zip(fs_corr_df["query_id"], fs_corr_df["content_id"]))
    query_df["demo"] = query_df.apply(lambda x: format_example(x, id2name, fs_query2content), axis=1)
    query2example = dict(zip(query_df["query_id"], query_df["demo"]))

    test_df = add_fs_examples_tutor_for_eval(test_df, content2query, query2example, rng)
    print(f"shape of test data: {test_df.shape}")
    print("--" * 40)

    # ---
    dataset_creator = RankerDataset(cfg)
    infer_ds = dataset_creator.get_dataset(test_df)
    tokenizer = dataset_creator.tokenizer

    a_tok_id = tokenizer("A", add_special_tokens=False)["input_ids"][-1]
    b_tok_id = tokenizer("B", add_special_tokens=False)["input_ids"][-1]
    c_tok_id = tokenizer("C", add_special_tokens=False)["input_ids"][-1]
    d_tok_id = tokenizer("D", add_special_tokens=False)["input_ids"][-1]
    e_tok_id = tokenizer("E", add_special_tokens=False)["input_ids"][-1]

    print(f">> EediRanker: A token id: {a_tok_id}")
    print(f">> EediRanker: B token id: {b_tok_id}")
    print(f">> EediRanker: C token id: {c_tok_id}")
    print(f">> EediRanker: D token id: {d_tok_id}")
    print(f">> EediRanker: E token id: {e_tok_id}")

    infer_ds = infer_ds.map(lambda example: {"prompt": tokenizer.decode(example["input_ids"], skip_special_tokens=False)})

    infer_qa_ids = infer_ds["query_id"]
    infer_mc_ids = infer_ds["content_ids"]
    prompts = infer_ds["prompt"]

    print(f"# of requests: {len(prompts)}")
    print(f"Example:\n\n{prompts[0]}")
    print("data preparation done...")

    # -- create the llm ----#
    llm = vllm.LLM(
        cfg.model.backbone_path,
        # quantization="awq",
        tensor_parallel_size=2,
        gpu_memory_utilization=0.99,
        trust_remote_code=True,
        dtype="half",
        enforce_eager=True,
        max_model_len=2048,
        disable_log_stats=True,
        cpu_offload_gb=8,
        swap_space=1,
        device="cuda",
        max_num_seqs=20,
        # enable_prefix_caching=True
    )

    sampling_params = vllm.SamplingParams(n=1, top_p=0.8, logprobs=20, max_tokens=1, temperature=0.0, skip_special_tokens=False)
    responses = llm.generate(prompts, sampling_params, use_tqdm=True)
    print("inference done...")

    # get results ---
    print("--" * 40)

    QuestionId_Answer = []
    MisconceptionId = []
    scores = []

    for qid, cids, response in zip(infer_qa_ids, infer_mc_ids, responses):
        logprob_dict = response.outputs[0].logprobs[0]

        top_tok_ids = set(list(logprob_dict.keys()))
        if len(top_tok_ids.intersection(set([a_tok_id, b_tok_id, c_tok_id, d_tok_id, e_tok_id]))) == 0:
            print(f"Bad Output for {qid}")
            continue

        a_logit, b_logit, c_logit, d_logit, e_logit = -10.0, -10.0, -10.0, -10.0, -10.0

        if a_tok_id in logprob_dict:
            a_logit = logprob_dict[a_tok_id].logprob

        if b_tok_id in logprob_dict:
            b_logit = logprob_dict[b_tok_id].logprob

        if c_tok_id in logprob_dict:
            c_logit = logprob_dict[c_tok_id].logprob

        if d_tok_id in logprob_dict:
            d_logit = logprob_dict[d_tok_id].logprob

        if e_tok_id in logprob_dict:
            e_logit = logprob_dict[e_tok_id].logprob

        logits = np.array([a_logit, b_logit, c_logit, d_logit, e_logit])
        logits_max = np.max(logits)
        exp_logits = np.exp(logits - logits_max)
        normalized_scores = exp_logits / np.sum(exp_logits)

        QuestionId_Answer.append(qid)
        MisconceptionId.append(cids)
        scores.append(normalized_scores)

    result_df = pd.DataFrame()
    result_df["QuestionId_Answer"] = QuestionId_Answer
    result_df["MisconceptionId"] = MisconceptionId
    result_df["MisconceptionId"] = result_df["MisconceptionId"].apply(lambda x: [str(y) for y in x])
    result_df["score"] = scores

    # ----
    if cfg.use_tta:
        result_df = result_df.explode(["MisconceptionId", "score"]).reset_index(drop=True)
        result_df = result_df.groupby(["QuestionId_Answer", "MisconceptionId"]).agg({"score": "mean"}).reset_index()

        # regroup --
        agg_df = result_df.groupby("QuestionId_Answer")["MisconceptionId"].agg(list).reset_index()
        score_agg_df = result_df.groupby("QuestionId_Answer")["score"].agg(list).reset_index()
        agg_df = pd.merge(agg_df, score_agg_df, on="QuestionId_Answer", how="left")
        result_df = agg_df.copy()

    # --------
    agg_df = result_df.copy()
    agg_df["topk_info"] = agg_df.apply(lambda x: sort_by_scores(x["MisconceptionId"], x["score"]), axis=1)
    agg_df["MisconceptionId"] = agg_df["topk_info"].apply(lambda x: x["sorted_ids"])
    agg_df["score"] = agg_df["topk_info"].apply(lambda x: x["sorted_scores"])

    # compute oof dataframe ---
    oof_df = agg_df.copy()
    oof_df = oof_df[["QuestionId_Answer", "MisconceptionId", "score"]].copy()
    oof_df = oof_df.rename(columns={"score": "logit_scores"})

    # normalize ---
    oof_df["pred_scores"] = oof_df["logit_scores"].apply(stable_softmax)
    oof_df["MisconceptionId"] = oof_df["MisconceptionId"].apply(lambda x: list(map(str, x)))

    # print ---
    print("--" * 40)
    row = oof_df.sample()
    formatted_scores = [f"{s:.3f}" for s in row["pred_scores"].values[0]]
    misconceptions = row["MisconceptionId"].values[0]
    print(f"Showing 1 example: {row['QuestionId_Answer'].values[0]}")
    for rank, (m, s) in enumerate(zip(misconceptions, formatted_scores)):
        print(f"MisconceptionId: {m} -> Score: {s}")
    print("--" * 40)

    save_path = os.path.join(save_dir, f"ranker_{model_id}.parquet")
    oof_df.to_parquet(save_path)


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--config_path", type=str, required=True)
    ap.add_argument("--save_dir", type=str, required=True)
    ap.add_argument("--model_id", type=str, required=True)

    args = ap.parse_args()
    cfg = OmegaConf.load(args.config_path)

    os.makedirs(args.save_dir, exist_ok=True)

    # execution ---
    main(cfg, save_dir=args.save_dir, model_id=args.model_id)

In [None]:
%%writefile conf_expert_tutor.yaml

seed: 8798
k_shot: 0

use_tta: false

input_path: ./ranker_outputs/tutor_input.parquet
icl_path: /kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv

model:
    backbone_path: "/kaggle/input/eedi-tutor-72b-cv661-dec7-awq/transformers/default/1"
    max_length: 2048
    num_proc: 2
    
    tokenizer:
        padding_side: left
        truncation_side: left
        use_fast: true

In [None]:
%%time
!python run_expert_tutor.py --config_path "./conf_expert_tutor.yaml" --save_dir "./ranker_outputs" --model_id "qwen_72b_tutor"

In [None]:
# time for 16 examples (tta, no prompt caching): Processed prompts: 100%|█| 96/96 [04:45<00:00,  2.97s/it, est. speed input: 177.
# time for 16 examples (tta, prompt caching)   : 100%|█| 96/96 [03:54<00:00,  2.44s/it, est. speed input: 215.

## 4.2 Additional 32b

In [None]:
%%writefile conf_oracle_32b_cv652.yaml

seed: 9461
k_shot: 0
use_cot: false

input_path: ./ranker_outputs/ranker_input_stage_four.parquet
icl_path: /kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv
cot_path: ./gen/gen_qwen_14b.parquet

model:
    backbone_path: "/kaggle/input/eedi-ranker-32b-cv655-nov24-custom-awq/transformers/default/1"
    max_length: 640
    num_proc: 2
    
    tokenizer:
        padding_side: left
        truncation_side: left
        use_fast: true

In [None]:
# %%time
# !python run_sonnet_v1.py --config_path "./conf_oracle_32b_cv652.yaml" --save_dir "./ranker_outputs" --model_id "qwen_32b_oracle_support"

## 4.3 Ensemble Rankers

In [None]:
%%writefile ensemble_rankers.py

import argparse
import os
from copy import deepcopy

import pandas as pd
import numpy as np
from omegaconf import OmegaConf

def get_sorted_pairs(content_ids, scores):
    collection = [(cid, s) for cid, s in zip(content_ids, scores)]
    sorted_collection = sorted(collection, key=lambda x: x[1], reverse=True)
    return sorted_collection

def stable_softmax(x, temp=1.0):
    x = np.array(x) / temp
    x_max = np.max(x)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x)
    
def _load_df(pth, cutoff):
    df = pd.read_parquet(pth)
    # df["MisconceptionId"] = df["MisconceptionId"].apply(lambda x: x[:cutoff])
    
    print(f"recomputing softmax for {pth}...")
    # df["logit_scores"] = df["logit_scores"].apply(lambda x: x[:cutoff])
    df["pred_scores"] = df["logit_scores"].apply(stable_softmax) # recompute ---
    print(df.sample().T)
    print("--"*40)

    return df


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-path", type=str)
    args = parser.parse_args()
    
    with open(args.config_path, "r") as f:
        cfg = OmegaConf.load(f)

    # read predictions ---
    one_preds = _load_df(cfg.ranker_1, cutoff=cfg.cutoff_n)
    two_preds = _load_df(cfg.ranker_2, cutoff=cfg.cutoff_n)


    # flatten
    one_preds = one_preds[["QuestionId_Answer", "MisconceptionId", "pred_scores"]].explode(["MisconceptionId", "pred_scores"]).reset_index(drop=True)
    one_preds = one_preds.rename(columns={"pred_scores": "score_one"})
    one_preds["MisconceptionId"] = one_preds["MisconceptionId"].astype(str)
    
    two_preds = two_preds[["QuestionId_Answer", "MisconceptionId", "pred_scores"]].explode(["MisconceptionId", "pred_scores"]).reset_index(drop=True)
    two_preds = two_preds.rename(columns={"pred_scores": "score_two"})
    two_preds["MisconceptionId"] = two_preds["MisconceptionId"].astype(str)


    # blend ---
    w1, w2 = cfg.r1_weight, cfg.r2_weight

    candidate_df = pd.merge(one_preds, two_preds, on=["QuestionId_Answer", "MisconceptionId"])
    print(candidate_df.sample().T)

    candidate_df["score"] = candidate_df.apply(lambda x: w1*x['score_one'] + w2*x['score_two'], axis=1) # blending
    candidate_df = candidate_df[["QuestionId_Answer", "MisconceptionId", "score"]].copy()

    cdf = candidate_df.groupby("QuestionId_Answer")["MisconceptionId"].agg(list).reset_index()
    sdf = candidate_df.groupby("QuestionId_Answer")["score"].agg(list).reset_index()
    candidate_df = pd.merge(cdf, sdf, on="QuestionId_Answer", how="left")

    candidate_df["sorted"] = candidate_df.apply(lambda x: get_sorted_pairs(x['MisconceptionId'], x['score']), axis=1)
    candidate_df["MisconceptionId"] = candidate_df["sorted"].apply(lambda x: [y[0] for y in x])
    candidate_df["score"] = candidate_df["sorted"].apply(lambda x: [y[1] for y in x])

    # Cut at N ---
    print("Sample:")
    candidate_df = candidate_df.drop(columns=['sorted'])
    print(candidate_df.sample().T)

    # Updated input df for re-ranking ---
    ranked_df = candidate_df[["QuestionId_Answer", "MisconceptionId"]].copy()
    ranked_df = ranked_df.reset_index(drop=True)
    
    print("--"*40)
    ranked_df.to_parquet(cfg.outfile_path)
    print(ranked_df.sample(3))
    print("--"*40)

In [None]:
%%writefile ensemble_rankers.yaml

cutoff_n: 5 # 25 # TODO <- FIX

ranker_1: ./ranker_outputs/ranker_qwen_72b_tutor.parquet
ranker_2: ./ranker_outputs/ranker_qwen_32b_oracle_main.parquet

r1_weight: 4.0
r2_weight: 1.0

outfile_path: ./ranker_outputs/ranker_blend_stage4.parquet

In [None]:
!python ensemble_rankers.py --config-path ensemble_rankers.yaml

# Submission

In [None]:
import pandas as pd
pd.options.display.max_colwidth = None

top_df = pd.read_parquet("./ranker_outputs/ranker_blend_stage4.parquet").rename(columns={"MisconceptionId": "top_ids"})
mid_df = pd.read_parquet("./ranker_outputs/two_three_blended.parquet").rename(columns={"MisconceptionId": "mid_ids"}).drop(columns=['score'])
low_df = pd.read_parquet("./ranker_outputs/one_two_blended.parquet").rename(columns={"MisconceptionId": "low_ids"}).drop(columns=['score'])

pred_df = pd.merge(top_df, mid_df, on=["QuestionId_Answer"])
pred_df = pd.merge(pred_df, low_df, on=["QuestionId_Answer"])

In [None]:
def get_final_ids(row):
    def _cast_to_str(x): return [str(y) for y in x]
        
    top_ids = _cast_to_str(row['top_ids'])
    mid_ids = _cast_to_str(row['mid_ids'])
    low_ids = _cast_to_str(row['low_ids'])

    ret = list(top_ids)
    
    for this_id in mid_ids: # add mid ids
        if this_id not in ret: ret.append(this_id)

    for this_id in low_ids: # add mid ids
        if this_id not in ret: ret.append(this_id)

    return ret

In [None]:
pred_df["MisconceptionId"] = pred_df.apply(get_final_ids, axis=1)
pred_df["MisconceptionId"] = pred_df["MisconceptionId"].apply(lambda x: x[:25])
pred_df["MisconceptionId"] = pred_df["MisconceptionId"].apply(lambda x: " ".join(x))
sub_df = pred_df[["QuestionId_Answer", "MisconceptionId"]].copy()
sub_df.to_csv("submission.csv", index=False)

In [None]:
sub_df = pd.read_csv("submission.csv")
sub_df.head()

In [None]:
# End ---#