In [None]:
import pandas as pd

df = pd.read_csv(
    "qidpidtriples.top3.clean.tsv",
    sep="\t",
    encoding="utf-16",
    names=["qid", "query", "pos_pid", "positive", "neg_pid", "negative"],
    header=0  
)

df

Unnamed: 0,qid,query,pos_pid,positive,neg_pid,negative
0,1000094,where is whitemarsh island,5399011,"Whitemarsh Island, Georgia. Whitemarsh Island ...",271630,Underwater Volcano Forms New South Pacific Isl...
1,1000094,where is whitemarsh island,5399011,"Whitemarsh Island, Georgia. Whitemarsh Island ...",5534953,"Komodo is one of the 17,508 islands that make ..."
2,1000684,where is your perineum,6133670,That part of the floor of the PELVIS that lies...,54955,rule of nines (rÅ«l nÄ«nz) Method used in calc...
3,1000684,where is your perineum,6133670,That part of the floor of the PELVIS that lies...,5952792,This delicate triangle is important during chi...
4,1000684,where is your perineum,6133670,That part of the floor of the PELVIS that lies...,4455896,"1 abdomen: Latin abdomen = the belly, the part..."
...,...,...,...,...,...,...
42268,1063764,why did the war on western front turn into a s...,4060164,Another reason why Stalemate developed on the ...,2176755,2001 Mitsubishi Mirage car wont turn over. Eng...
42269,1077589,will kramer robertson be drafted,5374462,7 LSU baseball players selected for MLB draft ...,2798501,"On Tour Â» Poppy Harlow, Sinisa Babcic, Myla K..."
42270,1081834,zale name meaning,3461499,"Meaning of Zale. Greek name. In Greek, the nam...",6954998,Greek Meaning: The name Teresa is a Greek baby...
42271,1081834,zale name meaning,3461499,"Meaning of Zale. Greek name. In Greek, the nam...",2231931,The name Lila is a Polish baby name. In Polish...


## Appending text to positive and negative PIDs

In [None]:
import pandas as pd

df_pos = df[["query", "positive"]].copy()
df_pos = df_pos.rename(columns={"positive":"passage"})
df_pos["label"] = 1

df_neg = df[["query", "negative"]].copy()
df_neg = df_neg.rename(columns={"negative":"passage"})
df_neg["label"] = 0

cross_df = pd.concat([df_pos, df_neg], ignore_index=True)
print(cross_df.shape)

cross_df

(84546, 3)


Unnamed: 0,query,passage,label
0,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ...",1
1,where is whitemarsh island,"Whitemarsh Island, Georgia. Whitemarsh Island ...",1
2,where is your perineum,That part of the floor of the PELVIS that lies...,1
3,where is your perineum,That part of the floor of the PELVIS that lies...,1
4,where is your perineum,That part of the floor of the PELVIS that lies...,1
...,...,...,...
84541,why did the war on western front turn into a s...,2001 Mitsubishi Mirage car wont turn over. Eng...,0
84542,will kramer robertson be drafted,"On Tour Â» Poppy Harlow, Sinisa Babcic, Myla K...",0
84543,zale name meaning,Greek Meaning: The name Teresa is a Greek baby...,0
84544,zale name meaning,The name Lila is a Polish baby name. In Polish...,0


## Training the cross-encoder

In [None]:
from sentence_transformers import CrossEncoder, InputExample
from torch.utils.data import DataLoader
import torch

print("Torch CUDA available:", torch.cuda.is_available())

# Convert dataframe rows into InputExamples
train_samples = [
    InputExample(texts=[row['query'], row['passage']], label=float(row['label']))
    for _, row in cross_df.iterrows()
]

# Wrap them in a DataLoader
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=16)

# Initialize cross encoder
model = CrossEncoder(
    'cross-encoder/ms-marco-MiniLM-L-6-v2',
    num_labels=1  # we used binary classification
)

# Train the model
model.fit(
    train_dataloader=train_dataloader,
    epochs=3,
    warmup_steps=100,
    output_path="./cross-encoder-model",
    use_amp=True # for mixed precision training which uses float16 and is faster on modern GPUs (we used NVIDIA T4 on GCP)
)
model.save("./cross-encoder-model")

  from .autonotebook import tqdm as notebook_tqdm


Torch CUDA available: True


Step,Training Loss
500,0.1679
1000,0.1507
1500,0.1537
2000,0.1428
2500,0.1511
3000,0.1393
3500,0.143
4000,0.1409
4500,0.1253
5000,0.1185
