# 目的
BAAI/bge-en-iclをlora fine-tuningする

ref: https://sbert.net/docs/sentence_transformer/training_overview.html#trainer

lora ft: https://github.com/UKPLab/sentence-transformers/issues/2748#issuecomment-2212508370

In [1]:
!nvidia-smi

Sun Nov 10 18:09:05 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090 Ti     On  |   00000000:01:00.0 Off |                  Off |
|  0%   37C    P8             17W /  450W |       1MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                     

# Setting

In [2]:
EXP_NAME = "e019-ret-ilu-100k"
DATA_PATH = "data"
# MODEL_NAME = "nvidia/NV-Embed-v2" # NotImplementedErrorになる
MODEL_NAME = "BAAI/bge-en-icl"
# MODEL_NAME = "Salesforce/SFR-Embedding-2_R"
# MODEL_NAME = "BAAI/bge-large-en-v1.5"
ENV_PATH = "env_file"
COMPETITION_NAME = "eedi-mining-misconceptions-in-mathematics"
# RETRIVED_FILE_PATH = "output/retriever/e003-ret-bge-ret5/e003-ret-bge-ret5-ret5-map0.1589-recall0.2831.csv"
RETRIVED_FILE_PATH = (
    "output/retriever/e003-ret-bge/e003-ret-bge-ret25-map0.1841-recall0.5506.csv"
)
# RETRIVED_EX_FILE_PATH = "output/retriever/e013-ret-bge-prepare/e013-ret-bge-prepare-ret25-map0.4374-recall0.8410.csv"

DATASET_NAME = EXP_NAME
OUTPUT_PATH = f"output/retriever/{EXP_NAME}"
# MODEL_OUTPUT_PATH = f"{OUTPUT_PATH}/trained_model"
MODEL_OUTPUT_PATH = f"{OUTPUT_PATH}/checkpoint-5010"

TRAIN_USE_NUM = 5
RETRIEVE_NUM = 25  # TODO: 多くしてみる

EPOCH = 5
BS = 32  # 8
GRAD_ACC_STEP = 1  # bugがあるので1にする。# 128 // BS
LR = 2e-5

TRAINING = True
DEBUG = False
WANDB = True

In [3]:
def resolve_path(base_path: str) -> str:
    import os

    cwd = os.getcwd()
    print(cwd)
    if cwd == f"/notebooks":
        print("Jupyter Kernel By VSCode!")
        return f"/notebooks/{COMPETITION_NAME}/{base_path}"
    elif cwd == f"/notebooks/{COMPETITION_NAME}":
        print("nohup!")
        return base_path
    elif cwd == f"/notebooks/{COMPETITION_NAME}/{COMPETITION_NAME}/exp":
        print("Jupyter Lab!")
        return f"../../{base_path}"
    elif cwd == f"/root/{COMPETITION_NAME}/exp/reranker":
        print("VastAi! Reranker")
        return f"../../{base_path}"
    elif cwd == f"/root/{COMPETITION_NAME}/exp/retriever":
        print("VastAi! Retriever")
        return f"../../{base_path}"
    elif cwd == f"/root/{COMPETITION_NAME}":
        print("VastAi!")
        return base_path
    else:
        raise Exception("Unknown environment")


DATA_PATH = resolve_path(DATA_PATH)
print(DATA_PATH)
OUTPUT_PATH = resolve_path(OUTPUT_PATH)
print(OUTPUT_PATH)
MODEL_OUTPUT_PATH = resolve_path(MODEL_OUTPUT_PATH)
print(MODEL_OUTPUT_PATH)
ENV_PATH = resolve_path(ENV_PATH)
print(ENV_PATH)
RETRIVED_FILE_PATH = resolve_path(RETRIVED_FILE_PATH)
print(RETRIVED_FILE_PATH)
# RETRIVED_EX_FILE_PATH = resolve_path(RETRIVED_EX_FILE_PATH)
# print(RETRIVED_EX_FILE_PATH)

# Import

In [4]:
import os
import numpy as np

from datasets import load_dataset

import wandb
import polars as pl

import torch
from transformers import BitsAndBytesConfig
from peft import (
    prepare_model_for_kbit_training,
    get_peft_model,
    LoraConfig,
    TaskType,
    get_peft_model,
)

from sklearn.metrics.pairwise import cosine_similarity

from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator

In [5]:
import transformers
import sentence_transformers

assert transformers.__version__ == "4.42.4"
assert sentence_transformers.__version__ == "3.1.0"

In [6]:
NUM_PROC = 16

# Setting

In [7]:
from dotenv import load_dotenv

load_dotenv(f"{ENV_PATH}/.env")

True

# WANDB

In [8]:
if WANDB:
    wandb.login(key=os.environ["WANDB_API_KEY"])
    wandb.init(project=COMPETITION_NAME, name=EXP_NAME)
    REPORT_TO = "wandb"
else:
    REPORT_TO = "none"

REPORT_TO

'wandb'

# Data Load

In [9]:
train = (
    load_dataset(
        "csv",
        data_files=RETRIVED_FILE_PATH,
        split="train",
    )
    .filter(
        lambda example: example["MisconceptionId"] is not None,
        num_proc=NUM_PROC,
    )
    .filter(  # anchor、positive、negativeの構成にするため、positiveとnegativeが一致している行を削除する
        lambda example: example["MisconceptionId"] != example["PredictMisconceptionId"],
        num_proc=NUM_PROC,
    )
)

Generating train split: 0 examples [00:00, ? examples/s]

Filter (num_proc=16):   0%|          | 0/109250 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/109250 [00:00<?, ? examples/s]

In [10]:
train

Dataset({
    features: ['QuestionId', 'ConstructName', 'SubjectName', 'QuestionText', 'CorrectAnswer', 'AnswerType', 'AnswerText', 'AllText', 'AnswerAlphabet', 'QuestionId_Answer', 'MisconceptionId', 'PredictMisconceptionId', 'target', 'MisconceptionName', 'PredictMisconceptionName'],
    num_rows: 106844
})

In [11]:
# # 同じMisconceptionNameのpositiveの例を25個→5個に減らす
# train = Dataset.from_pandas(
#     train.to_pandas().groupby("QuestionId_Answer").head(TRAIN_USE_NUM)
# )

In [12]:
if DEBUG:
    train = train.select(range(1000))

In [13]:
print(train["AllText"][0])

In [14]:
print(train["MisconceptionName"][0])

In [15]:
print(train["PredictMisconceptionName"][0])

In [16]:
# # load ex data
# train_ex = (
#     load_dataset(
#         "csv",
#         data_files=RETRIVED_EX_FILE_PATH,
#         split="train",
#     )
#     .filter(  # Nameを正しく生成できず、Idが紐づかなかったデータを落とす
#         lambda example: example["MisconceptionId"] is not None,
#         num_proc=NUM_PROC,
#     )
#     .filter(  # anchor、positive、negativeの構成にするため、positiveとnegativeが一致している行を削除する
#         lambda example: example["MisconceptionId"] != example["PredictMisconceptionId"],
#         num_proc=NUM_PROC,
#     )
# )

# train_ex

# Train Valid Split

In [17]:
train, valid = (
    train.filter(lambda example: example["QuestionId"] % 3 != 0, num_proc=NUM_PROC),
    train.filter(lambda example: example["QuestionId"] % 3 == 0, num_proc=NUM_PROC),
)

Filter (num_proc=16):   0%|          | 0/106844 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/106844 [00:00<?, ? examples/s]

In [18]:
print(train)
print(valid)

In [19]:
# from datasets import concatenate_datasets

# train = concatenate_datasets([train, train_ex])

In [20]:
# print(train)
# print(valid)

# Fine-tuning

In [21]:
if TRAINING:
    config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        # bnb_4bit_use_double_quant=False,
        bnb_4bit_compute_dtype=torch.float16,
    )
    model = SentenceTransformer(
        MODEL_NAME, trust_remote_code=True, model_kwargs={"quantization_config": config}
    )
    model.to("cuda:0")

    model[0].auto_model = prepare_model_for_kbit_training(model[0].auto_model)
    # loadするときは、model = PeftModel.from_pretrained(model, peft_model_id)を利用する。

    peft_config = LoraConfig(
        r=8,
        lora_alpha=16,
        bias="none",
        task_type=TaskType.FEATURE_EXTRACTION,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
    )

    model[0].auto_model = get_peft_model(model[0].auto_model, peft_config)

    loss = MultipleNegativesRankingLoss(model)

    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=OUTPUT_PATH,
        # Optional training parameters:
        num_train_epochs=EPOCH,
        per_device_train_batch_size=BS,  # 16,
        gradient_accumulation_steps=GRAD_ACC_STEP,
        per_device_eval_batch_size=BS,  # 16,
        eval_accumulation_steps=GRAD_ACC_STEP,
        learning_rate=LR,
        weight_decay=0.01,
        warmup_ratio=0.1,
        fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
        fp16_full_eval=True,
        bf16=False,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        # Optional tracking/debugging parameters:
        lr_scheduler_type="cosine_with_restarts",
        eval_strategy="steps",
        eval_steps=0.1,
        save_strategy="steps",
        save_steps=0.1,
        save_total_limit=2,
        logging_steps=100,
        report_to=REPORT_TO,  # Will be used in W&B if `wandb` is installed
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    )

    dev_evaluator = TripletEvaluator(
        anchors=valid["AllText"],
        positives=valid["MisconceptionName"],
        negatives=valid["PredictMisconceptionName"],
        name=f"{MODEL_NAME}-dev",
    )
    # dev_evaluator(model)

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train.select_columns(
            ["AllText", "MisconceptionName", "PredictMisconceptionName"]
        ),
        eval_dataset=valid.select_columns(
            ["AllText", "MisconceptionName", "PredictMisconceptionName"]
        ),
        loss=loss,
        evaluator=dev_evaluator,
    )

    trainer.train()
    model.save_pretrained(MODEL_OUTPUT_PATH)
else:
    model = SentenceTransformer(MODEL_OUTPUT_PATH)

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/22.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/9.89G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/8.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/69.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/640 [00:00<?, ?B/s]

Step,Training Loss,Validation Loss,Baai/bge-en-icl-dev Cosine Accuracy,Baai/bge-en-icl-dev Dot Accuracy,Baai/bge-en-icl-dev Manhattan Accuracy,Baai/bge-en-icl-dev Euclidean Accuracy,Baai/bge-en-icl-dev Max Accuracy
1112,0.4121,2.744061,0.637089,0.348607,0.527278,0.532764,0.637089
2224,0.3305,3.110274,0.654192,0.33839,0.50508,0.509335,0.654192
3336,0.2249,2.825666,0.674262,0.328425,0.518209,0.521512,0.674262
4448,0.2201,2.966031,0.688985,0.314654,0.509531,0.512218,0.688985
5560,0.1939,2.612731,0.74119,0.268859,0.518992,0.522547,0.74119
6672,0.1599,2.635054,0.76669,0.245486,0.526074,0.529125,0.76669
7784,0.1353,2.809931,0.76795,0.243415,0.534724,0.537159,0.76795
8896,0.046,2.853986,0.771924,0.241036,0.529769,0.532596,0.771924
10008,0.0338,2.831838,0.782449,0.232526,0.53881,0.540378,0.782449


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

# Make Vector

In [22]:
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

train_long = (
    pl.read_csv(f"{DATA_PATH}/train.csv")
    .select(
        pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]])
    )
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("ConstructName"),
                pl.col("SubjectName"),
                pl.col("QuestionText"),
                pl.col("AnswerText"),
            ],
            separator=" ",
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
train_long.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""0_A"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""0_B"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""0_C"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D"""
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction…","""A""","""1000_A"""


In [23]:
train_misconception_long = (
    pl.read_csv(f"{DATA_PATH}/train.csv")
    .select(
        pl.col(
            common_col + [f"Misconception{alpha}Id" for alpha in ["A", "B", "C", "D"]]
        )
    )
    .unpivot(
        index=common_col,
        variable_name="MisconceptionType",
        value_name="MisconceptionId",
    )
    .with_columns(
        pl.col("MisconceptionType")
        .str.extract(r"Misconception([A-D])Id$")
        .alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
    .select(pl.col(["QuestionId_Answer", "MisconceptionId"]))
    .with_columns(pl.col("MisconceptionId").cast(pl.Int64))
)
train_misconception_long.head()

QuestionId_Answer,MisconceptionId
str,i64
"""0_A""",
"""0_B""",
"""0_C""",
"""0_D""",1672.0
"""1000_A""",891.0


In [24]:
# join MisconceptionId
train_long = train_long.join(train_misconception_long, on="QuestionId_Answer")

In [25]:
valid_long = train_long.filter(
    pl.col("QuestionId_Answer").is_in(set(valid["QuestionId_Answer"]))
)

In [26]:
valid_long

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId
i64,str,str,str,str,str,str,str,str,str,i64
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672
1002,"""Convert fractions less than 1 …","""Converting between Fractions a…","""Convert \( \frac{3}{20} \) int…","""A""","""AnswerBText""","""\( 0.3 \)""","""Convert fractions less than 1 …","""B""","""1002_B""",1715
1002,"""Convert fractions less than 1 …","""Converting between Fractions a…","""Convert \( \frac{3}{20} \) int…","""A""","""AnswerCText""","""\( 3.20 \)""","""Convert fractions less than 1 …","""C""","""1002_C""",2308
1005,"""Convert fractions less than 1 …","""Converting between Fractions a…","""What is \( \frac{3}{5} \) as a…","""D""","""AnswerAText""","""\( 0.5 \)""","""Convert fractions less than 1 …","""A""","""1005_A""",760
1005,"""Convert fractions less than 1 …","""Converting between Fractions a…","""What is \( \frac{3}{5} \) as a…","""D""","""AnswerBText""","""\( 0.3 \)""","""Convert fractions less than 1 …","""B""","""1005_B""",1715
…,…,…,…,…,…,…,…,…,…,…
99,"""Given the perimeter, work out …","""Volume and Capacity Units""","""A regular pentagon has a total…","""C""","""AnswerBText""","""\( 12 \mathrm{~cm} \)""","""Given the perimeter, work out …","""B""","""99_B""",1815
99,"""Given the perimeter, work out …","""Volume and Capacity Units""","""A regular pentagon has a total…","""C""","""AnswerDText""","""Not enough information""","""Given the perimeter, work out …","""D""","""99_D""",255
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerAText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""A""","""9_A""",1889
9,"""Identify horizontal translatio…","""Transformations of functions i…","""What transformation maps the g…","""C""","""AnswerBText""","""Translation by vector \( \left…","""Identify horizontal translatio…","""B""","""9_B""",1234


In [27]:
# model = SentenceTransformer(MODEL_OUTPUT_PATH)

valid_long_vec = model.encode(
    valid_long["AllText"].to_list(), normalize_embeddings=True
)
misconception_mapping = pl.read_csv(f"{DATA_PATH}/misconception_mapping.csv")
misconception_mapping_vec = model.encode(
    misconception_mapping["MisconceptionName"].to_list(), normalize_embeddings=True
)
print(valid_long_vec.shape)
print(misconception_mapping_vec.shape)

In [28]:
# misconception_mapping_vecを保存する
os.makedirs(OUTPUT_PATH, exist_ok=True)
np.save(f"{OUTPUT_PATH}/misconception_mapping_vec.npy", misconception_mapping_vec)

In [29]:
valid_cos_sim_arr = cosine_similarity(valid_long_vec, misconception_mapping_vec)
valid_sorted_indices = np.argsort(-valid_cos_sim_arr, axis=1)

In [30]:
# example
def print_example(df: pl.DataFrame, sorted_indices: np.ndarray, idx: int) -> None:
    print(f"Query idx{idx}")
    print(df["AllText"][idx])
    print("\nCos Sim No.1")
    print(misconception_mapping["MisconceptionName"][int(sorted_indices[idx, 0])])
    print("\nCos Sim No.2")
    print(misconception_mapping["MisconceptionName"][int(sorted_indices[idx, 1])])

In [31]:
print_example(train_long, valid_sorted_indices, 0)

In [32]:
print_example(train_long, valid_sorted_indices, 1)

# Evaluate

In [33]:
valid_long = valid_long.with_columns(
    pl.Series(valid_sorted_indices[:, :RETRIEVE_NUM].tolist()).alias(
        "PredictMisconceptionId"
    )
)

In [34]:
# https://www.kaggle.com/code/cdeotte/how-to-train-open-book-model-part-1#MAP@3-Metric
def map_at_25(predictions, labels):
    map_sum = 0
    for x, y in zip(predictions, labels):
        z = [1 / i if y == j else 0 for i, j in zip(range(1, 26), x)]
        map_sum += np.sum(z)
    return map_sum / len(predictions)

In [35]:
map_at_25_score = map_at_25(
    valid_long.filter(pl.col("MisconceptionId").is_not_null())[
        "PredictMisconceptionId"
    ],
    valid_long.filter(pl.col("MisconceptionId").is_not_null())["MisconceptionId"],
)
map_at_25_score

0.18824167124198196

In [36]:
def recall(predictions, labels):
    acc_num = np.sum([1 for x, y in zip(predictions, labels) if y in x])
    return acc_num / len(predictions)


recall_score = recall(
    valid_long.filter(pl.col("MisconceptionId").is_not_null())[
        "PredictMisconceptionId"
    ],
    valid_long.filter(pl.col("MisconceptionId").is_not_null())["MisconceptionId"],
)
recall_score

0.49246575342465754

In [37]:
# output_textを保存
with open(f"{MODEL_OUTPUT_PATH}/cv_score.txt", "w") as f:
    f.write(f"MAP@25:{map_at_25_score:.4f}, Recall:{recall_score:.4f}")

# Make Retrieved Train File

In [38]:
valid_retrieved = (
    valid_long.filter(pl.col("MisconceptionId").is_not_null())
    .explode("PredictMisconceptionId")
    .with_columns(
        (pl.col("MisconceptionId") == pl.col("PredictMisconceptionId"))
        .cast(pl.Int64)
        .alias("target")
    )
    .join(
        misconception_mapping.with_columns(pl.all().name.prefix("Predict")),
        on="PredictMisconceptionId",
    )
)
valid_retrieved.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId,PredictMisconceptionId,target,MisconceptionId_right,MisconceptionName,PredictMisconceptionName
i64,str,str,str,str,str,str,str,str,str,i64,i64,i64,i64,str,str
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,1005,0,1005,"""Carries out operations from le…","""Carries out operations from le…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,1345,0,1345,"""Inserts brackets but not chang…","""Inserts brackets but not chang…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,315,0,315,"""Has not realised that the answ…","""Has not realised that the answ…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,706,0,706,"""Carries out operations from ri…","""Carries out operations from ri…"
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672,974,0,974,"""Believes multiplying a positiv…","""Believes multiplying a positiv…"


In [39]:
valid_retrieved.write_csv(
    f"{MODEL_OUTPUT_PATH}/{EXP_NAME}-valid-ret{RETRIEVE_NUM}-map{map_at_25_score:.4f}-recall{recall_score:.4f}.csv",
)

In [40]:
print(
    f"{MODEL_OUTPUT_PATH}/{EXP_NAME}-valid-ret{RETRIEVE_NUM}-map{map_at_25_score:.4f}-recall{recall_score:.4f}.csv",
)

# Kaggle Upload

In [41]:
import os
import json

from kaggle.api.kaggle_api_extended import KaggleApi


def dataset_create_new(dataset_name: str, upload_dir: str):
    # if "_" in dataset_name:
    #     raise ValueError("datasetの名称に_の使用は禁止です")
    dataset_metadata = {}
    dataset_metadata["id"] = f"sinchir0/{dataset_name}"
    dataset_metadata["licenses"] = [{"name": "CC0-1.0"}]
    dataset_metadata["title"] = dataset_name
    with open(os.path.join(upload_dir, "dataset-metadata.json"), "w") as f:
        json.dump(dataset_metadata, f, indent=4)
    api = KaggleApi()
    api.authenticate()
    api.dataset_create_new(folder=upload_dir, convert_to_csv=False, dir_mode="tar")


print(f"Create Dataset name:{DATASET_NAME}, output_dir:{OUTPUT_PATH}")
dataset_create_new(dataset_name=DATASET_NAME, upload_dir=OUTPUT_PATH)