# Overview
I prepare 3 Notebook.

1. [Train Tfidf Retriver](https://www.kaggle.com/code/sinchir0/retriever-tfidf-reranker-deberta-1-trn-ret) (Recall: 0.4530, CV:0.1378, LB:0.128)

2. [Train DeBERTa Reranker](https://www.kaggle.com/code/sinchir0/retriever-tfidf-reranker-deberta-2-trn-rerank)(CV: 0.1616)

3. Infer by Tfidf Retriver And DeBERTa Reranker(LB:0.169)  <- Now

Submit Time: 65min

Please let me know if there are any mistakes.

# Install

In [1]:
!pip uninstall -qq -y \
scikit-learn \
polars \
transformers \
accelerate \
datasets

In [2]:
!python -m pip install -qq --no-index --find-links=/kaggle/input/eedi-library \
scikit-learn \
polars \
transformers \
accelerate \
datasets

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 0.22.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.10.0, but you have google-cloud-bigquery 2.34.4 which is incompatible.
bigframes 0.22.0 requires google-cloud-storage>=2.0.0, but you have google-cloud-storage 1.44.0 which is incompatible.
bigframes 0.22.0 requires pandas<2.1.4,>=1.5.0, but you have pandas 2.2.2 which is incompatible.
dataproc-jupyter-plugin 0.1.79 requires pydantic~=1.10.0, but you have pydantic 2.8.2 which is incompatible.
spaghetti 1.7.6 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.
spopt 0.6.1 requires shapely>=2.0.1, but you have shapely 1.8.5.post1 which is incompatible.[0m[31m
[0m

# Setting

In [3]:
RETRIEVE_NUM = 25
EVAL_BS = 4
INFERENCE_MAX_LENGTH = 256

DATA_PATH = "/kaggle/input/eedi-mining-misconceptions-in-mathematics"
RETRIEVER_PATH = "/kaggle/input/retriever-tfidf-reranker-deberta-1-trn-ret"
RERANKER_PATH = "/kaggle/input/e003-gr-ml256-deberta-v3-xsmall"

# Import

In [4]:
import os
import pickle

from tqdm.auto import tqdm

import numpy as np
import polars as pl

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

import torch
from datasets import Dataset
from scipy.special import softmax
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)

In [5]:
NUM_PROC = os.cpu_count()

In [6]:
device = torch.device(f"cuda:0")

In [7]:
import transformers
import sklearn
import datasets

assert pl.__version__ == "1.7.1"
assert transformers.__version__ == "4.44.2"
assert sklearn.__version__ == "1.5.2"
assert datasets.__version__ == "3.0.0"

# Load

In [8]:
with open(f"{RETRIEVER_PATH}/vectorizer.pkl", "rb") as file:
    vectorizer = pickle.load(file)
    
misconception_mapping_vec = np.load(f"{RETRIEVER_PATH}/misconception_mapping_vec.npy")

# Check Environment

In [9]:
!python --version

  pid, fd = os.forkpty()


Python 3.10.14


In [10]:
!nvidia-smi

Mon Sep 16 02:12:17 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             26W /  250W |       0MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                     

# Preprocess Test

In [11]:
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

test_long = (
    pl.read_csv(f"{DATA_PATH}/test.csv")
    .select(
        pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]])
    )
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("ConstructName"),
                pl.col("SubjectName"),
                pl.col("QuestionText"),
                pl.col("AnswerText"),
            ],
            separator=" ",
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-Z])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str([pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_").alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
test_long.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""1869_A"""
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""1869_B"""
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""1869_C"""
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""1869_D"""
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerAText""","""\( m+1 \)""","""Simplify an algebraic fraction…","""A""","""1870_A"""


# Retrieval

In [12]:
test_long_vec = vectorizer.transform(test_long["AllText"])
test_cos_sim_arr = cosine_similarity(test_long_vec, misconception_mapping_vec)
test_sorted_indices = np.argsort(-test_cos_sim_arr, axis=1)

test_long = test_long.with_columns(
    pl.Series(test_sorted_indices[:, :RETRIEVE_NUM].tolist()).alias("PredictMisconceptionId")
)
test_long.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,PredictMisconceptionId
i64,str,str,str,str,str,str,str,str,str,list[i64]
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""1869_A""","[2488, 2532, … 15]"
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""1869_B""","[2488, 2532, … 15]"
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""1869_C""","[2488, 2532, … 15]"
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""1869_D""","[2488, 2551, … 1756]"
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerAText""","""\( m+1 \)""","""Simplify an algebraic fraction…","""A""","""1870_A""","[1540, 979, … 419]"


In [13]:
test = (
    test_long
    .explode("PredictMisconceptionId")
    .join(
        pl.read_csv(f"{DATA_PATH}/misconception_mapping.csv").with_columns(pl.all().name.prefix("Predict")),
        on="PredictMisconceptionId",
    )
)
test.head(10)

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,PredictMisconceptionId,MisconceptionId,MisconceptionName,PredictMisconceptionName
i64,str,str,str,str,str,str,str,str,str,i64,i64,str,str
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""1869_A""",15,15,"""Confuses the order of operatio…","""Confuses the order of operatio…"
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""1869_B""",15,15,"""Confuses the order of operatio…","""Confuses the order of operatio…"
1869,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""1869_C""",15,15,"""Confuses the order of operatio…","""Confuses the order of operatio…"
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerAText""","""\( m+1 \)""","""Simplify an algebraic fraction…","""A""","""1870_A""",29,29,"""Forgot to simplify the fractio…","""Forgot to simplify the fractio…"
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerBText""","""\( m+2 \)""","""Simplify an algebraic fraction…","""B""","""1870_B""",29,29,"""Forgot to simplify the fractio…","""Forgot to simplify the fractio…"
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerCText""","""\( m-1 \)""","""Simplify an algebraic fraction…","""C""","""1870_C""",29,29,"""Forgot to simplify the fractio…","""Forgot to simplify the fractio…"
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerDText""","""Does not simplify""","""Simplify an algebraic fraction…","""D""","""1870_D""",29,29,"""Forgot to simplify the fractio…","""Forgot to simplify the fractio…"
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerAText""","""\( m+1 \)""","""Simplify an algebraic fraction…","""A""","""1870_A""",59,59,"""Cannot identify a common facto…","""Cannot identify a common facto…"
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerBText""","""\( m+2 \)""","""Simplify an algebraic fraction…","""B""","""1870_B""",59,59,"""Cannot identify a common facto…","""Cannot identify a common facto…"
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerCText""","""\( m-1 \)""","""Simplify an algebraic fraction…","""C""","""1870_C""",59,59,"""Cannot identify a common facto…","""Cannot identify a common facto…"


# Rerank

In [14]:
tokenizer = AutoTokenizer.from_pretrained(RERANKER_PATH)
model = AutoModelForSequenceClassification.from_pretrained(RERANKER_PATH)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=16)

In [15]:
def tokenize(examples, max_token_length: int):
    separator = " [SEP] "

    joined_text = (
        examples["ConstructName"]
        + separator
        + examples["SubjectName"]
        + separator
        + examples["QuestionText"]
        + separator
        + examples["AnswerText"]
        + separator  # TODO: change special token
        + examples["PredictMisconceptionName"]
    )

    return tokenizer(
        joined_text,
        max_length=max_token_length,
        truncation=True,
        padding=False,
    )


test = Dataset.from_polars(test).map(
    tokenize,
    batched=False,
    fn_kwargs={"max_token_length": INFERENCE_MAX_LENGTH},
    num_proc=NUM_PROC,
).to_polars()

  self.pid = os.fork()


Map (num_proc=4):   0%|          | 0/300 [00:00<?, ? examples/s]

In [16]:
test = test.with_columns(
    pl.col("input_ids").list.len().alias("length")
).sort("length")

In [17]:
@torch.inference_mode()
@torch.amp.autocast("cuda")
def inference(
    test: pl.DataFrame,
    model,
    device,
    batch_size=EVAL_BS,
    max_length=INFERENCE_MAX_LENGTH
):
    probabilities = []
    
    for i in tqdm(range(0, len(test), batch_size), total=len(test) // EVAL_BS):
        batch = test[i:i + batch_size]  
        input_ids = batch["input_ids"].to_list()
        attention_mask = batch["attention_mask"].to_list()
        inputs = pad_without_fast_tokenizer_warning(
            tokenizer,
            {"input_ids": input_ids, "attention_mask": attention_mask},
            padding="longest",
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
    
        outputs = model(**inputs)
        proba = softmax(outputs.logits.detach().numpy(), -1)
        probabilities.extend(proba[:, 1])
    
    return (
        test.with_columns(
            pl.Series(probabilities).alias("pred_prob")
        )
    )

In [18]:
results = inference(test, model, device)

  0%|          | 0/75 [00:00<?, ?it/s]

In [19]:
results.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,PredictMisconceptionId,MisconceptionId,MisconceptionName,PredictMisconceptionName,input_ids,token_type_ids,attention_mask,length,pred_prob
i64,str,str,str,str,str,str,str,str,str,i64,i64,str,str,list[i32],list[i8],list[i8],u32,f32
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerDText""","""Does not simplify""","""Simplify an algebraic fraction…","""D""","""1870_D""",29,29,"""Forgot to simplify the fractio…","""Forgot to simplify the fractio…","[1, 56341, … 2]","[0, 0, … 0]","[1, 1, … 1]",58,0.013615
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerDText""","""Does not simplify""","""Simplify an algebraic fraction…","""D""","""1870_D""",1825,1825,"""Does not fully simplify fracti…","""Does not fully simplify fracti…","[1, 56341, … 2]","[0, 0, … 0]","[1, 1, … 1]",58,0.011827
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerDText""","""Does not simplify""","""Simplify an algebraic fraction…","""D""","""1870_D""",1918,1918,"""Does not fully simplify ratio""","""Does not fully simplify ratio""","[1, 56341, … 2]","[0, 0, … 0]","[1, 1, … 1]",58,0.008678
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerDText""","""Does not simplify""","""Simplify an algebraic fraction…","""D""","""1870_D""",634,634,"""Does not know the term numerat…","""Does not know the term numerat…","[1, 56341, … 2]","[0, 0, … 0]","[1, 1, … 1]",59,0.009045
1870,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""D""","""AnswerDText""","""Does not simplify""","""Simplify an algebraic fraction…","""D""","""1870_D""",792,792,"""Expands when asked to simplify""","""Expands when asked to simplify""","[1, 56341, … 2]","[0, 0, … 0]","[1, 1, … 1]",59,0.01169


In [20]:
results = (
    results.sort(by=["QuestionId_Answer", "pred_prob"], descending=[False, True])
    .group_by(["QuestionId_Answer"], maintain_order=True)
    .agg(pl.col("PredictMisconceptionId").alias("MisconceptionId"))
)

# Make Submit File

In [21]:
submission = (
    test_long.join(
        results,
        on=["QuestionId_Answer"],
    ).with_columns(
        pl.col("MisconceptionId").map_elements(
            lambda x: " ".join(map(str, x)), return_dtype=pl.String
        )
    ).filter(
        pl.col("CorrectAnswer") != pl.col("AnswerAlphabet")
    ).select(
        pl.col(["QuestionId_Answer", "MisconceptionId"])
    ).sort("QuestionId_Answer")
)

In [22]:
submission.head(10)

QuestionId_Answer,MisconceptionId
str,str
"""1869_B""","""2488 521 1005 2586 1507 706 25…"
"""1869_C""","""2488 1005 521 2586 1507 706 25…"
"""1869_D""","""2488 1005 2131 656 1392 256 25…"
"""1870_A""","""606 363 1593 1540 59 29 1825 2…"
"""1870_B""","""606 363 1540 1593 59 29 1825 2…"
"""1870_C""","""606 363 1593 1540 59 29 1825 2…"
"""1871_A""","""1349 2319 397 632 1059 365 216…"
"""1871_C""","""1349 397 2319 632 1059 365 216…"
"""1871_D""","""1349 397 2319 632 1059 365 216…"


In [23]:
submission.write_csv("submission.csv")