# Custom Model

Use a custom model with custom loss function to co-train bm25 and bert

In [1]:

from sentence_transformers import losses, models, SentenceTransformer
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.train import TrainRetriever
import pathlib, os, tqdm
import logging

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

  from tqdm.autonotebook import tqdm, trange


In [3]:

#### Download nfcorpus.zip dataset and unzip the dataset
dataset = "scifact"

url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join("../datasets")
data_path = util.download_and_unzip(url, out_dir)


In [4]:

#### Provide the data_path where scifact has been downloaded and unzipped
corpus, queries, qrels = GenericDataLoader(data_path).load(split="train")



2024-05-28 21:57:52 - Loading Corpus...


100%|██████████| 5183/5183 [00:00<00:00, 28404.32it/s]


2024-05-28 21:57:53 - Loaded 5183 TRAIN Documents.
2024-05-28 21:57:53 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 ver

In [5]:

#### Lexical Retrieval using Bm25 (Elasticsearch) ####
#### Provide a hostname (localhost) to connect to ES instance
#### Define a new index name or use an already existing one.
#### We use default ES settings for retrieval
#### https://www.elastic.co/

hostname = "localhost" #localhost
index_name = "scifact" # scifact

#### Intialize #### 
# (1) True - Delete existing index and re-index all documents from scratch 
# (2) False - Load existing index
initialize = True # False

#### Sharding ####
# (1) For datasets with small corpus (datasets ~ < 5k docs) => limit shards = 1 
# SciFact is a relatively small dataset! (limit shards to 1)
number_of_shards = 1
model = BM25(index_name=index_name, hostname=hostname, initialize=initialize, number_of_shards=number_of_shards)


2024-05-28 21:57:57 - Activating Elasticsearch....
2024-05-28 21:57:57 - Elastic Search Credentials: {'hostname': 'localhost', 'index_name': 'scifact', 'keys': {'title': 'title', 'body': 'txt'}, 'timeout': 100, 'retry_on_timeout': True, 'maxsize': 24, 'number_of_shards': 1, 'language': 'english'}
2024-05-28 21:57:57 - Deleting previous Elasticsearch-Index named - scifact
2024-05-28 21:58:00 - Creating fresh Elasticsearch-Index named - scifact


In [6]:
bm25 = EvaluateRetrieval(model)

#### Index passages into the index (seperately)
bm25.retriever.index(corpus)

triplets = []
qids = list(qrels) 
hard_negatives_max = 10

#### Retrieve BM25 hard negatives => Given a positive document, find most similar lexical documents
for idx in tqdm.tqdm(range(len(qids)), desc="Retrieve Hard Negatives using BM25"):
    query_id, query_text = qids[idx], queries[qids[idx]]
    pos_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
    pos_doc_texts = [corpus[doc_id]["title"] + " " + corpus[doc_id]["text"] for doc_id in pos_docs]
    hits = bm25.retriever.es.lexical_multisearch(texts=pos_doc_texts, top_hits=hard_negatives_max+1)
    for (pos_text, hit) in zip(pos_doc_texts, hits):
        for (neg_id, _) in hit.get("hits"):
            if neg_id not in pos_docs:
                neg_text = corpus[neg_id]["title"] + " " + corpus[neg_id]["text"]
                triplets.append([query_text, pos_text, neg_text])


  0%|          | 0/5183 [00:00<?, ?docs/s]             
Retrieve Hard Negatives using BM25: 100%|██████████| 809/809 [00:11<00:00, 71.68it/s] 


## Custom Model

In [10]:
import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel
import numpy as np
from typing import List, Dict

class CLEAR(BertPreTrainedModel):
    def __init__(self, model_path=None, **kwargs):
        self.model = None # ---> HERE Load your custom model
        # self.model = SentenceTransformer(model_path)

    def __init__(self, config, model_path, Ksi, Lambda):
        super(CLEAR, self).__init__(config)
        self.model = BertModel(config)
        self.Ksi = Ksi
        self.Lambda = Lambda
        self.act = nn.ReLU()
        self.init_weights()
    
    def forward(self, **kwargs):
        if len(kwargs) == 8:
            query_input_ids, query_mask = kwargs['query_input_ids'], kwargs['query_mask']
            pos_doc_input_ids, pos_doc_mask = kwargs['pos_doc_input_ids'], kwargs['pos_doc_mask']
            neg_doc_input_ids, neg_doc_mask = kwargs['neg_doc_input_ids'], kwargs['neg_doc_mask']
            pos_s_lex, neg_s_lex = kwargs['pos_s_lex'], kwargs['neg_s_lex']
            pos_s_emb = self.S_emb(self.encoding(query_input_ids, query_mask), self.encoding(pos_doc_input_ids, pos_doc_mask))
            neg_s_emb = self.S_emb(self.encoding(query_input_ids, query_mask), self.encoding(neg_doc_input_ids, neg_doc_mask))
            mr = self.Ksi - self.Lambda * (pos_s_lex - neg_s_lex)
            return torch.mean(self.act(mr.squeeze() - pos_s_emb + neg_s_emb))
        elif len(kwargs) == 5:
            query_input_ids, query_mask = kwargs['query_input_ids'], kwargs['query_mask']
            doc_input_ids, doc_mask = kwargs['doc_input_ids'], kwargs['doc_mask']
            s_lex = kwargs['s_lex']
            s_emb = self.S_emb(self.encoding(query_input_ids, query_mask), self.encoding(doc_input_ids, doc_mask))
            s_lex = self.Lambda * s_lex
            return s_lex.squeeze() + s_emb
    
    def mean_pooling(self, sequence_vectors):
        return torch.mean(sequence_vectors, dim=1)

    def encoding(self, input_ids, attention_mask):
        sequence_vectors = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
        rep = self.mean_pooling(sequence_vectors)
        return rep

    def S_emb(self, rep_q, rep_d):
        assert rep_q.shape == rep_d.shape
        return torch.mul(rep_q, rep_d).sum(1)
    
    # Write your own encoding query function (Returns: Query embeddings as numpy array)
    # For eg ==> return np.asarray(self.model.encode(queries, batch_size=batch_size, **kwargs))
    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> np.ndarray:
        return np.asarray(self.model.encode(queries, batch_size=batch_size, **kwargs))
    
    # Write your own encoding corpus function (Returns: Document embeddings as numpy array)  
    # For eg ==> sentences = [(doc["title"] + "  " + doc["text"]).strip() for doc in corpus]
    #        ==> return np.asarray(self.model.encode(sentences, batch_size=batch_size, **kwargs))
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> np.ndarray:
        sentences = [(doc["title"] + "  " + doc["text"]).strip() for doc in corpus]
        return np.asarray(self.model.encode(sentences, batch_size=batch_size, **kwargs))

In [16]:
from transformers import BertConfig, BertTokenizer, AdamW, get_linear_schedule_with_warmup

load_model_path = "bert-base-uncased"
model_name = load_model_path
config = BertConfig.from_pretrained(load_model_path)
model = CLEAR.from_pretrained(load_model_path, config, 1, 10)    


Some weights of CLEAR were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.model.embeddings.LayerNorm.bias', 'bert.model.embeddings.LayerNorm.weight', 'bert.model.embeddings.position_embeddings.weight', 'bert.model.embeddings.token_type_embeddings.weight', 'bert.model.embeddings.word_embeddings.weight', 'bert.model.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.model.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.model.encoder.layer.0.attention.output.dense.bias', 'bert.model.encoder.layer.0.attention.output.dense.weight', 'bert.model.encoder.layer.0.attention.self.key.bias', 'bert.model.encoder.layer.0.attention.self.key.weight', 'bert.model.encoder.layer.0.attention.self.query.bias', 'bert.model.encoder.layer.0.attention.self.query.weight', 'bert.model.encoder.layer.0.attention.self.value.bias', 'bert.model.encoder.layer.0.attention.self.value.weight', 'bert.model.encoder.layer.0.intermediate.dense.bias', 'bert.mode

In [17]:

#### Provide any sentence-transformers or HF model
# model_name = "distilbert-base-uncased" 
# word_embedding_model = models.Transformer(model_name, max_seq_length=300)
# pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
# model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

#### Provide a high batch-size to train better with triplets!
retriever = TrainRetriever(model=model, batch_size=12)


In [18]:

#### Prepare triplets samples
train_samples = retriever.load_train_triplets(triplets=triplets)
train_dataloader = retriever.prepare_train_triplets(train_samples)


Adding Input Examples: 100%|██████████| 749/749 [00:00<00:00, 188397.82it/s]

2024-05-28 22:13:21 - Loaded 8980 training pairs.





In [19]:

#### Training SBERT with cosine-product
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)


In [20]:

#### Prepare dev evaluator
# ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)

#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.load_dummy_evaluator()

#### Provide model save path
model_save_path = os.path.join(os.getcwd(), "output", "{}-v2-{}-bm25-hard-negs".format(model_name, dataset))
os.makedirs(model_save_path, exist_ok=True)


In [21]:


#### Configure Train params
num_epochs = 1
evaluation_steps = 10000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                use_amp=True)


2024-05-28 22:13:43 - Starting to Train...


AttributeError: 'CLEAR' object has no attribute 'fit'