<a href="https://colab.research.google.com/github/sushi15/Wiki-QA/blob/main/wiki_qa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
# General downloads
!pip install transformers datasets
!pip install wptools
!pip install wikipedia 

import itertools 
import os 
import numpy 
import re 

# HuggingFace Transformers
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline 
import tensorflow as tf
import spacy 

import nltk 
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger') 

# MediaWiki API 
import wptools 
import wikipedia 

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity 
from sklearn.metrics.pairwise import linear_kernel 

Collecting transformers
  Downloading transformers-4.10.2-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 5.2 MB/s 
[?25hCollecting datasets
  Downloading datasets-1.12.0-py3-none-any.whl (269 kB)
[K     |████████████████████████████████| 269 kB 29.5 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 51.8 MB/s 
[?25hCollecting huggingface-hub>=0.0.12
  Downloading huggingface_hub-0.0.17-py3-none-any.whl (52 kB)
[K     |████████████████████████████████| 52 kB 1.5 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 43.5 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 47.1 MB/s 
Collecting fsspe

In [None]:
# Model trained on the SQuAD 2.0 dev set 
model_name = "deepset/roberta-base-squad2"

qa_final = pipeline('question-answering', model = model_name, tokenizer = model_name) 

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/496M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

## Modules

In [None]:
def getKeywords(question): 
    tagged = nltk.pos_tag(nltk.word_tokenize(question)) 
    print(tagged) 

    # The NLTK POS Tagger follows the Penn Treebank Project tag conventions 
    # Only the following kinds of words are extracted from the query as keywords 
    limit = ['FW', 'JJ', 'JJS', 'JJR', 'NN', 'NNS', 'NNP', 'NNPS', 'SYM'] 
    keywords = ' '.join(i[0] for i in tagged if i[1] in limit) 
    print(keywords) 
    return keywords 

In [None]:
def retrieveDocs(keywords): 
    wiki_search = wikipedia.search(keywords) 
    print("Wiki search results:") 
    print(wiki_search) 
    print("\n" + '='*(20) + "\n")
    documents = []
    documentTitles = []
    for i in wiki_search: 
        page = wptools.page(str(i))
        page.get_parse() 
        # print(str(i), page.data['pageid']) 
        try: 
            content = wikipedia.page(pageid=page.data['pageid']) 
        except: 
            continue
        documentTitles.append(str(i))
        res = cleanDoc(content.content)
        documents.append(res) 
    print("Entries considered:") 
    print(documentTitles) 
    print("\n" + '='*(20) + "\n")
    return documents 

In [None]:
def cleanDoc(content):
      headings_to_remove = ['== Further reading ==', '== Further references ==', '=== Citations ===', '== References ==', '== Footnotes ==', 
                            '=== Notes ===', '== Notes ==', '=== Sources ===', '== Sources ==', '== External links', '== See also ==', ]
      headings_to_remove = '|'.join(headings_to_remove) 
      inds = [m.start() for m in re.finditer(headings_to_remove, content)]
      # print(inds) 
      if len(inds) != 0: 
          mini = min(inds) 
          mini = min(mini, len(content)) 
      else: 
          mini = len(content)
      # print(mini)
      return content[:mini]

In [None]:
def splitDocs(question, documents): 
    passages = [question]
    for i in documents: 
        curr_passages = [p for p in i.split('\n') if p and not p.startswith('=')] 
        passages += curr_passages 
    return passages 

In [None]:
# def retrievePassages(passages): 
#     tfidf = TfidfVectorizer().fit_transform(passages) 
#     cosSims = linear_kernel(tfidf[0:1], tfidf).flatten()
#     # print(cosSims) 
#     passageInds = cosSims.argsort()[:-12:-1]
#     print("Indices of most relevant passages: ")
#     print(passageInds[1:]) 
#     print("\n" + '='*(20) + "\n")
#     return passageInds 

# def printRelevantPassages(passages, passageInds): 
#     print("Most relevant passages: ")
#     for i in range(1, len(passageInds)): 
#         print(passages[passageInds[i]]) 
#     print("\n" + '='*(20) + "\n") 

# def getAnswers(passages, passageInds): 
#     possibleAnswers = []
#     for i in range(1, len(passageInds)): 
#         possibleAnswers.append(qa_final(question = passages[0], context = passages[passageInds[i]])) 
#     # print(possibleAnswers) 
#     possibleAnswers = sorted(possibleAnswers, key = lambda i: i['score']) 
#     return possibleAnswers 

In [None]:
# def retrievePassages(question, documents): 
#     passages = {}
#     for i in documents: 
#         curr_passages = [p for p in i.split('\n') if p and not p.startswith('=')] 
#         curr_passages.insert(0, question) 
#         tfidf = TfidfVectorizer().fit_transform(curr_passages) 
#         cosSims = linear_kernel(tfidf[0:1], tfidf).flatten()
#         print(cosSims) 
#         passageInds = cosSims.argsort()[:-12:-1] 
#         print(cosSims) 
#         print(passageInds)
#         for i in range(1, len(passageInds)): 
#             passages[curr_passages[passageInds[i]]] = cosSims[i]
#         # passages.append(curr_passages[passageInds[-1]]) 
#     return passages 

def retrievePassages(question, documents): 
    passages = {}
    for i in documents: 
        curr_passages = [p for p in i.split('\n') if p and not p.startswith('=')] 
        curr_passages.insert(0, question) 
        tfidf = TfidfVectorizer().fit_transform(curr_passages) 
        cosSims = linear_kernel(tfidf[0:1], tfidf).flatten() 
        curr_passages = dict(zip(curr_passages, cosSims)) 
        curr_passages = dict(sorted(curr_passages.items(), key = lambda item: item[1], reverse = True)) 
        p = list(curr_passages.keys()) 
        s = list(curr_passages.values()) 
        i = 1 
        while i < len(p) and i < 10: 
          passages[p[i]] = s[i] 
          i += 1

    print(passages) 
    passages = dict(sorted(passages.items(), key = lambda item: item[1], reverse = True)) 
    print(passages.values())
    passages = list(passages.keys())[:10] 
    print(passages) 
    return passages 

In [None]:
def printRelevantPassages(passages): 
    print("Most relevant passages: ")
    for i in passages: 
        print(i) 
    print("\n" + '='*(20) + "\n") 

In [None]:
def getAnswers(question, passages): 
    possibleAnswers = []
    for i in passages: 
        possibleAnswers.append(qa_final(question = question, context = i)) 
    # print(possibleAnswers) 
    possibleAnswers = sorted(possibleAnswers, key = lambda i: i['score']) 
    return possibleAnswers 

In [None]:
def printAllAnswers(possibleAnswers): 
    print("Possible answers sorted by confidence rating: ")
    for i in range(len(possibleAnswers) - 1, -1, -1): 
        print(str(len(possibleAnswers) - 1 - i + 1) + '.' + possibleAnswers[i]['answer'] + ':' + str(possibleAnswers[i]['score'])) 
    print("\n" + '='*(20) + "\n") 

## System

In [None]:

##### 
# Type either keywords only or the entire question itself 
# Does not yet work for yes/no questions like "Is Australia a Continent?" 
# Model used to derive answers from context will be modified to increase accuracy, and NLG for the answer will also be tried out 
# Example questions that work: What is the capital of Assam?, Who is the Greek goddess of Wisdom?, Where is Addis Ababa? 
# Example questions that don't work: Who played Harley Quinn in the Suicide Squad?, What is a binary search tree? Who is the CEO of Apple? 
#####

question = input("Enter question: ") 

Enter question: What is a binary search tree


In [None]:
# keywords = getKeywords(question) 
# documents = retrieveDocs(keywords) 
# passages = splitDocs(question, documents) 
# passageInds = retrievePassages(passages) 
# printRelevantPassages(passages, passageInds) 
# possibleAnswers = getAnswers(passages, passageInds) 
# printAllAnswers(possibleAnswers) 

In [None]:
keywords = getKeywords(question) 
documents = retrieveDocs(keywords) 
passages = retrievePassages(question, documents) 
printRelevantPassages(passages) 
possibleAnswers = getAnswers(question, passages) 
printAllAnswers(possibleAnswers) 

[('What', 'WP'), ('is', 'VBZ'), ('a', 'DT'), ('binary', 'JJ'), ('search', 'NN'), ('tree', 'NN')]
binary search tree
Wiki search results:
['Binary search tree', 'Self-balancing binary search tree', 'Binary tree', 'Binary search algorithm', 'Optimal binary search tree', 'Tree traversal', 'Search tree', 'Splay tree', 'Red–black tree', 'Treap']




en.wikipedia.org (parse) Binary search tree
Binary search tree (en) data
{
  infobox: <dict(4)> name, type, invented_by, invented_year
  iwlinks: <list(1)> https://commons.wikimedia.org/wiki/Category:B...
  pageid: 4320
  parsetree: <str(54977)> <root><template><title>short description...
  requests: <list(1)> parse
  title: Binary search tree
  wikibase: Q623818
  wikidata_url: https://www.wikidata.org/wiki/Q623818
  wikitext: <str(42014)> {{short description|Data structure in tre...
}
en.wikipedia.org (parse) Self-balancing binary search tree
Self-balancing binary search tree (en) data
{
  pageid: 378310
  parsetree: <str(9144)> <root><template><title>short description<...
  requests: <list(1)> parse
  title: Self-balancing binary search tree
  wikibase: Q245955
  wikidata_url: https://www.wikidata.org/wiki/Q245955
  wikitext: <str(7947)> {{short description|Any node-based binary ...
}
en.wikipedia.org (parse) Binary tree
Binary tree (en) data
{
  iwlinks: <list(1)> https://commons.w

Entries considered:
['Binary search tree', 'Self-balancing binary search tree', 'Binary tree', 'Binary search algorithm', 'Optimal binary search tree', 'Tree traversal', 'Search tree', 'Splay tree', 'Red–black tree', 'Treap']


{'Here is how a typical binary search tree insertion might be performed in a binary tree in C++:': 0.31650395376481383, 'If a search tree is not intended to be modified, and it is known exactly how often each item will be accessed, it is possible to construct an optimal binary search tree, which is a search tree where the average cost of looking up an item (the expected search cost) is minimized.': 0.29772886934250853, 'In computer science, a binary search tree (BST), also called an ordered or sorted binary tree, is a rooted binary tree data structure whose internal nodes each store a key greater than all the keys in the node’s left subtree and less than those in its right subtree. A binary tree is a type of data structure for storing data such as numbers in an 

In [None]:
print(question) 
print(possibleAnswers[-1]['answer'])

What is a binary search tree
A splay tree
