In [1]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy
import time
import torch
from torch import nn

## Text Summarization

In [2]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import pipeline

# We can use either BART or T5

# for BART...we can use the built-in pipeline, which uses BART.
BART_summarizer = pipeline("summarization")

# ...or we can specify T5 specifically.
T5_model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
T5_tokenizer = AutoTokenizer.from_pretrained("t5-base")

No model was supplied, defaulted to sshleifer/distilbart-cnn-12-6 (https://huggingface.co/sshleifer/distilbart-cnn-12-6)


In [3]:
# T5 uses a max_length of 512 so we cut the article to 512 tokens.
def T5_summarize(article_text, max_length=150, min_length=40, print_stats=True):
    inputs = T5_tokenizer("summarize: " + article_text, return_tensors="pt", max_length=512, truncation=True)
    outputs = T5_model.generate(
        inputs["input_ids"], max_length=max_length, min_length=min_length, length_penalty=2.0, num_beams=4, early_stopping=True
    )
    summary_text = T5_tokenizer.decode(outputs[0])
    if print_stats:
        print(f"Input text had {len(article_text)}, now summarized to {len(summary_text)}.")
    return summary_text

In [4]:
def BART_summarize(article_text, max_length=130, min_length=30):
    summary_text = BART_summarizer(article_text,max_length=max_length,min_length=min_length)
    # current form: [{'summary_text': ' Liana Barrientos, 39, is charged with two cou... '}]
    summary_text = summary_text[0]['summary_text']
    return summary_text

In [5]:
def summarize(article_text, use_BART, **kwargs):
    # other arguments are max_length and min_length for article text
    if use_BART:
        return BART_summarize(article_text, **kwargs)
    else:
        return T5_summarize(article_text, **kwargs)

## Getting Wikipedia Intro Text to Summarize

In [6]:
import wikipedia

In [7]:
import flask

In [8]:
from functools import lru_cache

In [9]:
# watch out - wikipedia.summary will raise a DisambiguationError if the page is a disambiguation page, or 
# a PageError if the page doesn’t exist. We cache results since we are just retrieving intro pages and wish
# to reduce Wikipedia lookup time and focus on summarization in timings.

@lru_cache()
def get_wikipedia_content_for_page(page_name,allow_auto_suggest=False):
    '''Returns page content if page can be found and is unique, otherwise a list of
    alternatives to try if disambiguation required. First value is whether page was
    found uniquely, second value either the contents or disambiguation alternatives.
    
    Since auto-suggest can sometimes return an incorrect page title (like auto-correct!),
    we first try the page name as is, without auto_suggest. If that fails then we try
    again with auto_suggest on.
    
    '''
    try:
        contents = wikipedia.summary(page_name,auto_suggest=allow_auto_suggest)
        return True,contents
    except wikipedia.exceptions.DisambiguationError as e:
        if not allow_auto_suggest:
            match,results = get_wikipedia_content_for_page(page_name,allow_auto_suggest=True)
            return match,results
        else:
            return f'Disambiguation Error for name {page_name}',e.options
    except wikipedia.exceptions.PageError as e:
        if not allow_auto_suggest:
            match,results = get_wikipedia_content_for_page(page_name,allow_auto_suggest=True)
            return match,results
        return f'No Such Page Error: {page_name}',[]
    except wikipedia.exceptions.HTTPTimeoutError as e:
        return f'Timeout Error: {page_name}',[]
    except wikipedia.exceptions.RedirectError as e:
        return f'Redirect Error: {page_name}',[]

In [10]:
get_wikipedia_content_for_page('Julius Caesar')

(True,
 'Gaius Julius Caesar (Latin: [ˈɡaːiʊs ˈjuːliʊs ˈkae̯sar]; 12 July 100 BC – 15 March 44 BC) was a Roman general and statesman. A member of the First Triumvirate, Caesar led the Roman armies in the Gallic Wars before defeating his political rival Pompey in a civil war, and subsequently became dictator of Rome from 49 BC until his assassination in 44 BC. He played a critical role in the events that led to the demise of the Roman Republic and the rise of the Roman Empire.\nIn 60 BC, Caesar, Crassus and Pompey formed the First Triumvirate, a political alliance that dominated Roman politics for several years. Their attempts to amass power as Populares were opposed by the Optimates within the Roman Senate, among them Cato the Younger with the frequent support of Cicero. Caesar rose to become one of the most powerful politicians in the Roman Republic through a string of military victories in the Gallic Wars, completed by 51 BC, which greatly extended Roman territory. During this time h

In [11]:
from flask import Flask,request,jsonify
import uuid

In [12]:
app = Flask(__name__)

In [13]:
expected_parameters = ['page_title','max_len','min_len','BART_or_T5']

abs_max_max_summary_len = 1024 # truncate max summary length requested to 1024
abs_max_min_summary_len = 512  # truncate max min summary length requested to 512
abs_max_intro_text_len =  4096 # truncate max intro text considered to this size

@app.route("/api",methods=['POST'])
def summarize_article():
    content = request.json
    errors =[]
    
    # check all expected parameters present
    for name in expected_parameters:
        if name not in content:
            errors.append(f"Missing value: {name}")
            
    # get parameters if no errors
    page_title = content['page_title']
    max_len = content['max_len']
    min_len = content['min_len']
    BART_or_T5 = content['BART_or_T5']
    
    if max_len < min_len:
        errors.append(f"Max len, {max_len}, should be greater than min len, {min_len}.")
        
    if BART_or_T5 not in ["BART","T5"]:
        errors.append(f'BART_or_T5 was not "BART" or "T5" but instead "{BART_OR_T5}"')
    
    if not(errors):
        title_found, intro_text = get_wikipedia_content_for_page(page_title)
        if not title_found:
            errors.append(f"Could not find '{page_title}'")
            
        # truncate text to max_len
        if title_found and len(intro_text) > abs_max_intro_text_len:
            intro_text = intro_text[:abs_max_intro_text_len]
            
        if max_len > abs_max_max_summary_len:
            max_len = abs_max_max_summary_len
            
        if min_len > abs_max_min_summary_len:
            min_len = abs_max_min_summary_len
            
        summarized_text = summarize(intro_text,
                                    use_BART=(BART_or_T5=="BART"), 
                                    max_length=max_len,
                                    min_length=min_len)
        
        response = {"id":str(uuid.uuid4()), "text":summarized_text, "errors":errors}
        
    else:
        
        response = {"id":str(uuid.uuid4()), "errors":errors}
        
    return jsonify(response)

In [None]:
if __name__=="__main__":
    app.run(debug=True, use_reloader=False)

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: on


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [13/Jan/2022 15:12:32] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [13/Jan/2022 15:12:44] "POST /api HTTP/1.1" 200 -


Input text had 4096, now summarized to 269.


127.0.0.1 - - [13/Jan/2022 15:13:10] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [13/Jan/2022 15:13:45] "POST /api HTTP/1.1" 200 -


Input text had 4096, now summarized to 900.


127.0.0.1 - - [13/Jan/2022 15:14:10] "POST /api HTTP/1.1" 200 -
Your max_length is set to 512, but you input_length is only 487. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)


Input text had 2374, now summarized to 963.


127.0.0.1 - - [13/Jan/2022 15:14:25] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [13/Jan/2022 15:14:48] "POST /api HTTP/1.1" 200 -
Your max_length is set to 1024, but you input_length is only 836. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)
127.0.0.1 - - [13/Jan/2022 15:15:43] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [13/Jan/2022 15:16:06] "POST /api HTTP/1.1" 200 -
Your max_length is set to 1024, but you input_length is only 629. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)
127.0.0.1 - - [13/Jan/2022 15:16:50] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [13/Jan/2022 15:17:14] "POST /api HTTP/1.1" 200 -
Your max_length is set to 1024, but you input_length is only 719. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)
127.0.0.1 - - [13/Jan/2022 15:17:57] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [13/Jan/2022 16:40:40] "GET / HTTP/1.1" 404 -
127.0.0.1 - - [13/Jan/2022 