In [1]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Example dataset (replace this with your own dataset)
search_queries = ["How to bake a cake", "Python tutorial", "Best hiking trails"]
relevant_documents = ["A guide to baking cakes", "Learn Python programming", "Top hiking trails in the world"]

# Split dataset into train and validation sets
train_queries, val_queries, train_docs, val_docs = train_test_split(search_queries, relevant_documents, test_size=0.2, random_state=42)

# Tokenize data
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_encodings = tokenizer(train_queries, train_docs, truncation=True, padding=True)
val_encodings = tokenizer(val_queries, val_docs, truncation=True, padding=True)

class SearchDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = SearchDataset(train_encodings, [1] * len(train_queries))  # Binary labels for relevant documents
val_dataset = SearchDataset(val_encodings, [1] * len(val_queries))

# Fine-tune BERT for search
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)  # Binary classification
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch + 1}, Training loss: {train_loss / len(train_loader)}")

# Evaluation on validation set
model.eval()
val_loader = DataLoader(val_dataset, batch_size=8)
val_loss = 0
for batch in tqdm(val_loader, desc="Validation"):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        val_loss += loss.item()

print(f"Validation loss: {val_loss / len(val_loader)}")

# Example of using the fine-tuned model for search
query = "How to bake a cake"
relevant_docs = ["A guide to baking cakes", "The history of baking", "10 cake recipes for beginners"]

input_encoding = tokenizer(query, relevant_docs, truncation=True, padding=True, return_tensors="pt").to(device)
with torch.no_grad():
    output = model(**input_encoding)
probabilities = torch.softmax(output.logits, dim=1)
print("Document relevancy probabilities:", probabilities[:, 1])  # Probability of being relevant


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Train


Epoch 1, Training loss: 0.6078182458877563


Train


Epoch 2, Training loss: 0.4706692695617676


Train


Epoch 3, Training loss: 0.2963221073150635


Valid

Validation loss: 0.30848565697669983
Document relevancy probabilities: tensor([0.7020])



