In [1]:
import os, pprint, re
from datetime import datetime

if not 'workbookDir' in globals():  workbookDir = os.getcwd()
os.chdir(workbookDir+'/..')
  
try:    __set_jtplot__('dark')
except: pass

pp = pprint.PrettyPrinter(indent=2)

In [2]:
%run datasets.py
ds = CNNDailyMail(split='train')  # train, val, or test
print(f'{len(ds)} documents.')

Using existing data.
287227 documents.


In [3]:
from lexrank import STOPWORDS, LexRank

# Training time (Yenson's laptop):
#  - 1000 segments: ~30s
#  - 2000 segments: ~60s
#  - 4000 segments: ~130s
# Looks roughly linear (maybe nlog(n)?) -- won't know for sure at limited scale.
trainSegs = min(4000, len(ds))

ds.outFilt = lambda out: out.src  # use sources only
startTime = datetime.now()
lxr = LexRank(ds[:trainSegs], stopwords=STOPWORDS['en'])
print(f'Creating model took {(datetime.now()-startTime).seconds}s.')

Creating model took 131s.


In [4]:
ds.outFilt = None    # get source and target

def to_sentences(seg):
  sentences = map(lambda sent: sent.strip(), seg.split('.'))
  sentences = filter(lambda sent: len(sent)>0, sentences)
  sentences = map(lambda sent: re.sub(r' , ', ', ', sent), sentences)
  sentences = map(lambda sent: sent[0].upper()+sent[1:]+'.', sentences)
  return list(sentences)


testSeg = ds[-1]
testSents = to_sentences(testSeg.src)
print(f'Test segment has {len(testSents)} sentences.')

summary = lxr.get_summary(testSents, summary_size=2, threshold=.1)

Test segment has 41 sentences.


In [5]:
print(summary)

['Controversial : many claim that the tradition is offensive towards black people.', 'Opponents say pete is an offensive caricature of black people.']


In [6]:
tgt = to_sentences(' '.join(ds.split_tags(testSeg.tgt)))
pp.pprint(tgt)

[ "Facebook page supporting tradition gains one million ` likes ' in a day.",
  "` do n't let the netherlands ' most beautiful tradition disappear, ' it "
  'says.',
  'Un has condemned the tradition claiming it reflects racial prejudice.']


In [7]:
from rouge import Rouge 

rouge = Rouge()
scores = rouge.get_scores(' '.join(summary), ' '.join(tgt))
pp.pprint(scores)  # choose one of the F1 (f) scores

[ { 'rouge-1': { 'f': 0.06779660549267486,
                 'p': 0.09090909090909091,
                 'r': 0.05405405405405406},
    'rouge-2': { 'f': 0.03508771464450662,
                 'p': 0.047619047619047616,
                 'r': 0.027777777777777776},
    'rouge-l': { 'f': 0.08333332864583361,
                 'p': 0.1111111111111111,
                 'r': 0.06666666666666667}}]
