In [130]:
import spacy
import json
import re 

nlp = spacy.load("en_core_web_lg")

In [213]:
def check_pos (token):
    tag_list = [ 'PRP', 'PRP$', 'WP', 'WP$', 'NN', 'NNP', 'NNPS', 'NNS',
                    'JJ', 'JJR', 'JJS', 'MD', 'VB', 'VBD', 'VBG',
                    'VBN', 'VBP', 'VBZ', 'RB', 'RBR', 'RBS', 'WRB']
    contraction = ['\'m', '\'ll', '\'s', 'n\'t', '\'ve']
    text, pos, tag, dep = token.text, token.pos_, token.tag_, token.dep_
    
    if ((pos == 'SCONJ' or tag in tag_list) and text not in contraction \
        and not (text == 'as' and dep == 'prep')) :
        return True

In [214]:
def is_hyphenated (headline, index):
    headline_len = len(headline)
    if (index != headline_len - 1 and headline[index + 1].text == '-') or headline[index - 1].text == '-':
        return True 

In [215]:
def is_camel_case (word):
    if not (len(word) == 2 and word[0].isupper()):
        return word != word.lower() and word != word.upper() and word != word.istitle()

In [216]:
def letter_and_digit (token):
    if any(map(str.isdigit,token)) and any(map(str.isalpha ,token)):
        return True

In [217]:
def handle_leading_trailing(arr, reverse=False):
    arr = arr[:]
    rng = range(len(arr))
    
    if reverse:
        rng = reversed(rng)
    
    for i in rng:
        item = arr[i]
        token, _, is_entity = item
        if is_entity:
            break
        if token.strip().isalpha():
            if not token.isupper() or is_camel_case(token):
                arr[i][0] = token.title()
            break    
    return arr

In [219]:
def format_headline (headline): 
    formatted_headline = []
    parsed_headline = nlp(headline)
    
    for i, token in enumerate(parsed_headline):
        text = token.text
         
        if text.isupper() or is_camel_case(text) or letter_and_digit(text):
            formatted_text = text
        elif len(token) > 3 or check_pos(token) or is_hyphenated(parsed_headline, i):
            formatted_text = text.title()
        else:
            formatted_text = text.lower()
        formatted_headline.append([formatted_text, token.whitespace_, token.ent_type != 0])
    
    formatted_headline = handle_leading_trailing(formatted_headline)
    formatted_headline = handle_leading_trailing(formatted_headline, True)
    formatted_headline = [text + whitespace for text, whitespace, _ in formatted_headline]
    return ''.join(formatted_headline)

In [225]:
def test_headlines (f):
    tp = 0
    with open(f) as json_file:
        data = json.load(json_file)
        for entry in data:
            formatted_headline = format_headline(entry[0])
            if formatted_headline == entry[1]:
                tp += 1
        accuracy = tp/len(data)
        return accuracy

In [226]:
test_headlines('headlines-test-set.json')

0.93

In [228]:
with open('examiner-headlines.txt') as f:
    counter = 0
    for line in f:
        formatted_headline = format_headline(line)
        if formatted_headline == line:
            counter += 1
    print(counter)

697
