<h2 align="center"><i>Two-stage learning to rank without supervision from relevance labels </i><br></h2>


<p><br>In this notebook we will train the <i>"tiny"</i> version of <i>state-of-the-art</i> information retrieval model inspired by <i>Google AI</i> blogposts:<br/></p>
<ul style="font-size:120%;">
<li>
<a href="https://ai.googleblog.com/2019/01/transformer-xl-unleashing-potential-of.html"><i>Transformer-XL</i></a>
</li>
<li>
<a href="https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html"><i>BERT</i></a>
</li><br/>
</ul></p>


<p style="font-size:120%;"> Our model will consist of 2 stages:<br/></p>
<ul>
<li>
<p style="font-size:120%; list-style-type:disc;">Classical <i>matrix-based</i> ranking method (<i>BM25Plus</i> with <i>default</i> parameters) retrieves the most relevant documents. Retrieval is fast, but may be <i>innacurate</i></p>
</li>
<li>
<p style="font-size:120%;"><i>"Tiny" BERT</i> reranks the documents retrieved by the <i>firt-stage</i> algorithm</p><br/>
</li>
</ul>

<p style="font-size:120%;">
<i>Original BERT</i> was trained during a <i>week</i> using <i>64</i> GPUs on <i>extreamly</i> large text corpuses.<br/><br/><br><br></p>

<img src="https://github.com/xkaple01/bert-information-retrieval/blob/master/bert_ir_system_gc/bert_detailed.png?raw=1">


<p style="font-size:120%;"><br/><br>Our model will be trained from <i>scratch</i> on a <i>single</i> <i>GeForce GTX 850M</i> GPU.<br/><br/></p>
<p align="center" style="text-align:center;font-size:120%;">
<b><i>The model will never touch the relevance labels during the training process</i></b><br/><br/><br/>
</p> 




In [1]:
! pip install git+https://gitlab.fi.muni.cz/xstefan3/pv211-utils.git@master
! git clone https://github.com/xkaple01/bert-information-retrieval.git
! pip install nltk
! pip install rank_bm25

Collecting git+https://gitlab.fi.muni.cz/xstefan3/pv211-utils.git@master
  Cloning https://gitlab.fi.muni.cz/xstefan3/pv211-utils.git (to revision master) to /tmp/pip-req-build-txbp9ov2
  Running command git clone -q https://gitlab.fi.muni.cz/xstefan3/pv211-utils.git /tmp/pip-req-build-txbp9ov2
Building wheels for collected packages: pv211-utils
  Building wheel for pv211-utils (setup.py) ... [?25l[?25hdone
  Created wheel for pv211-utils: filename=pv211_utils-0.1.dev22+g3e6e680-cp36-none-any.whl size=503621 sha256=445fa555b335e4f4ccdef01dc52f66dff19ee0d9c580b3edf7aa59e6d35ab5ff
  Stored in directory: /tmp/pip-ephem-wheel-cache-9enlcqrm/wheels/fc/91/96/1054afbd540b3a8d5913d19a3f34c07b339c0e58340950a781
Successfully built pv211-utils
fatal: destination path 'bert-information-retrieval' already exists and is not an empty directory.


In [2]:
from pv211_utils.entities import DocumentBase
from pv211_utils.entities import QueryBase
from pv211_utils.irsystem import IRSystem
from pv211_utils.loader import load_documents
from pv211_utils.loader import load_queries
from pv211_utils.loader import load_judgements
from pv211_utils.eval import mean_average_precision

import nltk
from nltk.corpus import stopwords
from nltk.stem.snowball import SnowballStemmer
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Plus
nltk.download('stopwords')
nltk.download('wordnet')

import re
import numpy as np
import os
import pickle
from random import shuffle
from collections import OrderedDict
from copy import deepcopy


class Document(DocumentBase):       
    def __init__(self, document_id, authors, bibliography, title, body):
        super().__init__(document_id, authors, bibliography, title, body)
        self.preprocessed = preprocess_text(self.body)

        
class Query(QueryBase):
    def __init__(self, query_id, body):
        super().__init__(query_id, body)
        self.preprocessed = preprocess_text(self.body)


sm_st = SnowballStemmer("english")
stop_words = set(stopwords.words('english')) | \
             set(['', 'co', 'eq', 'pr', 'tr', 'rl', 'psf', 'ko', 'la', 'vz', 'plk', 'o', 'etc', 'igy', 'soc',
                   'ic', 'ible','ser', 'ing', 'ob', 'feb', 'wkb', 'ao', 'dp', 'tne', 'sr', 'ux', 'som', 'aft', 'con',
                   'rev', 'j', 'b', 'p', 'e', 'a', 'sq', 'op', 'er', 'oc', 'ab', 'bc', 'de', 'im', 'fs', 'vs', 'rf',
                   'bi','et', 'al', 'th', 'rd', 'nd', 'pb', 'rt', 'rm', 'qn', 'fd', 'qe', 'qm', 'de', 'vas', 'fig',
                   'ty', 'tx', 'tz', 'pai', 'ied', 'ref', 'thn', 'jan', 'pre', 'mth', 'nth', 'uhf', 'btu', 'ink', 'rae',
                   'ofr', 'n','f', 'dx', 'dy', 'dz', 'x', 'xx', 'y', 'yy', 'z', 'zz', 'q', 's', 'ax', 'cx', 'cf', 'b',
                   'du','u', 'r', 'h', 'l', 'jmin','jmjn','viz', 'fl', 'ld', 'dvl', 're', 'tn', 'aec', 'k', 'i', 'ii',
                   'iii', 'iv', 'v', 'vi', 'vii', 'viii', 'ix', 'x', 'xi', 'xii', 'xiii', 'xiv', 'xv', 'xvi', 'xvii',
                   'xviii', 'xix', 'xx']) 


def preprocess_text(text):
    sentences = re.split(r'\s\.\s', text)
    prepr_sentences = []
    for s in sentences:
        word_tokens = list(filter(None, re.split('[\s,\(\)0-9\'\:\$\*\;\?\"\/\+\=\-]|[\s.*\.]|(aero|air|aer|super|sub|hyper|ultra|sonic|retro|poly|multi|therm|magnet|hydro|fero|accel|electro|photo|strobo|alumin|anti|less|non|post|cross|eigen|dynam|curv|cylind|linear|off|ellip|compress|molecul|metal|correl|stream|visc|crystall)', s)))
        prepr_sentence_tokens = [sm_st.stem(w) for w in word_tokens if not w in stop_words]
        if len(prepr_sentence_tokens) != 0:
            prepr_sentences.append(prepr_sentence_tokens)
    return prepr_sentences


def find_max_length(texts):
    l=[]
    for text in texts.values():
        l.append(len(set(flatten_text(text.preprocessed))))
    return max(l), np.argmax(l)


def flatten_text(text):
    return [w for s in text for w in s]
        
    
def create_new_doc_odict():
    new_doc_odict = OrderedDict()
    cnt = 0
    for doc in documents.values():
        if doc.body != '':
            cnt += 1
            new_doc_odict[cnt] = doc
    return new_doc_odict
    
    
def count_term_frequencies(texts):
    words_fr = {}
    for text in texts:
        for sentence in text.preprocessed:
            for w in sentence:
                if w not in words_fr.keys():
                    words_fr[w] = 1
                else:
                    words_fr[w] = words_fr[w] + 1
    return words_fr
                    

def remove_rare_terms_from_words_fr_dict(words_fr):
    words_fr_filtered = {}
    for w, fr in words_fr.items():
        if fr > 1:
            words_fr_filtered[w] = fr
    return words_fr_filtered


def filter_text(text, words):
    filtered_text = []
    for s in text:
        filtered_sentence = []
        for w in s:
            if w not in words:
                filtered_sentence.append(w)
        if len(filtered_sentence)>0:
            filtered_text.append(filtered_sentence)
    return filtered_text


def create_words_to_ids_dict():
    words_to_ids = {}
    words_fr = count_term_frequencies(list(documents.values()) + list(queries.values()))
    words_fr = remove_rare_terms_from_words_fr_dict(words_fr)
    unique_terms = gather_unique_terms(words_fr)    
    words_to_ids = {'CLS':0, 'SEP':1}    
    cnt = 2
    for w in sorted(unique_terms):
        words_to_ids[w] = cnt
        cnt+=1
    return words_to_ids, sorted(list(unique_terms))


def gather_unique_terms(words_fr):
    unique_terms = gather_unique_terms_from_text(queries)
    remove_duplicates(documents, unique_terms, words_fr)
    remove_duplicates(queries, unique_terms, words_fr)
    return set(unique_terms) & set(words_fr.keys())


def gather_unique_terms_from_text(texts):
    unique_terms = []
    for t in texts.values():
        for s in t.preprocessed:
            for w in s:
                if w not in unique_terms:
                    unique_terms.append(w)
    return unique_terms


def remove_duplicates(texts, unique_terms, words_fr):
    for t in texts.values():
        filtered_text = []
        for s in t.preprocessed:
            filtered_sentence = []
            for w in s:
                if w in unique_terms and w in words_fr.keys():
                    filtered_sentence.append(w)
            if len(filtered_sentence)>0:
                filtered_text.append(filtered_sentence)
        t.preprocessed = filtered_text

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
documents = load_documents(Document)
queries = load_queries(Query)
relevant = load_judgements(queries, documents)

documents = create_new_doc_odict()
words_to_ids, unique_terms = create_words_to_ids_dict()
vocab_size = len(words_to_ids.keys())
num_docs = len(documents.values())
num_queries = len(queries.values())

print('Documents: ', num_docs)
print('Queries: ', num_queries)

Documents:  1398
Queries:  225


<br>
<h3 align="center" style="text-align:center;"><i>First-stage retrieval algorithm</i><br/><br></h3>


<p style="font-size:120%;"><br/><i>Preprosessing</i> pipeline includes following steps:<br/><br/></p>
<ul style="font-size:120%;">
<li>Text is splitted into sentences, sentences are splitted into the tokens, <i>stopwords</i> are removed and each token is <i>stemmed</i> by <i>SnowballStemmer</i><br/><br/></li>
<li><i>Relevance</i> scores between <i>query</i> and <i>each document</i> are calculated using <i>BM25Plus</i> algorithm with <i>default</i> configuration (no explicit optimization on hyperparameters). The most <i>relevant documents</i> appear at the top of retrieved list of documents. Howewer, the <i>order</i> of documents may be <i>inacurate.</i></li></ul><br>


<br><br>

In [0]:
class BM25IRSystem(IRSystem):
    def __init__(self):
        self.documents = list(documents.values())
        self.corpus = [flatten_text(doc.preprocessed) for doc in self.documents]
        self.bm25 = BM25Plus(self.corpus)

        
    def first_stage_ranking(self, preprocessed_query):
        scores = self.bm25.get_scores(preprocessed_query.preprocessed[0])
        doc_list_ids = np.argsort(scores)[::-1]
        return np.take(self.documents, doc_list_ids, axis=0).tolist()[:]

                
    def search(self, query):
        preprocessed_query = queries[query.query_id]
        fs_ranked_docs = self.first_stage_ranking(preprocessed_query)
        return fs_ranked_docs

In [5]:
mean_average_precision(BM25IRSystem(), submit_result=False, author_name="Kaplenko, Mykola")

Mean average precision: 42.703% 
Not submitted.


<br>
<h3 align="center" style="text-align:center;"><i>Now, let's train the second-stage reranking model: "tiny" BERT</i><br/><br></h3>


<p style="font-size:120%;"><br/>    
Obviously, reranking models are trained in a <i>supervised</i> manner using the <i>relevance labels</i>: model takes <i>2</i> documents as inputs (one document is <i>relevant</i>, another document is always <i>non-relevant</i>), model then tries to guess which of <i>2</i> input documents is <i>relevant</i>.<br/><br/></p>

<p style="font-size:120%;"><br/>  
We <b>will not use</b> the <i>relevance labels</i>. Instead, we slightly reformulate our task in order to train the model without any supervision from the <i>relevance labels</i>.<br/><br/></p>

<p style="font-size:120%;"> The <i>key idea</i> is to train our model in a such a way that it will develop the <i>general skill</i> of languge understanding.<br/><br/><br/>
Training process is organized as follows:<br/><br/></p>
<ul style="font-size:120%;"><li>
We randomly choose <i>2</i> documents from the list of documents. Randomly extract several (let's say <i>N</i>) words from the <i>first document</i> and randomly remove <i>K</i> words from the <i>second document</i> such that the documents now have the <i>same</i> length<br/><br/></li>
<li>Model takes <i>3</i> inputs: the <i>extracted words</i>, the <i>first document</i> (without <i>extracted words</i>) and the <i>second document</i> (from which K words were removed)</li>
<br/>
<li>Model tries to answer the <i>question</i>: do <i>extracted words</i> belong to the <i>first document</i> or to the <i>second one</i>? We know from which one document we <i>extracted</i> the words, so we can train the model in the <i>supervised</i> manner, but we <b>do not use</b> the <i>relevance labels</i><br/><br/></li></ul>

<p style="font-size:120%;">
To <i>succesfully</i> perform its task, the model has to learn the <i>relations</i> between the <i>context</i> and the <i>words</i>, whose <i>meaning</i> depends on this <i>context</i>.<br/><br/></p>

<img src="https://github.com/xkaple01/bert-information-retrieval/blob/master/bert_ir_system_gc/multiple_choice_2.png?raw=1" align="middle">

<br>

<p style="font-size:120%;"><br/>
But <i><b>how</b></i> can we use this to rerank documents?<br/><br/></p>

<p style="font-size:120%;">
Trained model again takes <i>3</i> inputs: <i>query</i>, <i>whole firs document</i> (with no <i>words extracted</i>), and <i>whole second document</i>. Model then <i>predicts</i>: is it more likely that the <i>query words</i> originate ("<i>were extracted</i>") from the <i>first document</i> or from the <i>second document</i>?<br/><br/> Documents in the list of retrieved documents are then <i>pairwise</i> rearanged in such a way, that the more relevant documents occur in the list at higher positions than the less relevant documents.<br/><br></p>

<br>
<br>

In [6]:
import tensorflow as tf
! pip install transformers
from transformers import *

PATH_TO_LOGS = './logs'
PATH_TO_MODEL = './model/model_checkpoint' 

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)



In [7]:
documents = load_documents(Document)
queries = load_queries(Query)
relevant = load_judgements(queries, documents)

documents = create_new_doc_odict()
words_to_ids, unique_terms = create_words_to_ids_dict()
vocab_size = len(words_to_ids.keys())
num_docs = len(documents.values())
num_queries = len(queries.values())

print('Documents: ', num_docs)
print('Queries: ', num_queries)

Documents:  1398
Queries:  225


<br><br>
<h3 align="center"><i>Let's create the data generators:</i><br/><br></h3>
<p style="font-size:120%;">Note, the <i>TrainDataGenerator</i> never touches the relevance pairs (query - relevant document) for the training.
Relevance labels are only used in <i>ValidationDataGenerator</i> to monitor the model performance during the training and <b>does not affect</b> the training process.</p>
<br><br>

In [0]:
class TrainDataGenerator(tf.keras.utils.Sequence):   
    def __len__(self):
        return 300

    def __getitem__(self, index):
        data_x_input_ids = []
        data_x_attention_masks = []
        data_x_token_type_ids = []
        data_x_position_ids = []
        data_y = []
        for _ in range(BATCH_SIZE):
            while True:                  
                random_doc_ids = np.random.choice(num_docs, 2 , replace=False) + 1
                random_docs = np.array([documents[random_doc_ids[0]], documents[random_doc_ids[1]]])
                if not bool(set(random_docs[0].preprocessed[0]) & set(random_docs[1].preprocessed[0])):
                    break
                                
            polarity = np.random.choice(2, 2, replace=False)
            p_doc = random_docs[0]
            n_doc = random_docs[1]
            
            p_random_words, p_words, n_words = select_random_words_from_texts(p_doc.preprocessed, n_doc.preprocessed)
           
            text1 = p_random_words
            p_text = p_words
            n_text = n_words

            n_p_texts = np.array([n_text, p_text])
            random_texts = np.take(n_p_texts, polarity, axis=0).tolist()
            text2 = random_texts[0]
            text3 = random_texts[1]

            sample_input_ids, sample_attention_mask, sample_token_type_ids, sample_position_ids = create_sample(text1, text2, text3)
            label = polarity
            
            data_x_input_ids.append(sample_input_ids)
            data_x_attention_masks.append(sample_attention_mask)
            data_x_token_type_ids.append(sample_token_type_ids)
            data_x_position_ids.append(sample_position_ids)
            data_y.append(label)
            
        return ([np.array(data_x_input_ids), np.array(data_x_attention_masks), np.array(data_x_token_type_ids), np.array(data_x_position_ids)], np.array(data_y))


    
class ValidationDataGenerator(tf.keras.utils.Sequence):
    def __init__(self):
        self.rel_pairs = list(relevant)
        self.num_rel_pairs = len(self.rel_pairs)
        self.len = 25
        
    def __len__(self):
        return self.len

    def __getitem__(self, index):
        data_x_input_ids = []
        data_x_attention_masks = []
        data_x_token_type_ids = []
        data_x_sample_position_ids = []
        data_y = []
        for _ in range(BATCH_SIZE):
            random_pair_idx = np.random.randint(self.num_rel_pairs, size=1).item()
            random_pair = self.rel_pairs[random_pair_idx]
            polarity = np.random.choice(2, 2, replace=False)
              
            while True:
                n_random_doc_id = np.random.randint(num_docs, size=1).item() + 1
                if not (random_pair[0], documents[n_random_doc_id]) in relevant:
                    break

            p_text = random_pair[1].preprocessed  
            n_text = documents[n_random_doc_id].preprocessed
        
            p_random_words, p_words, n_words = select_random_words_from_texts(p_text, n_text)
           
            text1 = p_random_words
            p_text = p_words
            n_text = n_words 
        
            n_p_texts = np.array([n_text, p_text])
            random_texts = np.take(n_p_texts, polarity, axis=0).tolist()
            text2 = random_texts[0]
            text3 = random_texts[1]
            
            sample_input_ids, sample_attention_mask, sample_token_type_ids, sample_position_ids = create_sample(text1, text2, text3)
            label = polarity
            
            data_x_input_ids.append(sample_input_ids)
            data_x_attention_masks.append(sample_attention_mask)
            data_x_token_type_ids.append(sample_token_type_ids)
            data_x_sample_position_ids.append(sample_position_ids)
            data_y.append(label)
            
        return ([np.array(data_x_input_ids), np.array(data_x_attention_masks), np.array(data_x_token_type_ids), np.array(data_x_sample_position_ids)], np.array(data_y))

In [0]:
def create_sample_part(text, token_type_id, position_id):
    set(flatten_text(text))
    words = []
    input_ids = []
    attention_mask = []
    token_type_ids = []
    position_ids = []
    for sentence in text:
        for w in sentence:
            if w in words_to_ids.keys():
                words.append(w)
                input_ids.append(words_to_ids[w])
                attention_mask.append(1)
                token_type_ids.append(token_type_id)
                position_ids.append(position_id)
    words.append('SEP')
    input_ids.append(words_to_ids['SEP'])
    attention_mask.append(1)
    token_type_ids.append(token_type_id)
    position_ids.append(position_id)
    return words, input_ids, attention_mask, token_type_ids, position_ids


def create_sample(text1, text2, text3):
    sample_words = ['CLS']
    sample_input_ids = [words_to_ids['CLS']]
    sample_attention_mask = [1]
    sample_token_type_ids = [0]
    sample_position_ids = [0]

    part1 = create_sample_part(text1, token_type_id=0, position_id=0)
    part2 = create_sample_part(text2, token_type_id=1, position_id=0)
    part3 = create_sample_part(text3, token_type_id=2, position_id=0)
    
    n_pad_1 = BERT_MAX_SEQ_LEN - (len(part1[0]) + len(part2[0]) + 1)
    n_pad_2 = BERT_MAX_SEQ_LEN - (len(part1[0]) + len(part3[0]) + 1)

    zero_pad_1 = np.zeros(n_pad_1, dtype='int32')
    zero_pad_2 = np.zeros(n_pad_2, dtype='int32')
    
    sample_words_1 = np.concatenate([np.array(sample_words + part1[0] + part2[0]), zero_pad_1])
    sample_input_ids_1 = np.concatenate([np.array(sample_input_ids + part1[1] + part2[1]), zero_pad_1])
    sample_attention_mask_1 = np.concatenate([np.array(sample_attention_mask + part1[2] + part2[2]), zero_pad_1])
    sample_token_type_ids_1 = np.concatenate([np.array(sample_token_type_ids + part1[3] + part2[3]), zero_pad_1])
    sample_position_ids_1 = np.concatenate([np.array(sample_position_ids + part1[4] + part2[4]), zero_pad_1])
    
    sample_words_2 = np.concatenate([np.array(sample_words + part1[0] + part3[0]), zero_pad_2])
    sample_input_ids_2 = np.concatenate([np.array(sample_input_ids + part1[1] + part3[1]), zero_pad_2])
    sample_attention_mask_2 = np.concatenate([np.array(sample_attention_mask + part1[2] + part3[2]), zero_pad_2])
    sample_token_type_ids_2 = np.concatenate([np.array(sample_token_type_ids + part1[3] + part3[3]), zero_pad_2])
    sample_position_ids_2 = np.concatenate([np.array(sample_position_ids + part1[4] + part3[4]), zero_pad_2])
    
    sample_words = np.stack([sample_words_1, sample_words_2])
    sample_input_ids = np.stack([sample_input_ids_1, sample_input_ids_2])
    sample_attention_mask = np.stack([sample_attention_mask_1, sample_attention_mask_2])
    sample_token_type_ids = np.stack([sample_token_type_ids_1, sample_token_type_ids_2])
    sample_position_ids = np.stack([sample_position_ids_1, sample_position_ids_2])
        
    return sample_input_ids, sample_attention_mask, sample_token_type_ids, sample_position_ids


def select_random_words_from_texts(p_text, n_text):
    num_words_to_select = np.random.choice(range(MIN_NUM_WORDS_TO_SELECT, MAX_NUM_WORDS_TO_SELECT), 1).item()
    p_words = flatten_text(p_text)
    n_words = flatten_text(n_text)
    
    p_n_words = list(set(p_words) & set(n_words))
    p_unique_words = list(set(p_words) - set(p_n_words))
    n_unique_words = list(set(n_words) - set(p_n_words))
    num_words_to_select = np.min([num_words_to_select, len(p_unique_words)//2, len(n_unique_words)//2])

    shuffle(p_unique_words)
    p_random_words = p_unique_words[:num_words_to_select]
    
    shuffle(n_unique_words)
    n_random_words = n_unique_words[:num_words_to_select]

    p_filtered_text = filter_text(p_text, p_random_words)
    p_filtered_text_fl = flatten_text(p_filtered_text)
    shuffle(p_filtered_text_fl)
    p_words = p_filtered_text_fl[:num_words_to_select]
      
    return [p_random_words], [p_words], [n_random_words]



def prepare_doc_batch(query, docs):
    data_x_input_ids = []
    data_x_attention_masks = []
    data_x_token_type_ids = []
    data_x_sample_position_ids = []
    
    for i in range(0, len(docs)-1, 2):             
        p_words = flatten_text(docs[i].preprocessed)
        n_words = flatten_text(docs[i+1].preprocessed)
    
        p_n_words = list(set(p_words) & set(n_words))
        p_unique_words = list(set(p_words) - set(p_n_words))
        n_unique_words = list(set(n_words) - set(p_n_words))
          
        text1 = query.preprocessed
        text2 = [p_unique_words]
        text3 = [n_unique_words]
        
        sample_input_ids, sample_attention_mask, sample_token_type_ids, sample_position_ids = create_sample(text1, text2, text3)
        
        data_x_input_ids.append(sample_input_ids)
        data_x_attention_masks.append(sample_attention_mask)
        data_x_token_type_ids.append(sample_token_type_ids)
        data_x_sample_position_ids.append(sample_position_ids)
        
    return [np.array(data_x_input_ids), np.array(data_x_attention_masks), np.array(data_x_token_type_ids), np.array(data_x_sample_position_ids)]
      
    
    
def swap_docs_based_on_predictions(raw_doc_batch, predictions):
    reranked_docs = raw_doc_batch   
    for i in range(predictions.shape[0]):
        if predictions[i][0] + 0.2 < predictions[i][1]:
            t = reranked_docs[2*i]
            reranked_docs[2*i] = reranked_docs[2*i+1]
            reranked_docs[2*i+1] = t            
    return reranked_docs
        
    
    
def rerank_docs_based_on_query(query, fs_ranked_documents):   
    docs_to_rerank = fs_ranked_documents[4:44]
    prepared_batch = prepare_doc_batch(query, docs_to_rerank)
    predictions = model.predict(prepared_batch, batch_size=BATCH_SIZE)
    reranked_docs = swap_docs_based_on_predictions(docs_to_rerank, predictions)
    fs_ranked_documents[4:44] = reranked_docs 
    return fs_ranked_documents  

<br><br>
<h3 align="center"><i>Now, the most interesting part - we define the model:</i><br/><br></h3>
<p style="font-size:120%;">Model's inputs have the following structure: <br/><br></p>
<ul style="font-size:120%;">
<li><i>input_ids</i>: each token has its unique <i>id</i>, this <i>id</i> is used in the <i>BERT</i> <i>embedding layer</i> to hash the integer index to the numeric vector representing the token as the float-pointing numbers</li><br/>
<li><i>attention_mask</i>: model always takes the input of predefined length; if the documents are shorter than this predefined length - they are padded with zeros; attention mask then indicates which tokens belong to the documents and which tokens (padded zeros) were appended to match the required input length</li><br/>
<li><i>token_type_ids</i>: indicate whether token correnspond to the first or to the second input document</li><br/>
<li><i>position_ids</i>: encode the order of words in documents</li><br/></ul></p>

<p style="font-size:120%;">
Model predicts <i>2</i> softmax scores: probability that the words were extracted from the first document and the probability that the words were extraceted from the second one<br><br></p>
<br><br>

In [0]:
def create_model(bert_config):
    num_choices = 2
    bert = TFBertForMultipleChoice(bert_config)
    input_ids = tf.keras.Input(shape=(num_choices, BERT_MAX_SEQ_LEN,), dtype=tf.int32, name='input_ids')
    attention_mask = tf.keras.Input(shape=(num_choices, BERT_MAX_SEQ_LEN,), dtype=tf.int32, name='attention_mask')
    token_type_ids = tf.keras.Input(shape=(num_choices, BERT_MAX_SEQ_LEN,), dtype=tf.int32, name='token_type_ids')
    position_ids = tf.keras.Input(shape=(num_choices, BERT_MAX_SEQ_LEN,), dtype=tf.int32, name='position_ids')
    
    
    logits = bert([input_ids, attention_mask, token_type_ids, position_ids])[0]
    output = tf.keras.layers.Softmax(input_shape=(num_choices,))(logits)
    model = tf.keras.Model([input_ids, attention_mask, token_type_ids, position_ids], output)
    
    
    model.summary()
    return model

<br><br>
<h3 align="center"><i>Configure the model for training:<br><br></i></h3>
<p style="font-size:120%;">Note, that we train the model from <i>scratch</i>. Training process may be unstable, so:<br><br></p>
<ul style="font-size:120%;"> 
<li>Use the batch size at least <i>128</i> (if the batch <i>does not</i> fit to memory - you will get the <i>OOM error</i> - just <i>decrease</i> the batch size)</li><br/>
<li>Use the <i>lower</i> learning rate (<i>i.e. 3e-6</i>)</li><br/><br/></ul>


<p style="font-size:120%;">You can experiment with different settings: <br/><br></p>
<ul style="font-size:120%;">
<li>Increase the batch size $n$ times means the reduction of <i>gradient variance</i> by factor $\sqrt{n}$. In other words, we <i>do not</i> benefit from the <i>too high</i> batch sizes - they still <i>do not</i> provide the true estimation of <i>gradient</i>, but significantly <i>slow down</i> the training. On the other hand, with the <i>too small</i> batch size the model can simply <i>not to converge</i> at all</li><br/>
<li>If you decrese the batch size $n$ times, you have to decrese the learning rate by factor $\sqrt{n}$ to preserve the stability of learning process</li><br/>
<li><i>Epochs</i> is just an arbitrary high number. Model weights are saved after <i>each</i> epoch through the <i>ModelCheckpoint</i> callback, so feel free to stop the learning whenever you want (when the model will be <i>precise</i> enough)<br><br><br></li></ul>
    
    
<p style="font-size:120%;">Model whose performance is reported below was trained with the <i>initial</i> learning rate <i>1e-5</i> to achieve the faster convergence at the <i>initial</i> phase of training. Then, the learning rate was decreased to <i>3e-6</i> and the network continued to train during the <i>14</i> additional hours. The whole training process took approximately <i>18</i> hours (on the <i>weeny</i> but <i>proud</i> <i>GeForce GTX 850M</i> GPU). <br><br> Model has only <i>65k</i> parameters. Experiments with <i>higher learning rates</i> led to the model <i>divergence</i>; <i>smaller batch sizes</i> (<i>i.e. 8</i>) - <i>convergence</i> to the <i>local minimums</i>, <i>performans degradation</i> and the <i>disability</i> to return to the normal training process.<br><br></p>
    
    
<p style="font-size:120%;">Performance in the final phase of training: <br><br></p>

<img src="https://github.com/xkaple01/bert-information-retrieval/blob/master/bert_ir_system_gc/learning_process.gif?raw=1" align="middle">
<br><br><br>

In [0]:
BATCH_SIZE = 128
LEARNING_RATE = 0.000003
EPOCHS = 1000
MIN_NUM_WORDS_TO_SELECT = 75
MAX_NUM_WORDS_TO_SELECT = 76 
BERT_MAX_SEQ_LEN = 25 + MAX_NUM_WORDS_TO_SELECT + 4

<br><br>
<h3><i>Original BERT:</i></h3><br>
<ul style="font-size:120%;">
<li><i>12</i> attention layers</li><br/>
<li>Each token (english has approximately <i>30000</i> of the most commonly used words) is encoded as a vecor of <i>512</i> numbers</li><br/>
<li>Memory size (dimensionality of pointwise dense layers in encoder blocks) is <i>3072</i></li></ul><br/>
<p style="font-size:120%;">As a result, the model requires the <i>gigabytes</i> of textual data, <i>64</i> GPUs and a <i>week</i> of training. <br/><br/> We will design the <i>"tiny"</i> version of <i>BERT</i>. The model has to be powerful enough to capture all the hidden relations between words and at the same time has not to be overparametrized in order not to overfit hardly.</p><br/><br/>
    
<p style="font-size:120%;">Recommended settings are following:<br/><br/></p>
<ul style="font-size:120%;">
<li><i>2</i> (but better at least <i>3</i> attention layers - this, ofcourse, will slow down the training and increase the inference time)</li><br/>
<li>Hidden size is <i>32</i>: each token in our scientific Cranfield language will be represented as a vector of <i>32</i> numbers</li><br/>
<li>Intermediate (memory size) is at least <i>256</i>: we want to reliably capture the hidden relations between words</li><br/>
<li>Number of attention heads is just <i>1</i>. The dot products between such a small (<i>32</i>-dim) representations do not require the vector to be splitted into several parts and to be attended by multiple heads. But feel free to set the number of attention heads to <i>2</i> or <i>4</i></li>
<br><br>


In [0]:
bert_train_config = BertConfig(vocab_size=680, hidden_size=32, num_hidden_layers=2, num_attention_heads=1, intermediate_size=256, max_position_embeddings=3, type_vocab_size=3, training=True)

train_data_generator = TrainDataGenerator()
val_data_generator = ValidationDataGenerator()

model = create_model(bert_train_config)
opt = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

tb_callback = tf.keras.callbacks.TensorBoard(log_dir=PATH_TO_LOGS, histogram_freq=1, write_graph=True, write_images=False)
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(filepath=PATH_TO_MODEL, monitor='accuracy', verbose=1, save_best_only=False, save_weights_only=True, save_freq='epoch')


history = model.fit(x=train_data_generator, validation_data = val_data_generator, epochs=EPOCHS, callbacks=[tb_callback, ckpt_callback])
model.save_weights(PATH_TO_MODEL)

<br><br>
<p style="font-size:120%;">Note how the <i>validation loss</i> is <i>improved</i> as the training proceeds. The increasing <i>language understanding general ability</i> of the model leads to improvement in the reranking accuracy even though the model <b>does not have</b> any information about the <i>query - document</i> relevance.</p><br><br>

<h3 align="center"><i>It's time to test our model:</i><br/><br></h3>
<p style="font-size:120%;">Let's configure the inference phase:</p><br><br>

In [0]:
BATCH_SIZE = 1
BERT_MAX_SEQ_LEN = 512 
bert_inference_config = BertConfig(vocab_size=680, hidden_size=32, num_hidden_layers=2, num_attention_heads=1, intermediate_size=256, max_position_embeddings=3, type_vocab_size=3, training=False)

<br><br>
<p style="font-size:120%;">Alternatively, you can load the trained model in order to reproduce the final achieved result reported below</p><br><br>

In [16]:
PATH_TO_MODEL  = '/content/bert-information-retrieval/bert_ir_system_gc/trained_model/model_checkpoint'
model = create_model(bert_inference_config)
model.load_weights(PATH_TO_MODEL)

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_ids (InputLayer)          [(None, 2, 512)]     0                                            
__________________________________________________________________________________________________
attention_mask (InputLayer)     [(None, 2, 512)]     0                                            
__________________________________________________________________________________________________
token_type_ids (InputLayer)     [(None, 2, 512)]     0                                            
__________________________________________________________________________________________________
position_ids (InputLayer)       [(None, 2, 512)]     0                                            
____________________________________________________________________________________________

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd7c41bdfd0>

In [0]:
BATCH_SIZE = 128
LEARNING_RATE = 0.000003
EPOCHS = 1000
MIN_NUM_WORDS_TO_SELECT = 75
MAX_NUM_WORDS_TO_SELECT = 76 
BERT_MAX_SEQ_LEN = 25 + MAX_NUM_WORDS_TO_SELECT + 4
PATH_TO_MODEL  = '/content/bert-information-retrieval/bert_ir_system_gc/trained_model/model_checkpoint'
bert_train_config = BertConfig(vocab_size=680, hidden_size=32, num_hidden_layers=2, num_attention_heads=1, intermediate_size=256, max_position_embeddings=3, type_vocab_size=3, training=True)

train_data_generator = TrainDataGenerator()
val_data_generator = ValidationDataGenerator()

model = create_model(bert_train_config)
model.load_weights(PATH_TO_MODEL)
opt = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

tb_callback = tf.keras.callbacks.TensorBoard(log_dir=PATH_TO_LOGS, histogram_freq=1, write_graph=True, write_images=False)
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(filepath=PATH_TO_MODEL, monitor='accuracy', verbose=1, save_best_only=False, save_weights_only=True, save_freq='epoch')


history = model.fit(x=train_data_generator, validation_data = val_data_generator, epochs=EPOCHS, callbacks=[tb_callback, ckpt_callback])
model.save_weights(PATH_TO_MODEL)

In [0]:
class BertIRSystem(IRSystem):
    def __init__(self):
        self.documents = list(documents.values())
        self.corpus = [flatten_text(doc.preprocessed) for doc in self.documents]
        self.bm25 = BM25Plus(self.corpus)

        
    def first_stage_ranking(self, preprocessed_query):
        scores = self.bm25.get_scores(preprocessed_query.preprocessed[0])
        doc_list_ids = np.argsort(scores)[::-1]
        return np.take(self.documents, doc_list_ids, axis=0).tolist()

                
    def search(self, query):
        preprocessed_query = queries[query.query_id]
        fs_ranked_docs = self.first_stage_ranking(preprocessed_query)

        reranked_docs = rerank_docs_based_on_query(preprocessed_query, fs_ranked_docs)
            
        return reranked_docs  

In [18]:
mean_average_precision(BertIRSystem(), submit_result=False, author_name="Kaplenko, Mykola")

Mean average precision: 42.761% 
Not submitted.


<br><br>
<p style="font-size:120%;">Note the increase in MAP after the document reranking by our second-stage model. The increase is not too high but:<br/><br/>
<ul style="font-size:120%;">
<li>The reranking was applied only once to pairs of neighbouring documents (to achieve the reasonable query response time). The higher number of pairwise rerankings significantly slows down the document retrieval</li><br/>
<li>Model was trained from scratch. We did no fine-tuning and did not use any huge model pretrained by <i>Google</i> on <i>gigabytes</i> of data</li><br/> 
<li>Our model was trained on general language understanding task (i.e. document prediction based on extracted words) and have never used the relevance labels. In contrast, the real information retrieval systems always use the extreamly huge datasets of queries and the human-defined document relevance labels</li></ul>
<br>
<p style="font-size:120%;">It will be interesting to change the project slightly in the following years by providing the train and the validation dataset such that we can train our models for the primary tasks they will be used for.</p><br>
<hr>

<br/>
<p align="center" style="font-size:120%;text-align:center;">Thank you for <i>attention</i>. I hope you have enjoyed our small information retrieval challenge.<br/><br>