In [1]:
from pathlib import Path

import numpy as np
import polars as pl
import seaborn as sns
from hydra import compose, initialize
from matplotlib import pyplot as plt


pl.Config.set_fmt_str_lengths(100)

with initialize(config_path="config", version_base=None):
    cfg = compose(config_name="config")
    cfg.exp_number = Path().resolve().name


In [2]:
EXP_NAME = "fine-tuning-bge"
DATA_PATH = "/kaggle/input/eedi-mining-misconceptions-in-mathematics"
MODEL_NAME = "BAAI/bge-large-en-v1.5"
COMPETITION_NAME = "eedi-mining-misconceptions-in-mathematics"
OUTPUT_PATH = "."
MODEL_OUTPUT_PATH = f"{OUTPUT_PATH}/trained_model"
RETRIEVE_NUM = 25

EPOCH = 2
LR = 2e-05
BS = 8
GRAD_ACC_STEP = 128 // BS

TRAINING = True
DEBUG = False
WANDB = False


In [3]:
import os
import numpy as np

from datasets import load_dataset, Dataset

import polars as pl

from sklearn.metrics.pairwise import cosine_similarity

from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import datasets
import sentence_transformers


In [5]:
NUM_PROC = os.cpu_count()
NUM_PROC


32

In [6]:
if WANDB:
    # Settings -> add wandb api
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    wandb.login(key=user_secrets.get_secret("wandb_api"))
    wandb.init(project=COMPETITION_NAME, name=EXP_NAME)
    REPORT_TO = "wandb"
else:
    REPORT_TO = "none"

REPORT_TO


'none'

In [7]:
train = pl.read_csv(cfg.path.train, try_parse_dates=True)
test = pl.read_csv(cfg.path.test, try_parse_dates=True)
misconception_mapping = pl.read_csv(cfg.path.misconception_mapping, try_parse_dates=True)


In [8]:
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

train_long = (
    train
    .select(
        pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]])
    )
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("ConstructName"),
                pl.col("SubjectName"),
                pl.col("QuestionText"),
                pl.col("AnswerText"),
            ],
            separator=" ",
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
train_long.head()


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
0,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to carry out calculations involving powers BIDMAS \[ 3 \times 2+4-5 \] W…","""A""","""0_A"""
0,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to carry out calculations involving powers BIDMAS \[ 3 \times 2+4-5 \] W…","""B""","""0_B"""
0,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to carry out calculations involving powers BIDMAS \[ 3 \times 2+4-5 \] W…","""C""","""0_C"""
0,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to carry out calculations involving powers BIDMAS \[ 3 \times 2+4-5 \] W…","""D""","""0_D"""
1000,"""Simplify an algebraic fraction by factorising the numerator""","""Simplifying Algebraic Fractions""","""Simplify the following, if possible: \( \frac{1-t}{t-1} \)""","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction by factorising the numerator Simplifying Algebraic Fractions Simplify…","""A""","""1000_A"""


In [9]:
train_misconception_long = (
    train.select(
        pl.col(
            common_col + [f"Misconception{alpha}Id" for alpha in ["A", "B", "C", "D"]]
        )
    )
    .unpivot(
        index=common_col,
        variable_name="MisconceptionType",
        value_name="MisconceptionId",
    )
    .with_columns(
        pl.col("MisconceptionType")
        .str.extract(r"Misconception([A-D])Id$")
        .alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
    .select(pl.col(["QuestionId_Answer", "MisconceptionId"]))
    .with_columns(pl.col("MisconceptionId").cast(pl.Int64))
)

train_misconception_long.head()


QuestionId_Answer,MisconceptionId
str,i64
"""0_A""",
"""0_B""",
"""0_C""",
"""0_D""",1672.0
"""1000_A""",891.0


In [10]:
train_long = train_long.join(train_misconception_long, on="QuestionId_Answer")
train_long.head()


QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId
i64,str,str,str,str,str,str,str,str,str,i64
0,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to carry out calculations involving powers BIDMAS \[ 3 \times 2+4-5 \] W…","""A""","""0_A""",
0,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to carry out calculations involving powers BIDMAS \[ 3 \times 2+4-5 \] W…","""B""","""0_B""",
0,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to carry out calculations involving powers BIDMAS \[ 3 \times 2+4-5 \] W…","""C""","""0_C""",
0,"""Use the order of operations to carry out calculations involving powers""","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do the brackets need to go to make the answer equal \( 13 \) ?""","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to carry out calculations involving powers BIDMAS \[ 3 \times 2+4-5 \] W…","""D""","""0_D""",1672.0
1000,"""Simplify an algebraic fraction by factorising the numerator""","""Simplifying Algebraic Fractions""","""Simplify the following, if possible: \( \frac{1-t}{t-1} \)""","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction by factorising the numerator Simplifying Algebraic Fractions Simplify…","""A""","""1000_A""",891.0


In [11]:
model = SentenceTransformer(MODEL_NAME)

train_long_vec = model.encode(
    train_long["AllText"].to_list(), normalize_embeddings=True
)
misconception_mapping_vec = model.encode(
    misconception_mapping["MisconceptionName"].to_list(), normalize_embeddings=True
)
print(train_long_vec.shape)
print(misconception_mapping_vec.shape)


(7476, 1024)
(2587, 1024)


In [12]:
train_cos_sim_arr = cosine_similarity(train_long_vec, misconception_mapping_vec)
train_sorted_indices = np.argsort(-train_cos_sim_arr, axis=1)


In [13]:
train_long = train_long.with_columns(
    pl.Series(train_sorted_indices[:, :RETRIEVE_NUM].tolist()).alias(
        "PredictMisconceptionId"
    )
)


In [14]:
train_retrieved = (
    train_long.filter(
        pl.col(
            "MisconceptionId"
        ).is_not_null()  # TODO: Consider ways to utilize data where MisconceptionId is NaN.
    )
    .explode("PredictMisconceptionId")
    .join(
        misconception_mapping,
        on="MisconceptionId",
    )
    .join(
        misconception_mapping.rename(lambda x: "Predict" + x),
        on="PredictMisconceptionId",
    )
)
train_retrieved.shape


(109250, 14)

In [15]:
train = (
    Dataset.from_polars(train_retrieved)
    .filter(  # To create an anchor, positive, and negative structure, delete rows where the positive and negative are identical.
        lambda example: example["MisconceptionId"] != example["PredictMisconceptionId"],
        num_proc=NUM_PROC,
    )
)


Filter (num_proc=32): 100%|██████████| 109250/109250 [00:08<00:00, 12527.15 examples/s]


In [18]:
model = SentenceTransformer(MODEL_NAME)

loss = MultipleNegativesRankingLoss(model)

args = SentenceTransformerTrainingArguments(
    output_dir=OUTPUT_PATH,
    num_train_epochs=EPOCH,
    per_device_train_batch_size=BS,
    gradient_accumulation_steps=GRAD_ACC_STEP,
    per_device_eval_batch_size=BS,
    eval_accumulation_steps=GRAD_ACC_STEP,
    learning_rate=LR,
    weight_decay=0.01,
    warmup_ratio=0.1,
    fp16=False,  # 一旦無効化
    bf16=False,  # 一旦無効化
    batch_sampler=None,  # サンプラーを無効化
    lr_scheduler_type="cosine_with_restarts",
    save_strategy="steps",
    save_steps=10,  # `0.1`を具体的な整数に変更
    save_total_limit=2,
    logging_steps=100,
    report_to=REPORT_TO,
    run_name=EXP_NAME,
    do_eval=False
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train.select_columns(
        ["AllText", "MisconceptionName", "PredictMisconceptionName"]
    ),
    loss=loss
)

trainer.train()
model.save_pretrained(MODEL_OUTPUT_PATH)


ValueError: None is not a valid BatchSamplers, please select one of ['batch_sampler', 'no_duplicates', 'group_by_label']