# Fine-tuning Sentence Transformer

https://huggingface.co/blog/how-to-train-sentence-transformers

In [1]:
DEBUG = False

In [2]:
from datasets import load_dataset

In [3]:
!pip -qqq install sentence-transformers
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer, models, InputExample, losses
from datasets import Dataset
from torch.utils.data import DataLoader

In [4]:
DATA_PATH = "/kaggle/input/learning-equality-curriculum-recommendations/"
topics = pd.read_csv(DATA_PATH + "topics.csv")
content = pd.read_csv(DATA_PATH + "content.csv")
correlations = pd.read_csv(DATA_PATH + "correlations.csv")

In [5]:
topics["title"] = topics["title"].fillna("No Title")
content["title"] = content["title"].fillna("No Title")

In [6]:
correlations.head()

Unnamed: 0,topic_id,content_ids
0,t_00004da3a1b2,c_1108dd0c7a5d c_376c5a8eb028 c_5bc0e1e2cba0 c...
1,t_00068291e9a4,c_639ea2ef9c95 c_89ce9367be10 c_ac1672cdcd2c c...
2,t_00069b63a70a,c_11a1dc0bfb99
3,t_0006d41a73a8,c_0c6473c3480d c_1c57a1316568 c_5e375cf14c47 c...
4,t_0008768bdee6,c_34e1424229b4 c_7d1a964d66d5 c_aab93ee667f4


In [7]:
train = pd.read_csv("/kaggle/input/lecr-uns-top-n-50/train.csv")

In [8]:
train_topic_ids_content_ids = train.groupby("topics_ids")["content_ids"].apply(set).sort_index()

In [9]:
correlations["content_ids"] = correlations["content_ids"].apply(lambda x: set(x.split()))
corr_topic_ids_content_ids = correlations.set_index("topic_id")
corr_topic_ids_content_ids = corr_topic_ids_content_ids.sort_index()

In [10]:
train_topic_ids_content_ids.head()

topics_ids
t_00004da3a1b2    {c_cbbf192e3fb1, c_ec09c6bd0877, c_adb20c7622a...
t_00068291e9a4    {c_83e247629e9b, c_5b35ca71313d, c_f96a0ab78be...
t_00069b63a70a    {c_271a79646124, c_864332cb2d95, c_4ea70c66b21...
t_0006d41a73a8    {c_af08c6756929, c_9cfd108287fb, c_9a1da8bc33f...
t_0008768bdee6    {c_fd2a0d4fdf2a, c_7b1ff48ee7d2, c_bf882e1890d...
Name: content_ids, dtype: object

In [11]:
corr_topic_ids_content_ids.head()

Unnamed: 0_level_0,content_ids
topic_id,Unnamed: 1_level_1
t_00004da3a1b2,"{c_1108dd0c7a5d, c_76231f9d0b5e, c_376c5a8eb02..."
t_00068291e9a4,"{c_ebb7fdf10a7e, c_ac1672cdcd2c, c_89ce9367be1..."
t_00069b63a70a,{c_11a1dc0bfb99}
t_0006d41a73a8,"{c_d7a0d7eaf799, c_5e375cf14c47, c_0c6473c3480..."
t_0008768bdee6,"{c_7d1a964d66d5, c_34e1424229b4, c_aab93ee667f4}"


In [12]:
output_id = {}
for topic_id, train_content_id, corr_content_id in tqdm(
    zip(train_topic_ids_content_ids.index, train_topic_ids_content_ids, corr_topic_ids_content_ids["content_ids"]),
    total=len(train_topic_ids_content_ids)
):
    pos = corr_content_id
    neg = train_content_id - corr_content_id
    output_id[topic_id] = {"pos": pos, "neg": neg}

  0%|          | 0/61517 [00:00<?, ?it/s]

In [13]:
output_list = []
for topic_id in tqdm(output_id.keys()):
    topic_title = topics.loc[(topics["id"] == topic_id), "title"].values[0]
    pos_content_ids = output_id[topic_id]["pos"]
    neg_content_ids = output_id[topic_id]["neg"]
    
    pos_content_titles = content.loc[(content["id"].isin(list(pos_content_ids))),"title"].tolist()
    neg_content_titles = content.loc[(content["id"].isin(list(neg_content_ids))),"title"].tolist()
    output_list.append([{"query": topic_title,"pos": pos_content_titles, "neg": neg_content_titles}])
    if DEBUG and len(output_list) == 100:
        break

  0%|          | 0/61517 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
df_dataset = pd.DataFrame(output_list, columns=["set"])

In [None]:
import itertools
from sentence_transformers import InputExample

train_examples = []
train_data = df_dataset['set']
n_examples = df_dataset.shape[0]

cnt = 0
for i in tqdm(range(n_examples), total=n_examples):
    example = train_data[i]
    for query, pos, neg in itertools.product([example['query']], example['pos'], example['neg'][:10]):
        train_examples.append(InputExample(texts=[query, pos, neg]))
        cnt += 1
    if DEBUG and cnt >= 100:
        break
    elif cnt >= 100_000: # positive sampleのみだと6万ぐらいであり、上回るように
        break

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=64)

In [None]:
from sentence_transformers import SentenceTransformer

model_id = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
model = SentenceTransformer(model_id)

In [None]:
from sentence_transformers import losses

train_loss = losses.TripletLoss(model=model)

In [None]:
num_epochs = 10
if DEBUG:
    num_epochs = 1
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) #10% of train data
    
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=num_epochs,
    warmup_steps=warmup_steps
)

In [None]:
model.save("/kaggle/working/paraphrase-multilingual-mpnet-base-v2-exp")