# 目的
bgeでのretrieveファイルの作成

- MAP@25: 
- recall: 

Please let me know if there are any mistakes.

# Setting

In [None]:
EXP_NAME = "e003-ret-bge-ret5"
DATA_PATH = "../../data"

DATASET_NAME = EXP_NAME
OUTPUT_PATH = f"../../output/retriever/{EXP_NAME}"

RETRIEVE_NUM = 5  # 25  # TODO: 多くしてみる

# Install

# Import

In [2]:
import os

import polars as pl
import numpy as np

from sklearn.metrics.pairwise import cosine_similarity

import torch
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer

In [3]:
# import transformers
# import torch
# import sentence_transformers

# assert transformers.__version__ == "4.44.2"
# assert torch.__version__ == "2.3.1"
# assert sentence_transformers.__version__ == "3.1.0"

# Data Load

In [4]:
train = pl.read_csv(f"{DATA_PATH}/train.csv")
misconception_mapping = pl.read_csv(f"{DATA_PATH}/misconception_mapping.csv")

# Preprocess

In [5]:
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

train_long = (
    pl.read_csv(f"{DATA_PATH}/train.csv")
    .select(
        pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]])
    )
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("ConstructName"),
                pl.col("SubjectName"),
                pl.col("QuestionText"),
                pl.col("AnswerText"),
            ],
            separator=" ",
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
train_long.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""0_A"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""0_B"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""0_C"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D"""
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction…","""A""","""1000_A"""


In [6]:
train_long.shape

(7476, 10)

In [7]:
train_misconception_long = (
    train.select(
        pl.col(
            common_col + [f"Misconception{alpha}Id" for alpha in ["A", "B", "C", "D"]]
        )
    )
    .unpivot(
        index=common_col,
        variable_name="MisconceptionType",
        value_name="MisconceptionId",
    )
    .with_columns(
        pl.col("MisconceptionType")
        .str.extract(r"Misconception([A-D])Id$")
        .alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
    .select(pl.col(["QuestionId_Answer", "MisconceptionId"]))
    .with_columns(pl.col("MisconceptionId").cast(pl.Int64))
)
train_misconception_long.head()

QuestionId_Answer,MisconceptionId
str,i64
"""0_A""",
"""0_B""",
"""0_C""",
"""0_D""",1672.0
"""1000_A""",891.0


In [8]:
# join MisconceptionId
train_long = train_long.join(train_misconception_long, on="QuestionId_Answer")

# BGE

In [9]:
# model = SentenceTransformer("BAAI/bge-large-zh-v1.5") # CV: 0.06ぐらい
model = SentenceTransformer("BAAI/bge-large-en-v1.5")  # CV: 0.1841439184198388
# model = SentenceTransformer("all-MiniLM-L6-v2") # CV: 0.17806659878200218
# model = SentenceTransformer("nvidia/NV-Embed-v2", trust_remote_code=True)  # CV: OOMのため、試せていない
# model.max_seq_length = 32768
# model.tokenizer.padding_side = "right"
# model = SentenceTransformer("dunzhang/stella_en_1.5B_v5", trust_remote_code=True)

train_long_vec = model.encode(
    train_long["AllText"].to_list(), normalize_embeddings=True
)
misconception_mapping_vec = model.encode(
    misconception_mapping["MisconceptionName"].to_list(), normalize_embeddings=True
)
print(train_long_vec.shape)
print(misconception_mapping_vec.shape)

(7476, 1024)
(2587, 1024)


In [10]:
# misconception_mapping_vecを保存する
os.makedirs(OUTPUT_PATH, exist_ok=True)
np.save(f"{OUTPUT_PATH}/misconception_mapping_vec.npy", misconception_mapping_vec)

In [11]:
train_cos_sim_arr = cosine_similarity(train_long_vec, misconception_mapping_vec)
train_sorted_indices = np.argsort(-train_cos_sim_arr, axis=1)

In [12]:
# example
def print_example(df: pl.DataFrame, sorted_indices: np.ndarray, idx: int) -> None:
    print(f"Query idx{idx}")
    print(df["AllText"][idx])
    print("\nCos Sim No.1")
    print(misconception_mapping["MisconceptionName"][int(sorted_indices[idx, 0])])
    print("\nCos Sim No.2")
    print(misconception_mapping["MisconceptionName"][int(sorted_indices[idx, 1])])

In [13]:
print_example(train_long, train_sorted_indices, 0)

Query idx0
Use the order of operations to carry out calculations involving powers BIDMAS \[
3 \times 2+4-5
\]
Where do the brackets need to go to make the answer equal \( 13 \) ? \( 3 \times(2+4)-5 \)

Cos Sim No.1
Applies BIDMAS in strict order (does not realize addition and subtraction, and multiplication and division, are of equal priority)

Cos Sim No.2
Misunderstands order of operations in algebraic expressions


In [14]:
print_example(train_long, train_sorted_indices, 1)

Query idx1
Use the order of operations to carry out calculations involving powers BIDMAS \[
3 \times 2+4-5
\]
Where do the brackets need to go to make the answer equal \( 13 \) ? \( 3 \times 2+(4-5) \)

Cos Sim No.1
Applies BIDMAS in strict order (does not realize addition and subtraction, and multiplication and division, are of equal priority)

Cos Sim No.2
Misunderstands order of operations in algebraic expressions


# Evaluate

In [15]:
train_long = train_long.with_columns(
    pl.Series(train_sorted_indices[:, :RETRIEVE_NUM].tolist()).alias(
        "PredictMisconceptionId"
    )
)

In [16]:
# https://www.kaggle.com/code/cdeotte/how-to-train-open-book-model-part-1#MAP@3-Metric
def map_at_25(predictions, labels):
    map_sum = 0
    for x, y in zip(predictions, labels):
        z = [1 / i if y == j else 0 for i, j in zip(range(1, 26), x)]
        map_sum += np.sum(z)
    return map_sum / len(predictions)

In [17]:
map_at_25_score = map_at_25(
    train_long.filter(pl.col("MisconceptionId").is_not_null())[
        "PredictMisconceptionId"
    ],
    train_long.filter(pl.col("MisconceptionId").is_not_null())["MisconceptionId"],
)
map_at_25_score

0.15885964912280748

In [18]:
def recall(predictions, labels):
    acc_num = np.sum([1 for x, y in zip(predictions, labels) if y in x])
    return acc_num / len(predictions)


recall_score = recall(
    train_long.filter(pl.col("MisconceptionId").is_not_null())[
        "PredictMisconceptionId"
    ],
    train_long.filter(pl.col("MisconceptionId").is_not_null())["MisconceptionId"],
)
recall_score

0.28306636155606407

# Make Retrieved Train File

In [19]:
train_long

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId,PredictMisconceptionId
i64,str,str,str,str,str,str,str,str,str,i64,list[i64]
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""0_A""",,"[2306, 2586, … 2488]"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""0_B""",,"[2306, 2586, … 2488]"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""0_C""",,"[2306, 2586, … 2488]"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,"[2306, 2586, … 1316]"
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction…","""A""","""1000_A""",891,"[979, 2398, … 363]"
…,…,…,…,…,…,…,…,…,…,…,…
99,"""Given the perimeter, work out …","""Volume and Capacity Units""","""A regular pentagon has a total…","""C""","""AnswerDText""","""Not enough information""","""Given the perimeter, work out …","""D""","""99_D""",255,"[1815, 535, … 700]"
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerAText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""A""","""9_A""",1889,"[1475, 2497, … 1234]"
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerBText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""B""","""9_B""",1234,"[1475, 2497, … 1401]"
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerCText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""C""","""9_C""",,"[1475, 2497, … 1234]"


In [20]:
train_retrieved = (
    train_long.filter(
        pl.col(
            "MisconceptionId"
        ).is_not_null()  # TODO: Consider ways to utilize data where MisconceptionId is NaN.
    )
    .explode("PredictMisconceptionId")
    .with_columns(
        (pl.col("MisconceptionId") == pl.col("PredictMisconceptionId"))
        .cast(pl.Int64)
        .alias("target")
    )
    .join(
        misconception_mapping,
        on="MisconceptionId",
    )
    .join(
        misconception_mapping.rename(lambda x: "Predict" + x),
        on="PredictMisconceptionId",
    )
)
train_retrieved.shape

(21850, 15)

In [21]:
train_long

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId,PredictMisconceptionId
i64,str,str,str,str,str,str,str,str,str,i64,list[i64]
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""0_A""",,"[2306, 2586, … 2488]"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""0_B""",,"[2306, 2586, … 2488]"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""0_C""",,"[2306, 2586, … 2488]"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,"[2306, 2586, … 1316]"
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction…","""A""","""1000_A""",891,"[979, 2398, … 363]"
…,…,…,…,…,…,…,…,…,…,…,…
99,"""Given the perimeter, work out …","""Volume and Capacity Units""","""A regular pentagon has a total…","""C""","""AnswerDText""","""Not enough information""","""Given the perimeter, work out …","""D""","""99_D""",255,"[1815, 535, … 700]"
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerAText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""A""","""9_A""",1889,"[1475, 2497, … 1234]"
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerBText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""B""","""9_B""",1234,"[1475, 2497, … 1401]"
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerCText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""C""","""9_C""",,"[1475, 2497, … 1234]"


In [22]:
train_retrieved.shape

(21850, 15)

In [23]:
train_retrieved.write_csv(
    f"{OUTPUT_PATH}/{EXP_NAME}-ret{RETRIEVE_NUM}-map{map_at_25_score:.4f}-recall{recall_score:.4f}.csv",
)

In [24]:
print(
    f"{OUTPUT_PATH}/{EXP_NAME}-ret{RETRIEVE_NUM}-map{map_at_25_score:.4f}-recall{recall_score:.4f}.csv",
)

../../output/retriever/e003-ret-bge-ret5/e003-ret-bge-ret5-ret5-map0.1589-recall0.2831.csv


# Kaggle Upload

In [25]:
import os
import json

from kaggle.api.kaggle_api_extended import KaggleApi


def dataset_create_new(dataset_name: str, upload_dir: str):
    # if "_" in dataset_name:
    #     raise ValueError("datasetの名称に_の使用は禁止です")
    dataset_metadata = {}
    dataset_metadata["id"] = f"sinchir0/{dataset_name}"
    dataset_metadata["licenses"] = [{"name": "CC0-1.0"}]
    dataset_metadata["title"] = dataset_name
    with open(os.path.join(upload_dir, "dataset-metadata.json"), "w") as f:
        json.dump(dataset_metadata, f, indent=4)
    api = KaggleApi()
    api.authenticate()
    api.dataset_create_new(folder=upload_dir, convert_to_csv=False, dir_mode="tar")


print(f"Create Dataset name:{DATASET_NAME}, output_dir:{OUTPUT_PATH}")
dataset_create_new(dataset_name=DATASET_NAME, upload_dir=OUTPUT_PATH)

Create Dataset name:e003-ret-bge-ret5, output_dir:../../output/retriever/e003-ret-bge-ret5
Starting upload for file misconception_mapping_vec.npy


100%|██████████| 10.1M/10.1M [00:02<00:00, 3.60MB/s]


Upload successful: misconception_mapping_vec.npy (10MB)
Starting upload for file e003-ret-bge-ret5-ret5-map0.1589-recall0.2831.csv


100%|██████████| 15.2M/15.2M [00:03<00:00, 4.94MB/s]


Upload successful: e003-ret-bge-ret5-ret5-map0.1589-recall0.2831.csv (15MB)
