In [2]:
!pip install wikipedia
!pip install googlesearch-python
!pip install requests
!pip install transformers==4.2.2
!pip install tensorflow==2.4.1



In [3]:
import os
import re
import json
import itertools

import wikipedia
import requests

import numpy as np
import pandas as pd
import tensorflow as tf

from concurrent.futures import ThreadPoolExecutor
from pprint import pprint
from googlesearch import search
from bs4 import BeautifulSoup
from tqdm.notebook import tqdm
from transformers import TFAutoModel, AutoTokenizer, AutoConfig

In [4]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased', use_fast=True)

In [5]:
BASE_DIR = 'drive/MyDrive'
DO_TRAIN = False

In [6]:
class QARecord:
    def __init__(self, question, context, is_impossible, answer_text=None, answer_start=-1):
        self.question = question
        self.context = context
        self.is_impossible = is_impossible
        self.answer_text = answer_text
        self.answer_start = answer_start
        if not is_impossible:
            self.answer_end = answer_start + len(answer_text)
        else:
            self.answer_end = -1

In [7]:
class QADataset:
    SQUAD_DIR = os.path.join(BASE_DIR, 'datasets/squad20')
    
    def __init__(self):
        pass
    
    def _record_generator(self, filename):
        with open(os.path.join(self.SQUAD_DIR, filename)) as f:
            d = json.load(f)
        for x in d['data']:
            for p in x['paragraphs']:
                context = p['context']
                qas = p['qas']
                for q in qas:
                    question = q['question']
                    is_impossible = q['is_impossible']
                    answer, answer_start = None, -1

                    if not is_impossible:
                        answer = q['answers'][0]['text']
                        answer_start = q['answers'][0]['answer_start']

                    yield QARecord(question, context, is_impossible, answer, answer_start)
    
    def train(self, **kwargs):
        return self._make(self._record_generator('train-v2.0.json'), **kwargs)
    
    def dev(self, **kwargs):
        return self._make(self._record_generator('dev-v2.0.json'), **kwargs)
    
    def test(self, **kwargs):
        pass
    
    def predict(self, q, c):
        return self._make([QARecord(q, c, True)])
    
    def _make(self, records, shuffle=False, drop_remainder=False):
        MAX_LENGTH = 128
        def generator():
            for record in records:
                encoding = tokenizer.encode_plus(
                    record.question,
                    record.context,
                    max_length=MAX_LENGTH,
                    truncation=True,
                    padding='max_length',
                    return_offsets_mapping=True,
                    return_overflowing_tokens=True,
                    stride=16
                )

                input_ids = encoding['input_ids']
                attention_masks = encoding['attention_mask']
                token_type_ids = encoding.get('token_type_ids', itertools.repeat(None))
                offset_mappings = encoding['offset_mapping']

                for input_id, attention_mask, token_type_id, offset_mapping in zip(input_ids, attention_masks, token_type_ids, offset_mappings):
                    if record.is_impossible:
                        start, end = 0, 0
                    else:
                        try:
                            start = [i for i, x in enumerate(offset_mapping) if x[0]==record.answer_start][0]
                            end = [i for i, x in enumerate(offset_mapping) if x[1]==record.answer_end][0]
                        except IndexError:
                            start, end = 0, 0

                    yield {
                        'input_ids': input_id,
                        'attention_mask': attention_mask,
                        #'token_type_ids': token_type_id,
                    }, [start, end]

        dataset = tf.data.Dataset.from_generator(
            generator,
            output_types=( 
                {
                    'input_ids': tf.int32,
                    'attention_mask': tf.int32,
                    #'token_type_ids': tf.int32,
                },
                tf.int32
            ),
            output_shapes=(
                {
                    'input_ids': (MAX_LENGTH,),
                    'attention_mask': (MAX_LENGTH,),
                    #'token_type_ids': (MAX_LENGTH,)
                },
                (2,)
            )
        )

        if shuffle:
            dataset.shuffle(10000)
        return dataset.batch(32, drop_remainder=drop_remainder).prefetch(tf.data.AUTOTUNE)

In [8]:
class QAModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        config = AutoConfig.from_pretrained(
            'distilbert-base-uncased', output_attentions=False, output_hidden_states=False
        )
        self.bert = TFAutoModel.from_pretrained('distilbert-base-uncased', config=config) 
        self.dropout = tf.keras.layers.Dropout(0.1)
        self.dense = tf.keras.layers.Dense(2, dtype=tf.float32)
    
    def call(self, inputs, training=False):
        res = self.bert(inputs, training=training)
        seq = res['last_hidden_state']
        x = self.dropout(seq)
        x = self.dense(x)
        return x

In [9]:
def QAloss(labels, logits):
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True
    )
    start_logits, end_logits = logits[:, :, 0], logits[:, :, 1]
    start_loss = loss_fn(labels[:, 0], start_logits)
    end_loss = loss_fn(labels[:, 1], end_logits)
    return (start_loss + end_loss)/2

In [10]:
model_checkpoint_callback  = tf.keras.callbacks.ModelCheckpoint(
    os.path.join(BASE_DIR, 'models/QA/QAModel'),
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=True, mode='auto'
)

In [11]:
dataset = QADataset()

In [12]:
model = QAModel()
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5),
    loss=QAloss,
)

Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertModel: ['activation_13', 'vocab_layer_norm', 'vocab_projector', 'vocab_transform']
- This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


In [13]:
if DO_TRAIN:
    history = model.fit(
      dataset.train(shuffle=True, drop_remainder=True), epochs=1, validation_data=dataset.dev(),
      callbacks=[model_checkpoint_callback]
    )
else:
    model.load_weights(os.path.join(BASE_DIR, 'models/QA/QAModel'))

In [14]:
def get_url_content(url):
    try:
        return requests.get(url, timeout=10).text
    except Exception as e:
        pass

def get_context(question, source='wiki', top_n=5):
    if source == 'wiki':
        for st in wikipedia.search(question)[:top_n]:
            yield wikipedia.page(st).content
    elif source == 'google':
        urls = search(question)[:top_n]
        with ThreadPoolExecutor() as pool:
            for text in list(pool.map(get_url_content, urls)):
                if text:
                    soup = BeautifulSoup(text)
                    yield re.sub(r'[\n]+', '\n', soup.get_text())

In [15]:
def _get_offset(input_ids):
    sep = np.where(input_ids == 102)[0]
    offset = sep[0] + 1
    try:
        till = sep[1]
    except IndexError:
        till = len(input_ids)
    return offset, till

def _get_answer(dp, top_n=1):
    
    dp_a = dp.take(-1)
    p_a = model.predict(dp_a)
    
    for idx, d in enumerate(iter(dp.unbatch().batch(1))):
        p = p_a[[idx], :, :]
        
        top_n_temp = top_n
        input_ids = d[0]['input_ids'].numpy()[0]

        offset, till = _get_offset(input_ids)

        start_proba, end_proba = p[:, :, 0][0], p[:, :, 1][0]
        
        mask = [0 if offset <= x < till else 1e-8 for x in range(len(start_proba))]
        mask[0] = 0
        
        start_proba = start_proba + mask
        end_proba = end_proba + mask
        
        start_proba = tf.nn.softmax(start_proba).numpy()
        end_proba = tf.nn.softmax(end_proba).numpy()
        
        no_answer_score = start_proba[0] * end_proba[0]
        
        start_proba = start_proba[offset:till]
        end_proba = end_proba[offset:till]
        
        results = []
        
        for i,s in enumerate(start_proba):
            for j,e in enumerate(end_proba):
                if i>j:
                    continue
                results.append({
                    'start': offset + i,
                    'end': offset + j,
                    'score': s*e
                })
        
        results = sorted(results, key=lambda x: x['score'], reverse=True)
        results = [x for x in results if x['score']>no_answer_score]
        
        for r in results:
            r['text'] = tokenizer.decode(input_ids[r['start']:r['end']+1])
            r['context'] = tokenizer.decode(input_ids[r['start']-15:r['end']+1+15])

        yield results[:top_n]

In [22]:
def bert_ama(question, source='google', max_urls=5, top_n=None, thresh=0):
    """
    question: The question you want to ask
    source: wiki or google. No need to change. google works better.
    max_urls: Maximum number of search results where we look for answers.
    top_n: Number of esults returned.
    thresh: Only answers which have a score greater than thresh will be consider further.
    """
    results = []
    for context in get_context(question, source, top_n=max_urls):
        dp = dataset.predict(question, context)

        for result in _get_answer(dp, top_n):
            results.extend(result)
    
    df = pd.DataFrame(results)
    if df.empty:
      return df

    def scorer(arr):
        """Calculate final score for a givne answer
        
        Total score = A[0] + A[1]/4 + A[2]/9 + A[3]/16 ....
        A[i] is the i'th score in descending order for a particuar answer 
        """
        return sum(
            x/i for i,x in enumerate(sorted(arr, reverse=True), start=1)
        )
  
    df = df.sort_values(by='score', ascending=False)
    if thresh:
      df = df[df['score'] > thresh]
    df = df.groupby('text').agg({
        'score': scorer, 'context': list
    })
    df = df.reset_index()
    df = df.sort_values(by='score', ascending=False)
    if top_n:
        df = df.head(top_n)
    return df

In [35]:
bert_ama('Who is the founder of google?', source='google', max_urls=5, top_n=1)

Unnamed: 0,text,score,context
2,larry page,0.638208,[— about a month after donald j. trump was ele...


In [36]:
bert_ama('What is the chemical formula of benzene?', source='google', max_urls=5, top_n=1)

Unnamed: 0,text,score,context
1,c6h6,1.329239,[is as shown in the figure below. the chemical...


In [37]:
bert_ama('What is the capital of India?', source='google', max_urls=5, top_n=1)

Unnamed: 0,text,score,context
1,new delhi,1.207674,[; malvika singh ; rudrangshu mukherjee ( 2009...


In [38]:
bert_ama('How many carbon atoms does buckminsterfullerene have?', source='google', max_urls=5, top_n=1)

Unnamed: 0,text,score,context
1,sixty,0.801346,[is slippery and has a low melting point. buck...


In [40]:
bert_ama('when is the independence day celebrated in india?', source='google', max_urls=5, top_n=1)

Unnamed: 0,text,score,context
4,15th of august,1.086717,"[independence day of india, which is celebrate..."


In [41]:
bert_ama('what is the answer to the ultimate question of life, the universe, and everything?', source='google', max_urls=5, top_n=1)

Unnamed: 0,text,score,context
1,42,1.347246,[slang dictionary 42 [ fawr - tee too ] what d...
