In [None]:
%pip install sec-api
%pip install -U transformers==3.0.0
%python -m nltk.downloader punkt
%git clone https://github.com/patil-suraj/question_generation.git
%cd question_generation

In [2]:
import spacy

from pipelines import pipeline
from nltk.tokenize import sent_tokenize
import re

from sec_api import ExtractorApi

In [3]:
def get_section(filing_url, section) :
    res = filing_url[:filing_url.index("htm") + len("htm")]
    section_text = extractorApi.get_section(res, section, "text")
    return section_text

In [4]:
def preprocess_text(text):
    temp_text = text.lower().replace('\n',' ').replace(' %','%')
    return temp_text

In [5]:
def extract_metric_vals(text, val_type="PERCENT", NER=None):
    if val_type == "PERCENT":
        return re.findall(r'(\d+(?:\.\d+)?%?(?!\S))', text)
    if val_type == "NUMBER":
        return re.findall(r"[+-]?([0-9]+\.?[0-9]*|\.[0-9]+)", text)
    if val_type == "RATIO":
        return re.findall(r"([0-9]+:[0-9]+)", text)
    if val_type == "MONEY":
        values = []
        entities = NER(text)
        for w in entities.ents:
            if w.label_ == 'MONEY':
                values.append(w.text)
        return values
    return []

In [6]:
# 1. search the metric name in the document
def search_metric(text_list, metric_list):
    matched_indices = []
    idx = 0
    while idx < len(text_list):
        if text_list[idx : idx+len(metric_list)] == metric_list:
            matched_indices.append(idx)
        idx += 1
    return matched_indices

# 2. get k words before and after the searched metric
def extract_phrases(text_list, matched_indices, k):
    phrases_extracted = []
    for idx in matched_indices:
        phrase = ""
        for i in range(-k,k+1):
            if idx+i < 0 or idx+i > len(text_list)-1:
                continue
            phrase += text_list[idx+i] + " "
        phrases_extracted.append(phrase)
    return phrases_extracted

# 3. apply NER and check for corresonding entity
def find_possible_values(text, metric, NER, k, val_type='PERCENT'):
    text = text.replace(',', ' ').replace('-', ' ')
    metric = metric.replace('-', ' ')
    text_list = text.split(' ')
    metric_list = metric.split(' ')
    matched_indices = search_metric(text_list, metric_list)
    phrases_extracted = extract_phrases(text_list, matched_indices, k)
    possible_values = []
    for phrase in phrases_extracted:
        possible_values += extract_metric_vals(phrase, val_type, NER)
    return possible_values

In [7]:
def filter_passage(doc,metric) :
    sents = sent_tokenize(doc)
    filtered_sents = ".".join(s for s in sents if metric in s)
    return filtered_sents

def get_output(passage,question,metric,NER,val_type='PERCENT') :
    tex = preprocess_text(passage)
    filtered_passage = filter_passage(tex,metric)
    ans = nlp({  "question": question,  "context": filtered_passage})
    ans = ans.replace(',', ' ')
    output_values = extract_metric_vals(ans, val_type, NER)
    return output_values

def get_correct_value(possible_values, output_values):
    correct_values = []
    for val1 in output_values:
        for val2 in possible_values:
            if val1 == val2 and val1 not in correct_values:
                correct_values.append(val1)    
    if len(correct_values) > 0:
        return correct_values[-1]
    return ''

In [None]:
NER = spacy.load("en_core_web_sm", disable=["tok2vec", "tagger", "parser", "attribute_ruler", "lemmatizer"])
nlp = pipeline("multitask-qa-qg")

In [13]:
api_key = '8aa664eee7896b98038d05ca58f3779d3d9dfa9ca28ad1394702477226556894'
extractorApi = ExtractorApi(api_key)

In [23]:
metric = "churn rate"
val_type = "PERCENT"
k = 6
url = "https://www.sec.gov/Archives/edgar/data/0001710583/000171058319000010/swch12311810-k.htm"
relevant_sections = ['1', '1A', '6', '7']
for sec in relevant_sections:
    text = get_section(url, sec)
    text = preprocess_text(text)
    possible_values = find_possible_values(text, metric, NER, k, val_type)
    output_values = get_output(text, f'What is the value of {metric}?', metric, NER, val_type)
    correct_value = get_correct_value(possible_values, output_values)
    if len(correct_value) > 0:
        break
print(correct_value)

0.7%
