In [9]:
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
import random
import torch
from torch.nn import CosineSimilarity
import torch.nn.functional as F
import numpy as np

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load SCIDOCS corpus
corpus = load_dataset("BeIR/scidocs", "corpus")

# Initialize tokenizer and models, move models to device
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
query_model = AutoModel.from_pretrained("bert-base-uncased").to(device)
document_model = AutoModel.from_pretrained("bert-base-uncased").to(device)

# Cosine similarity function
cosine_sim = CosineSimilarity(dim=1)

# Function to encode text in batches
def encode_text_batch(texts, tokenizer, model):
    inputs = tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=64).to(device)
    embeddings = model(**inputs).last_hidden_state[:, 0, :]  # CLS token embedding
    return embeddings

# Function to create ICT and BFS pairs
def create_ict_pair(text):
    sentences = text.split(". ")
    if len(sentences) <= 1:
        return None, None
    query = sentences.pop(random.randint(0, len(sentences) - 1))
    document = ". ".join(sentences)
    return query, document

def create_bfs_pair(text):
    sections = text.split(". ")
    if len(sections) <= 1:
        return None, None
    query = sections[0]
    document = ". ".join(sections[1:])
    return query, document

# Training parameters
learning_rate = 5e-6  # Lower learning rate
batch_size = 8
epochs = 2
loss_threshold = 0.01  # Higher threshold for early stopping

# Initialize optimizer and learning rate scheduler
optimizer = torch.optim.Adam(list(query_model.parameters()) + list(document_model.parameters()), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7, verbose=True)

# Training loop with early stopping
for epoch in range(epochs):
    queries = []
    documents = []
    early_stop = False  # Flag for early stopping
    
    for idx, item in enumerate(corpus['corpus']):
        document_text = item['text']
        
        # Alternate between ICT and BFS
        if random.random() > 0.5:
            query, document = create_ict_pair(document_text)
        else:
            query, document = create_bfs_pair(document_text)

        # Append to batch if pair is valid
        if query and document:
            queries.append(query)
            documents.append(document)
        
        # Process in batches
        if len(queries) == batch_size:
            # Encode queries and documents
            query_embeddings = encode_text_batch(queries, tokenizer, query_model)
            document_embeddings = encode_text_batch(documents, tokenizer, document_model)
            
            # Calculate similarity and loss (margin-based contrastive loss)
            positive_similarity = cosine_sim(query_embeddings, document_embeddings)
            loss = -torch.log(positive_similarity + 1e-8).mean()
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            print(f"Epoch {epoch + 1}, Step {idx + 1}, Loss: {loss.item()}")
            
            # Check for early stopping
            if loss.item() < loss_threshold:
                print(f"Early stopping at Epoch {epoch + 1}, Step {idx + 1}, Loss: {loss.item()}")
                early_stop = True
                break
            
            # Update learning rate based on loss
            scheduler.step(loss)
            
            # Clear batch
            queries = []
            documents = []

    if early_stop:
        break  # Stop training if early stop condition is met

# Save models
query_model.save_pretrained("query_model_scidocs")
document_model.save_pretrained("document_model_scidocs")
tokenizer.save_pretrained("tokenizer_scidocs")




Epoch 1, Step 8, Loss: 0.2534204125404358
Epoch 1, Step 16, Loss: 0.2427542358636856
Epoch 1, Step 24, Loss: 0.19090424478054047
Epoch 1, Step 32, Loss: 0.15652111172676086
Epoch 1, Step 40, Loss: 0.1298975646495819
Epoch 1, Step 48, Loss: 0.1380399465560913
Epoch 1, Step 56, Loss: 0.0988500788807869
Epoch 1, Step 64, Loss: 0.10447648912668228
Epoch 1, Step 72, Loss: 0.10988999158143997
Epoch 1, Step 80, Loss: 0.0863654688000679
Epoch 1, Step 88, Loss: 0.07435011863708496
Epoch 1, Step 96, Loss: 0.08043905347585678
Epoch 1, Step 104, Loss: 0.0729428231716156
Epoch 1, Step 112, Loss: 0.06611660122871399
Epoch 1, Step 120, Loss: 0.06419971585273743
Epoch 1, Step 129, Loss: 0.041641946882009506
Epoch 1, Step 138, Loss: 0.04795615002512932
Epoch 1, Step 146, Loss: 0.04832075536251068
Epoch 1, Step 154, Loss: 0.04018230736255646
Epoch 1, Step 162, Loss: 0.038446441292762756
Epoch 1, Step 171, Loss: 0.04135090112686157
Epoch 1, Step 179, Loss: 0.0362999327480793
Epoch 1, Step 188, Loss: 0.04

('tokenizer_scidocs\\tokenizer_config.json',
 'tokenizer_scidocs\\special_tokens_map.json',
 'tokenizer_scidocs\\vocab.txt',
 'tokenizer_scidocs\\added_tokens.json',
 'tokenizer_scidocs\\tokenizer.json')