# 目的
学習するretrieveデータの数を25→50にする

ref: https://sbert.net/docs/sentence_transformer/training_overview.html#trainer

# Setting

In [1]:
EXP_NAME = "e011-ret-add-ret-num"
DATA_PATH = "data"
MODEL_NAME = "BAAI/bge-large-en-v1.5"
ENV_PATH = "env_file"
COMPETITION_NAME = "eedi-mining-misconceptions-in-mathematics"
RETRIVED_FILE_PATH = (
    # "output/retriever/e003-ret-bge/e003-ret-bge-ret25-map0.1841-recall0.5506.csv"
    "output/retriever/e003-ret-bge/e003-ret-bge-ret50-map0.1841-recall0.6586.csv"
)

DATASET_NAME = EXP_NAME
OUTPUT_PATH = f"output/retriever/{EXP_NAME}"
# MODEL_OUTPUT_PATH = f"{OUTPUT_PATH}/trained_model"
MODEL_OUTPUT_PATH = f"{OUTPUT_PATH}/checkpoint-5605"

# TRAIN_USE_NUM = 5
RETRIEVE_NUM = 25

EPOCH = 2
LR = 2e-5
BS = 8
GRAD_ACC_STEP = 128 // BS

TRAINING = False
DEBUG = False
WANDB = True

In [2]:
def resolve_path(base_path: str) -> str:
    import os

    cwd = os.getcwd()
    print(cwd)
    if cwd == f"/notebooks":
        print("Jupyter Kernel By VSCode!")
        return f"/notebooks/{COMPETITION_NAME}/{base_path}"
    elif cwd == f"/notebooks/{COMPETITION_NAME}":
        print("nohup!")
        return base_path
    elif cwd == f"/notebooks/{COMPETITION_NAME}/{COMPETITION_NAME}/exp":
        print("Jupyter Lab!")
        return f"../../{base_path}"
    elif cwd == f"/root/{COMPETITION_NAME}/exp/reranker":
        print("VastAi! Reranker")
        return f"../../{base_path}"
    elif cwd == f"/root/{COMPETITION_NAME}/exp/retriever":
        print("VastAi! Retriever")
        return f"../../{base_path}"
    elif cwd == f"/root/{COMPETITION_NAME}":
        print("VastAi!")
        return base_path
    else:
        raise Exception("Unknown environment")


DATA_PATH = resolve_path(DATA_PATH)
print(DATA_PATH)
OUTPUT_PATH = resolve_path(OUTPUT_PATH)
print(OUTPUT_PATH)
MODEL_OUTPUT_PATH = resolve_path(MODEL_OUTPUT_PATH)
print(MODEL_OUTPUT_PATH)
ENV_PATH = resolve_path(ENV_PATH)
print(ENV_PATH)
RETRIVED_FILE_PATH = resolve_path(RETRIVED_FILE_PATH)
print(RETRIVED_FILE_PATH)

# Import

In [3]:
import os
import numpy as np

from datasets import load_dataset, Dataset

import wandb
import polars as pl

from sklearn.metrics.pairwise import cosine_similarity

from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator

In [4]:
import sentence_transformers

assert sentence_transformers.__version__ == "3.1.0"

In [5]:
NUM_PROC = 16

# Setting

In [6]:
from dotenv import load_dotenv

load_dotenv(f"{ENV_PATH}/.env")

True

# WANDB

In [7]:
if WANDB:
    wandb.login(key=os.environ["WANDB_API_KEY"])
    wandb.init(project=COMPETITION_NAME, name=EXP_NAME)
    REPORT_TO = "wandb"
else:
    REPORT_TO = "none"

REPORT_TO

'wandb'

# Data Load

In [8]:
train = (
    load_dataset(
        "csv",
        data_files=RETRIVED_FILE_PATH,
        split="train",
    )
    .filter(
        lambda example: example["MisconceptionId"] is not None,
        num_proc=NUM_PROC,
    )
    .filter(  # anchor、positive、negativeの構成にするため、positiveとnegativeが一致している行を削除する
        lambda example: example["MisconceptionId"] != example["PredictMisconceptionId"],
        num_proc=NUM_PROC,
    )
)

Generating train split: 0 examples [00:00, ? examples/s]

Filter (num_proc=16):   0%|          | 0/218500 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/218500 [00:00<?, ? examples/s]

In [9]:
train

Dataset({
    features: ['QuestionId', 'ConstructName', 'SubjectName', 'QuestionText', 'CorrectAnswer', 'AnswerType', 'AnswerText', 'AllText', 'AnswerAlphabet', 'QuestionId_Answer', 'MisconceptionId', 'PredictMisconceptionId', 'target', 'MisconceptionName', 'PredictMisconceptionName'],
    num_rows: 215622
})

In [10]:
# # 同じMisconceptionNameのpositiveの例を25個→5個に減らす
# train = Dataset.from_pandas(
#     train.to_pandas().groupby("QuestionId_Answer").head(TRAIN_USE_NUM)
# )

In [11]:
if DEBUG:
    train = train.select(range(1000))

In [12]:
print(train["AllText"][0])

In [13]:
print(train["MisconceptionName"][0])

In [14]:
print(train["PredictMisconceptionName"][0])

# Train Valid Split

In [15]:
train, valid = (
    train.filter(lambda example: example["QuestionId"] % 3 != 0, num_proc=NUM_PROC),
    train.filter(lambda example: example["QuestionId"] % 3 == 0, num_proc=NUM_PROC),
)

Filter (num_proc=16):   0%|          | 0/215622 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/215622 [00:00<?, ? examples/s]

In [16]:
print(train)
print(valid)

# Fine-tuning BGE

In [17]:
if TRAINING:
    model = SentenceTransformer(MODEL_NAME)

    loss = MultipleNegativesRankingLoss(model)

    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=OUTPUT_PATH,
        # Optional training parameters:
        num_train_epochs=EPOCH,
        per_device_train_batch_size=BS,
        gradient_accumulation_steps=GRAD_ACC_STEP,
        per_device_eval_batch_size=BS,
        eval_accumulation_steps=GRAD_ACC_STEP,
        learning_rate=LR,
        weight_decay=0.01,
        warmup_ratio=0.1,
        fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
        fp16_full_eval=True,
        bf16=False,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        # Optional tracking/debugging parameters:
        lr_scheduler_type="cosine_with_restarts",
        eval_strategy="steps",
        eval_steps=0.1,
        save_strategy="steps",
        save_steps=0.1,
        save_total_limit=2,
        logging_steps=100,
        report_to=REPORT_TO,  # Will be used in W&B if `wandb` is installed
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    )

    dev_evaluator = TripletEvaluator(
        anchors=valid["AllText"],
        positives=valid["MisconceptionName"],
        negatives=valid["PredictMisconceptionName"],
        name=f"{MODEL_NAME}-dev",
    )
    dev_evaluator(model)

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train.select_columns(
            ["AllText", "MisconceptionName", "PredictMisconceptionName"]
        ),
        eval_dataset=valid.select_columns(
            ["AllText", "MisconceptionName", "PredictMisconceptionName"]
        ),
        loss=loss,
        evaluator=dev_evaluator,
    )

    trainer.train()
    model.save_pretrained(MODEL_OUTPUT_PATH)
else:
    model = SentenceTransformer(MODEL_OUTPUT_PATH)

Step,Training Loss,Validation Loss,Baai/bge-large-en-v1.5-dev Cosine Accuracy,Baai/bge-large-en-v1.5-dev Dot Accuracy,Baai/bge-large-en-v1.5-dev Manhattan Accuracy,Baai/bge-large-en-v1.5-dev Euclidean Accuracy,Baai/bge-large-en-v1.5-dev Max Accuracy
561,1.2962,1.604903,0.747987,0.251999,0.748084,0.747987,0.748084
1122,1.1324,1.736033,0.805153,0.194833,0.804709,0.805153,0.805153
1683,0.9122,1.436146,0.839497,0.160489,0.838539,0.839497,0.839497
2244,0.7754,1.612417,0.847285,0.152701,0.84766,0.847285,0.84766
2805,0.5102,1.473543,0.845758,0.154228,0.845827,0.845758,0.845827
3366,0.5521,1.537305,0.850339,0.149647,0.849811,0.850339,0.850339
3927,0.3669,1.570691,0.845286,0.1547,0.845286,0.845286,0.845286
4488,0.3102,1.529935,0.862596,0.137376,0.862985,0.862596,0.862985
5049,0.2201,1.584729,0.852074,0.147912,0.851685,0.852074,0.852074


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

# Make Vector

In [18]:
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

train_long = (
    pl.read_csv(f"{DATA_PATH}/train.csv")
    .select(
        pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]])
    )
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("ConstructName"),
                pl.col("SubjectName"),
                pl.col("QuestionText"),
                pl.col("AnswerText"),
            ],
            separator=" ",
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
train_long.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""0_A"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""0_B"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""0_C"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D"""
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction…","""A""","""1000_A"""


In [19]:
train_misconception_long = (
    pl.read_csv(f"{DATA_PATH}/train.csv")
    .select(
        pl.col(
            common_col + [f"Misconception{alpha}Id" for alpha in ["A", "B", "C", "D"]]
        )
    )
    .unpivot(
        index=common_col,
        variable_name="MisconceptionType",
        value_name="MisconceptionId",
    )
    .with_columns(
        pl.col("MisconceptionType")
        .str.extract(r"Misconception([A-D])Id$")
        .alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
    .select(pl.col(["QuestionId_Answer", "MisconceptionId"]))
    .with_columns(pl.col("MisconceptionId").cast(pl.Int64))
)
train_misconception_long.head()

QuestionId_Answer,MisconceptionId
str,i64
"""0_A""",
"""0_B""",
"""0_C""",
"""0_D""",1672.0
"""1000_A""",891.0


In [20]:
# join MisconceptionId
train_long = train_long.join(train_misconception_long, on="QuestionId_Answer")

In [21]:
valid_long = train_long.filter(
    pl.col("QuestionId_Answer").is_in(set(valid["QuestionId_Answer"]))
)

In [22]:
valid_long

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId
i64,str,str,str,str,str,str,str,str,str,i64
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672
1002,"""Convert fractions less than 1 …","""Converting between Fractions a…","""Convert \( \frac{3}{20} \) int…","""A""","""AnswerBText""","""\( 0.3 \)""","""Convert fractions less than 1 …","""B""","""1002_B""",1715
1002,"""Convert fractions less than 1 …","""Converting between Fractions a…","""Convert \( \frac{3}{20} \) int…","""A""","""AnswerCText""","""\( 3.20 \)""","""Convert fractions less than 1 …","""C""","""1002_C""",2308
1005,"""Convert fractions less than 1 …","""Converting between Fractions a…","""What is \( \frac{3}{5} \) as a…","""D""","""AnswerAText""","""\( 0.5 \)""","""Convert fractions less than 1 …","""A""","""1005_A""",760
1005,"""Convert fractions less than 1 …","""Converting between Fractions a…","""What is \( \frac{3}{5} \) as a…","""D""","""AnswerBText""","""\( 0.3 \)""","""Convert fractions less than 1 …","""B""","""1005_B""",1715
…,…,…,…,…,…,…,…,…,…,…
99,"""Given the perimeter, work out …","""Volume and Capacity Units""","""A regular pentagon has a total…","""C""","""AnswerBText""","""\( 12 \mathrm{~cm} \)""","""Given the perimeter, work out …","""B""","""99_B""",1815
99,"""Given the perimeter, work out …","""Volume and Capacity Units""","""A regular pentagon has a total…","""C""","""AnswerDText""","""Not enough information""","""Given the perimeter, work out …","""D""","""99_D""",255
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerAText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""A""","""9_A""",1889
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerBText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""B""","""9_B""",1234


In [23]:
valid_long_vec = model.encode(
    valid_long["AllText"].to_list(), normalize_embeddings=True
)
misconception_mapping = pl.read_csv(f"{DATA_PATH}/misconception_mapping.csv")
misconception_mapping_vec = model.encode(
    misconception_mapping["MisconceptionName"].to_list(), normalize_embeddings=True
)
print(valid_long_vec.shape)
print(misconception_mapping_vec.shape)

In [24]:
# misconception_mapping_vecを保存する
os.makedirs(OUTPUT_PATH, exist_ok=True)
np.save(f"{OUTPUT_PATH}/misconception_mapping_vec.npy", misconception_mapping_vec)

In [25]:
valid_cos_sim_arr = cosine_similarity(valid_long_vec, misconception_mapping_vec)
valid_sorted_indices = np.argsort(-valid_cos_sim_arr, axis=1)

In [26]:
# example
def print_example(df: pl.DataFrame, sorted_indices: np.ndarray, idx: int) -> None:
    print(f"Query idx{idx}")
    print(df["AllText"][idx])
    print("\nCos Sim No.1")
    print(misconception_mapping["MisconceptionName"][int(sorted_indices[idx, 0])])
    print("\nCos Sim No.2")
    print(misconception_mapping["MisconceptionName"][int(sorted_indices[idx, 1])])

In [27]:
print_example(train_long, valid_sorted_indices, 0)

In [28]:
print_example(train_long, valid_sorted_indices, 1)

# Evaluate

In [29]:
valid_long = valid_long.with_columns(
    pl.Series(valid_sorted_indices[:, :RETRIEVE_NUM].tolist()).alias(
        "PredictMisconceptionId"
    )
)

In [30]:
# https://www.kaggle.com/code/cdeotte/how-to-train-open-book-model-part-1#MAP@3-Metric
def map_at_25(predictions, labels):
    map_sum = 0
    for x, y in zip(predictions, labels):
        z = [1 / i if y == j else 0 for i, j in zip(range(1, 26), x)]
        map_sum += np.sum(z)
    return map_sum / len(predictions)

In [31]:
map_at_25_score = map_at_25(
    valid_long.filter(pl.col("MisconceptionId").is_not_null())[
        "PredictMisconceptionId"
    ],
    valid_long.filter(pl.col("MisconceptionId").is_not_null())["MisconceptionId"],
)
map_at_25_score

0.2886289975607982

In [32]:
def recall(predictions, labels):
    acc_num = np.sum([1 for x, y in zip(predictions, labels) if y in x])
    return acc_num / len(predictions)


recall_score = recall(
    valid_long.filter(pl.col("MisconceptionId").is_not_null())[
        "PredictMisconceptionId"
    ],
    valid_long.filter(pl.col("MisconceptionId").is_not_null())["MisconceptionId"],
)
recall_score

0.6664383561643835

In [33]:
# output_textを保存
with open(f"{OUTPUT_PATH}/cv_score.txt", "w") as f:
    f.write(f"MAP@25:{map_at_25_score:.4f}, Recall:{recall_score:.4f}")

# Make Retrieved Train File

In [34]:
valid_retrieved = (
    valid_long.filter(pl.col("MisconceptionId").is_not_null())
    .explode("PredictMisconceptionId")
    .with_columns(
        (pl.col("MisconceptionId") == pl.col("PredictMisconceptionId"))
        .cast(pl.Int64)
        .alias("target")
    )
    .join(
        misconception_mapping.rename(lambda x: "Predict" + x),
        on="PredictMisconceptionId",
    )
)
valid_retrieved.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId,PredictMisconceptionId,target,PredictMisconceptionName
i64,str,str,str,str,str,str,str,str,str,i64,i64,i64,str
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,638,0,"""Believes multiplication is not…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,1361,0,"""Does not recognise that additi…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,1336,0,"""Does not realise addition is c…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,1180,0,"""Does not know the properties o…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,1929,0,"""Does not understand the concep…"


In [35]:
valid_retrieved.write_csv(
    f"{OUTPUT_PATH}/{EXP_NAME}-valid-ret{RETRIEVE_NUM}-map{map_at_25_score:.4f}-recall{recall_score:.4f}.csv",
)

In [36]:
print(
    f"{OUTPUT_PATH}/{EXP_NAME}-valid-ret{RETRIEVE_NUM}-map{map_at_25_score:.4f}-recall{recall_score:.4f}.csv",
)

# Kaggle Upload

In [37]:
import os
import json

from kaggle.api.kaggle_api_extended import KaggleApi


def dataset_create_new(dataset_name: str, upload_dir: str):
    # if "_" in dataset_name:
    #     raise ValueError("datasetの名称に_の使用は禁止です")
    dataset_metadata = {}
    dataset_metadata["id"] = f"sinchir0/{dataset_name}"
    dataset_metadata["licenses"] = [{"name": "CC0-1.0"}]
    dataset_metadata["title"] = dataset_name
    with open(os.path.join(upload_dir, "dataset-metadata.json"), "w") as f:
        json.dump(dataset_metadata, f, indent=4)
    api = KaggleApi()
    api.authenticate()
    api.dataset_create_new(folder=upload_dir, convert_to_csv=False, dir_mode="tar")


print(f"Create Dataset name:{DATASET_NAME}, output_dir:{OUTPUT_PATH}")
dataset_create_new(dataset_name=DATASET_NAME, upload_dir=OUTPUT_PATH)