# Обучние кросс-энкодера

Для переранижирования кандидатов я буду использовать обученную на классификацию bert-модель. Модель будет оценивать предложенных кандидатов, состоящих из контекста, вопроса и ответа на предмет того, является ли ответ продолжением контекста + ответа.

Для ранжирования правильных ответов - буду выбирать уверенность модели в классификации.

Ниже представлен код для обучения модели и сохранения ее на Hugging Face для использования в чат боте.

In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from transformers import Trainer, TrainingArguments, set_seed
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
import evaluate
import wandb
from transformers import EvalPrediction
import os
import warnings

warnings.filterwarnings("ignore")

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
data = pd.read_pickle("data/scripts_for_reranker.pkl")
df_train, df_valid = train_test_split(data, test_size=0.2)

dataset = DatasetDict(
    {
        "train": Dataset.from_pandas(df_train.reset_index(drop=True)),
        "valid": Dataset.from_pandas(df_valid.reset_index(drop=True)),
    }
)
dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'context', 'label', 'answer', 'combined'],
        num_rows: 17436
    })
    valid: Dataset({
        features: ['question', 'context', 'label', 'answer', 'combined'],
        num_rows: 4360
    })
})

In [4]:
from collections import Counter

Counter(dataset['train']['label'])

Counter({0: 8773, 1: 8663})

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [6]:
def preprocess_data(examples):

    encoding = tokenizer(
        examples['combined'],
        add_special_tokens=True,
    )
    return encoding

In [7]:
ACCURACY = evaluate.load("accuracy")


def compute_metrics(p: EvalPrediction):
    preds = p.predictions
    preds = np.argmax(preds, axis=1)

    acc_result = ACCURACY.compute(predictions=preds, references=p.label_ids)
    result = {
        "accuracy": acc_result["accuracy"],
    }
    return result

In [8]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

encoded_dataset = dataset.map(preprocess_data, batched=True)
encoded_dataset = encoded_dataset.remove_columns(["context", "question", "answer", "combined"])
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format("torch")
encoded_dataset

Map:   0%|          | 0/17436 [00:00<?, ? examples/s]

Map:   0%|          | 0/4360 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 17436
    })
    valid: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 4360
    })
})

In [9]:
model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", num_labels=2
)

os.environ["WANDB_PROJECT"] = "reranker_train"
os.environ["WANDB_LOG_MODEL"] = "true"

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkatya_shakhova[0m ([33mshakhova[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
training_args = TrainingArguments(
    output_dir=f"RerankerModel_chat_bot",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=5e-5,
    weight_decay=0.001,
    num_train_epochs=3,
    warmup_ratio=0.1,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    save_strategy="no",
    save_total_limit=1,
    group_by_length=True,
    push_to_hub=True,
    report_to="wandb",
)

In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["valid"],
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

  0%|          | 0/6540 [00:00<?, ?it/s]

{'loss': 0.6812, 'learning_rate': 3.822629969418961e-05, 'epoch': 0.23}
{'loss': 0.6066, 'learning_rate': 4.957490420675732e-05, 'epoch': 0.46}
{'loss': 0.5726, 'learning_rate': 4.7494362173553114e-05, 'epoch': 0.69}
{'loss': 0.56, 'learning_rate': 4.38212672347195e-05, 'epoch': 0.92}


  0%|          | 0/545 [00:00<?, ?it/s]

{'eval_loss': 0.5250728726387024, 'eval_accuracy': 0.7254587155963302, 'eval_runtime': 16.6173, 'eval_samples_per_second': 262.377, 'eval_steps_per_second': 32.797, 'epoch': 1.0}
{'loss': 0.4904, 'learning_rate': 3.881566668443446e-05, 'epoch': 1.15}
{'loss': 0.4679, 'learning_rate': 3.2831946374551544e-05, 'epoch': 1.38}
{'loss': 0.4406, 'learning_rate': 2.629374095149702e-05, 'epoch': 1.61}
{'loss': 0.416, 'learning_rate': 1.9663941426082897e-05, 'epoch': 1.83}


  0%|          | 0/545 [00:00<?, ?it/s]

{'eval_loss': 0.516691267490387, 'eval_accuracy': 0.7612385321100917, 'eval_runtime': 18.6674, 'eval_samples_per_second': 233.562, 'eval_steps_per_second': 29.195, 'epoch': 2.0}
{'loss': 0.3747, 'learning_rate': 1.3411923476378066e-05, 'epoch': 2.06}
{'loss': 0.29, 'learning_rate': 7.980316649956704e-06, 'epoch': 2.29}
{'loss': 0.2644, 'learning_rate': 3.7536671351888096e-06, 'epoch': 2.52}
{'loss': 0.2795, 'learning_rate': 1.0312127105846947e-06, 'epoch': 2.75}
{'loss': 0.2469, 'learning_rate': 5.697347762481653e-09, 'epoch': 2.98}


  0%|          | 0/545 [00:00<?, ?it/s]

{'eval_loss': 0.7941048741340637, 'eval_accuracy': 0.7814220183486239, 'eval_runtime': 18.8117, 'eval_samples_per_second': 231.77, 'eval_steps_per_second': 28.971, 'epoch': 3.0}
{'train_runtime': 610.1069, 'train_samples_per_second': 85.736, 'train_steps_per_second': 10.719, 'train_loss': 0.43677769144740675, 'epoch': 3.0}


TrainOutput(global_step=6540, training_loss=0.43677769144740675, metrics={'train_runtime': 610.1069, 'train_samples_per_second': 85.736, 'train_steps_per_second': 10.719, 'train_loss': 0.43677769144740675, 'epoch': 3.0})

In [13]:
wandb.finish()

VBox(children=(Label(value='418.577 MB of 418.577 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/accuracy,▁▅█
eval/loss,▁▁█
eval/runtime,▁██
eval/samples_per_second,█▁▁
eval/steps_per_second,█▁▁
train/epoch,▁▂▂▃▃▃▄▄▅▅▆▆▇▇███
train/global_step,▁▂▂▃▃▃▄▄▅▅▆▆▇▇███
train/learning_rate,▆██▇▆▆▅▄▃▂▂▁▁
train/loss,█▇▆▆▅▅▄▄▃▂▁▂▁
train/total_flos,▁

0,1
eval/accuracy,0.78142
eval/loss,0.7941
eval/runtime,18.8117
eval/samples_per_second,231.77
eval/steps_per_second,28.971
train/epoch,3.0
train/global_step,6540.0
train/learning_rate,0.0
train/loss,0.2469
train/total_flos,2270683352193840.0
