In [1]:
import os
import time
import torch
from pytorch_transformers import BertConfig, BertTokenizer, BertForQuestionAnswering
from pytorch_transformers import XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer

In [2]:
MODEL_CLASSES = {
    'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
    'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer)
}

In [3]:
class QuestionAnswering(object):
    def __init__(self, config_file, weight_file, tokenizer_file, model_type ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[model_type]
        self.config = self.config_class.from_json_file(config_file)
        self.model = self.model_class(self.config)
        self.model.load_state_dict(torch.load(weight_file, map_location=self.device))
        self.tokenizer = self.tokenizer_class(tokenizer_file)
        self.model_type = model_type
    
    def to_list(self, tensor):
        return tensor.detach().cpu().tolist()

    def get_reply(self, question, passage):
        self.model.eval()
        with torch.no_grad():
            input_ids, _ , tokens = self.prepare_features(question, passage)
            if self.model_type == 'bert':
                span_start,span_end= self.model(input_ids)
                answer = tokens[torch.argmax(span_start):torch.argmax(span_end)+1]
                answer = self.bert_convert_tokens_to_string(answer)
            elif self.model_type == 'xlnet':
                input_vector = {'input_ids': input_ids,
                                'start_positions': None,
                                'end_positions': None }
                outputs = self.model(**input_vector)
                answer = tokens[self.to_list(outputs[1])[0][torch.argmax(outputs[0])]:self.to_list(outputs[3])[0][torch.argmax(outputs[2])]+1]
                answer = self.xlnet_convert_tokens_to_string(answer)
        return answer
    
    def bert_convert_tokens_to_string(self, tokens):
        out_string = ' '.join(tokens).replace(' ##', '').strip()
        if '@' in tokens:
            out_string = out_string.replace(' ', '')
        return out_string

    def xlnet_convert_tokens_to_string(self, tokens):
        out_string = ''.join(tokens).replace('▁', ' ').strip()
        return out_string

    def prepare_features(self, question,  passage, max_seq_length = 300, 
                 zero_pad = False, include_CLS_token = True, include_SEP_token = True):
        tokens_a = self.tokenizer.tokenize(question)
        tokens_b = self.tokenizer.tokenize(passage)
        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[0:(max_seq_length - 2)]
        tokens = []
        if include_CLS_token:
            tokens.append(self.tokenizer.cls_token)
        for token in tokens_a:
            tokens.append(token)
        if include_SEP_token:
            tokens.append(self.tokenizer.sep_token)
        for token in tokens_b:
            tokens.append(token)
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)
        if zero_pad:
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
        return torch.tensor(input_ids).unsqueeze(0), input_mask, tokens

In [4]:
facts = " My wife is great. \
My project name is e-project. \
My e-mail roberto.dias@gmail.com \
I work for ADP. \
My wife is 33 years old. \
My complete name is Roberto Pereira Silveira. \
I am 40 years old. \
My wife was born in 1985. \
My wife is an urban planner. \
My dog is cool. \
My dog breed is jack russel. \
My dog was born in 2014.\
Best computer science university is Unisinos. \
My favorite city in RS is Sao Leopoldo. \
Best soccer team is Gremio. \
My dog name is Mallu."

In [5]:
!ls ../../models/bert

bert-base-uncased-vocab.txt
bert-large-cased-whole-word-masking-finetuned-squad-config.json
bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin
bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt
bert-large-uncased-vocab.txt
bert-large-uncased-whole-word-masking-finetuned-squad-config.json
bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin


In [6]:
bert_big = QuestionAnswering(
    config_file =   '../../models/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json',
    weight_file=    '../../models/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin',
    tokenizer_file= '../../models/bert/bert-base-uncased-vocab.txt',
    model_type =    'bert'
)

In [7]:
question =  "What is my e-mail?"   
bert_big.get_reply(question, facts)

'roberto.dias@gmail.com'

In [8]:
question =  "What is my dog name?"   
bert_big.get_reply(question, facts)

'mallu'

In [9]:
question =  "What is my age?"   
bert_big.get_reply(question, facts)

'40 years old'

In [10]:
question =  "What is best CS university?"   
bert_big.get_reply(question, facts)

'unisinos'

In [11]:
question =  "What my favorite city?"   
bert_big.get_reply(question, facts)

'sao leopoldo'

In [12]:
question =  "What is best soccer team?"   
bert_big.get_reply(question, facts)

'gremio'