In [6]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np

from tqdm import trange

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering, DistilBertModel
from transformers import AutoTokenizer

In [4]:
news = pd.read_csv("/kaggle/input/covidqa/news.csv", usecols=["question", "answer"])
community = pd.read_csv("/kaggle/input/covidqa/community.csv", usecols=["title", "answer"])
community = community.rename(columns={"title": "question"})
data = pd.concat([news, community], axis=0)

data.head()

Unnamed: 0,question,answer
0,What are the symptoms?,"Symptoms include fever, coughing, sore throat,..."
1,When should I get tested?,Your doctor will tell you if you need to get t...
2,What's the difference between physical distanc...,"As cases of coronavirus surge, health authorit..."
3,How do I practice physical distancing?,If you have been in contact with a person with...
4,What's closed?,Physical distancing is the reason the Federal ...


In [7]:
x_train, x_valid = train_test_split(data, test_size=0.1, random_state=42)
x_train.reset_index(drop=True, inplace=True)
x_valid.reset_index(drop=True, inplace=True)

In [8]:
class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, dim=1)
        sum_mask = input_mask_expanded.sum(dim=1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        
        mean_embeddings = F.normalize(mean_embeddings, p=2, dim=1)
        return mean_embeddings

In [9]:
class QAEstimator:
    """ 
    input: str - question
    output: str - answer
    """
    def __init__(self, data):
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        self.feature_extractor = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.pooling = MeanPooling()
        self.data = data
        
        self.data.dropna(inplace=True)
        self.data.reset_index(drop=True, inplace=True)
        
        print("Initializing embeddings...")
        self.answers = []
        for i in trange(len(data)):
            answer = data.loc[i].squeeze()["answer"]
            answer = self._get_embedding(answer)
            self.answers.append(answer)
        self.answers = torch.cat(self.answers)
        print("Ready to answer.")
        
    def __call__(self, question: str, num_answers: int = 3) -> str:
        question = self._get_embedding([question])
        similarity = question @ self.answers.T
        similarity = similarity[0].topk(num_answers)
        
        answers = [self.data.loc[i.item()].squeeze()["answer"] for i in similarity[1]]
        
        return answers
        
    def _get_embedding(self, query):
        inputs = self.tokenizer(query,
                                max_length=100,
                                truncation=True,
                                return_tensors="pt")
        raw_embedding = self.feature_extractor(**inputs).last_hidden_state
        norm_embedding = self.pooling(raw_embedding, inputs["attention_mask"])
        return norm_embedding

In [10]:
qa = QAEstimator(news)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Initializing embeddings...


100%|██████████| 479/479 [00:58<00:00,  8.21it/s]

Ready to answer.





In [11]:
for answer in qa("What are the symptoms?", 10):
    print(answer.replace("\n", " "), "\n")

Will the hospitals remain open? Yes. Hospitals will operate normally. 

Symptoms show up 14 days after exposure to the virus. 

Yes, some spread might be possible before people show symptoms, according to the CDC. However, people are most contagious when they are most symptomatic. 

Yes, although people aged under 18 are less susceptible to the virus, according to the limited clinical reports available. 

Yes, testing positive means that you have the virus, but it does not mean that you will develop symptoms. Some people who have the virus don't have any symptoms at all.  At the same time, testing negative does not necessarily mean that you don't have the virus. 

 At the same time, testing negative does not necessarily mean that you don't have the virus. op symptoms. Some people who have the virus don't have any symptoms at all.

The most common symptoms reported are cough, fever, and shortness of breath. For some, these symptoms lead to respiratory distress and pneumonia. Other sympt