In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
# Select the type of Rouge Score
TYPES = ['RL', 'R1', 'R2']

chosen_rouge = 'R1' # Edit this line

if not chosen_rouge in TYPES:
  raise ValueError("Not valid Rouge Score")

if chosen_rouge == 'RL':
  folder_name = 'R-L'
elif chosen_rouge == 'R2':
  folder_name = 'R2'
else:
  folder_name = 'R1'


In [None]:
f = open('drive/MyDrive/Text Summarization (Dataset and labels)/R1/train.jsonl', 'w')
f.close()

In [None]:
# Install depencencies
!pip install datasets
!pip install nltk
!pip install rouge-score
!pip install evaluate
!pip install bert_score
!pip install python-crfsuite
!pip install Flask
!pip install chardet

In [None]:
import json
import nltk
nltk.download('punkt')
from nltk import tokenize
from rouge_score import rouge_scorer
import re
import scipy.stats as ss
import bert_score
import pycrfsuite
import os
import random

# Load Multi-lexsum dataset, ETA is circa 2 mins
from datasets import load_dataset

multi_lexsum = load_dataset("allenai/multi_lexsum", name="v20220616")

train_dataset = multi_lexsum['train']
val_dataset = multi_lexsum['validation']
test = multi_lexsum['test']
# Download multi_lexsum locally and load it as a Dataset object 

In [None]:
!pip install transformers

from transformers import LEDTokenizer, LEDForConditionalGeneration
from transformers import AutoTokenizer

tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")

In [None]:
def bert_scorer(sentence, target):
  bert_scorer = bert_score.BERTScorer(lang="en", model_type='amazon/bort')
  return bert_scorer.score([sentence], [target])[0][0]

def rouge1_scorer(sentence, target):
   return rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True).score(sentence, target)['rouge1'].recall

def rouge2_scorer(sentence, target):
  # print('candidate: ', sentence)
  return rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True).score(sentence, target)['rouge2'].recall

def rougeL_scorer(sentence, target):
   return rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True).score(sentence, target)['rougeL'].recall

scorer_dict = {'bert': bert_scorer,
               'rouge1': rouge1_scorer,
               'rouge2': rouge2_scorer,
               'rougeL': rougeL_scorer
              }
              
# Sentence tokenizer code starts here
LOWER_CHAR = re.compile('[a-z]')
UPPER_CHAR = re.compile('[A-Z]')
SINGLE_DIGIT = re.compile('\\d')
WS = re.compile('^[ \\t]]+$')
LB_WS = re.compile('^[\\n\\r\\v\\f]+$')


def text2sentences(text, offsets=False):
    tokenizer = re.compile(r'[A-z]+|\d+|[ \t\f]+|[\n\r\v]+|[^A-z\d\s]')
    matches = [match for match in tokenizer.finditer(text)]
    preds = tokens2preds([match.group() for match in matches])
    indices = preds2sentences(matches, preds)
    if offsets:
        return indices
    else:
        return [text[indice[0]:indice[1]] for indice in indices]


def preds2sentences(matches, preds):
    indices = []
    in_annotation = False
    start, end = (0, 0)
    for label, match in zip(preds, matches):
        if label != 'O':
            if in_annotation:
                end = match.end()
            else:
                in_annotation = True
                start = match.start()
                end = match.end()
        else:
            if in_annotation:
                in_annotation = False
                indices.append((start, end))
    if in_annotation:
        indices.append((start, end))
    return indices


def tokens2preds(tokens):
    features = [word2features(tokens, i, 3) for i, token
                in enumerate(tokens)]
    tagger = init_crf_model('20180904.crfsuite')
    return tagger.tag(features)


def word2features(doc, i, n, extras=None):
    if not extras:
        extras = []
    features = ["bias"]
    for n_idx in range(0, n+1):
        if i+n_idx < len(doc):
            features.extend(token2features(token=doc[i+n_idx], i=n_idx))
        elif i+n_idx == len(doc):
            features.append(str(n_idx) + ':EOS')
    for n_idx in range(-n, 0):
        if i+n_idx >= 0:
            features.extend(token2features(token=doc[i + n_idx], i=n_idx))
        elif i+n_idx == -1:
            features.append(str(n_idx) + ':BOS')
    return features


def token2features(token, i):
    return [
        str(i) + ":word.lower=" + token.lower(),
        str(i) + ":word.sig=" + create_token_sig(token),
        str(i) + ":word.length=" + get_token_length(token),
        str(i) + ":word.islower=" + str(token.islower()),
        str(i) + ":word.isupper=" + str(token.isupper()),
        str(i) + ":word.istitle=" + str(token.istitle()),
        str(i) + ":word.isdigit=" + str(token.isdigit()),
        str(i) + ":word.iswhitespace=" + str(token.isspace())
    ]


def create_token_sig(token):
    token = LOWER_CHAR.sub('c', token)
    token = UPPER_CHAR.sub('C', token)
    token = SINGLE_DIGIT.sub('D', token)
    ws_match = WS.match(token)
    ln_ws_match = LB_WS.match(token)
    if ws_match:
        if ws_match.end()-ws_match.start() < 2:
            token = 'singlehws'
        elif ws_match.end()-ws_match.start() < 5:
            token = 'shorthws'
        elif ws_match.end()-ws_match.start() < 10:
            token = 'hws'
        else:
            token = 'longhws'
    if ln_ws_match:
        if ln_ws_match.end() - ln_ws_match.start() < 2:
            token = 'singlevws'
        elif ln_ws_match.end() - ln_ws_match.start() < 3:
            token = 'doublevws'
        elif ln_ws_match.end() - ln_ws_match.start() < 4:
            token = 'triplevws'
        else:
            token = 'longvws'
    return token


def get_token_length(token):
    length = len(token)
    if length < 4:
        return str(length)
    elif length < 7:
        return 'normal'
    else:
        return 'long'


def init_crf_model(model_type):
    # tagger = pycrfsuite.Tagger()
    
    # model_path = os.path.join('', '', model_type)
    # tagger.open(model_path)

    tagger = pycrfsuite.Tagger()
    import requests
    model_path = "20180904.crfsuite"
    with open(model_path, 'wb' ) as fw:
        file_path = requests.get("https://github.com/jsavelka/luima_sbd/blob/master/data/20180904.crfsuite?raw=true")
        fw.write(file_path.content)
    tagger.open(model_path)

    return tagger

def preprocess_input(input):
  input = re.sub(r'[^a-zA-Z0-9\.\!\? ]', '', input.replace('\n', ' '))
  input = re.sub(r'Page.[0-9].of.[0-9]', '' , input)
  # plug removing duplications here
  return input

def get_sentences(input):
  return text2sentences(input)

def add_unimportant(selected_dct):

  random.seed(5)

  selected_dct2 = selected_dct.copy()
  # print(selected_dct2)
  for key in selected_dct2.keys():
    value = selected_dct2[key]
    max_sentence_index = value[-1]
    unimportant_potential = list(range(0, max_sentence_index+1))
    unimportant = sorted((list(set(unimportant_potential) - set(selected_dct2[key]))))
    candidates = random.choices(unimportant, k=int(random.choice(range(0, 20))/100*len(selected_dct2[key])))
    len_unimportant = len(candidates)
    # print("candidates: ", candidates)
    # print("candidate size ", len_unimportant)
    # print(random.shuffle(selected_dct2[key]))

    # We need to remove equal number of sentences from the selected as much as we are adding
    if(len_unimportant>0):
      random.shuffle(value)
      selected_dct2[key] = sorted(value[:len(value)-len_unimportant])
      # print("Truncated: ", selected_dct2[key])

      selected_dct2[key] = sorted(list(set(selected_dct2[key] + candidates)))

  # sorted(random.shuffle(selected_dct2[doc_idx])[:-int(random.choice(range(0, 0.2)*len('assgfhs'))])
  # print("final:", selected_dct2)
  return selected_dct2


def find_max_sentence(sentences, target, scorer):
  max_score = 0
  i=0
  start_index=0
  while (i<len(sentences)):
    sentence_score = scorer(sentences[i], target)
    if sentence_score > max_score:
      max_score = sentence_score
      start_index = i
    i=i+1
  return start_index

def greedy_search(sentences, paragraph, paragraph_weight, scorer):
  start_index = find_max_sentence(sentences, paragraph, scorer)
  selected = [start_index]
  max_rouge = scorer(paragraph, sentences[start_index])

  for i in range(len(sentences)):
      score = scorer(paragraph, ' '.join(sentences[slt] for slt in selected) + sentences[i])
      new_rouge = score

      if new_rouge > max_rouge:
          selected.append(i)
          max_rouge = new_rouge
          total_token_length = 0
          for slt in selected:
            total_token_length += len(tokenizer(sentences[slt]))
          if total_token_length > MAX_TOKEN_SIZE*paragraph_weight:
            #print('Reached the limit for this document!')
            break
          # print(selected)
  # selected = sorted(selected)
  return selected

def get_order(sources, paragraphs, scorer):
  order_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
  rank_dct = {i: [] for i in range(len(paragraphs))}
  for i in range(len(paragraphs)):
    max_score = 0
    pos = 0
    for j in range(len(sources)):
      score = order_scorer.score(paragraphs[i], sources[j])['rougeL'].recall
      rank_dct[i] = rank_dct[i] + [score]

  # re-rank the dict
  order = {}
  for key in rank_dct.keys():
    # if rouge score is low
    # can separate into 3 levels
    if sum(rank_dct[key]) < 0.5*len(rank_dct[key]):
      order[key] = [None] + (ss.rankdata(rank_dct[key])[::-1] - 1).astype(int).tolist()
    else:
      order[key] = (ss.rankdata(rank_dct[key])[::-1] - 1).astype(int).tolist()
  return order

def get_proxy_label(train_idx, scorer, softing=False, unimportant=False, datasplit='train'):
  print(train_idx)
  sources = multi_lexsum[datasplit][train_idx]['sources']
  target = multi_lexsum[datasplit][train_idx]['summary/long']
  paragraphs = target.split('\n\n')
  paragraphs = [paragraph for paragraph in paragraphs if paragraph != '']
  order = get_order(sources, paragraphs, scorer)
  #print(order)
  #print("Unimportant flag is :", unimportant)
  selected_dct = {i: [] for i in range(len(sources))}
  for priority in range(len(sources)+1):
    for paragraph_idx in order.keys():
      try:
        if order[paragraph_idx][priority] is not None:
          doc_idx = order[paragraph_idx][priority]
          # print('Priority {}, greedy on paragraph {} and document {}'.format(priority+1, paragraph_idx, doc_idx))
          input = preprocess_input(sources[doc_idx])
          sentences = get_sentences(input)
          candidates = greedy_search(sentences=sentences,
                                    paragraph=paragraphs[paragraph_idx],
                                    paragraph_weight=len(paragraphs[paragraph_idx].split())/len(target.split()),
                                    scorer=scorer)
          selected_dct[doc_idx] = sorted(list(set(selected_dct[doc_idx] + candidates)))
      except: pass
    
    token_len = 0
    for key in selected_dct.keys():
      input = preprocess_input(sources[key])
      sentences = get_sentences(input)
      for slt in seleced_dct[key]:
        token_len += len(tokenizer(sentences[slt]))
    #print('Running on Priority {} -> Current token length {}'.format(priority, token_len))
    if token_len > MAX_TOKEN_SIZE or token_len > 15*len(target.split()):
      break

  # Document softing
  if softing or token_len < 10*len(target.split()):
    print('Searching in other documents')
    unlisted = []
    for key in selected_dct.keys():
      if len(selected_dct[key]) < 1:
        unlisted.append(key)
    print(unlisted)

    for doc_idx in unlisted:
      input = preprocess_input(sources[doc_idx])
      sentences = get_sentences(input)
      for paragraph_idx in range(len(paragraphs)):
        candidates = greedy_search(sentences=sentences,
                                   paragraph=paragraphs[paragraph_idx],
                                   paragraph_weight=len(paragraphs[paragraph_idx].split())/len(target.split())/len(sources),
                                   scorer=scorer)
        selected_dct[doc_idx] = sorted(list(set(selected_dct[doc_idx] + candidates)))

  if unimportant:
    print("Before unimportant: ", selected_dct)
    selected_dct = add_unimportant(selected_dct)
    print("After unimportant: ", selected_dct)
  return selected_dct, sources

In [None]:
import warnings
warnings.filterwarnings("ignore")
MAX_TOKEN_SIZE = 15000

In [None]:
a=[]
b=[]
datasets_list = ['train', 'validation', 'test']

for data in datasets_list:
  store_file = open(f"drive/MyDrive/Text Summarization (Dataset and labels)/{folder_name}/{data}.jsonl", "w")
  for i in range(len(multi_lexsum[data])):
    #replace with len dataset
    proxy_indices = []
    proxy_labels = []
    rouge1_scores = []
    rouge2_scores = []
    rougeL_scores = []
    bert_scores = []
    paragraphs = dict()
    scorers = [rouge1_scorer]

    for scorer in scorers:
      idx = i
      selected_dct, sources = get_proxy_label(train_idx=idx, scorer=scorer, softing=True, datasplit=data)

      target = preprocess_input(multi_lexsum[data][idx]['summary/long'])


      proxy_indices.append(selected_dct)

      # Reconstruct label
      for key in range(len(sources)):
        input = preprocess_input(sources[key])
        sentences = get_sentences(input)
        paragraphs[str(key)] = sentences 
        proxy_label = ' '.join(sentences[slt] for slt in selected_dct[key])
        proxy_labels.append(proxy_label)
        b.append(proxy_labels)

      result = dict()
      result["summary"] = target
      result["source"] = paragraphs
      result["label"] = selected_dct

      #print(result)
      store_file.write(json.dumps(result) + "\n")

  store_file.close()

0
Searching in other documents
[]
1
Searching in other documents
[]
2
Searching in other documents
[]
3
Searching in other documents
[0, 9, 11, 12]
4
Searching in other documents
[]
5
Searching in other documents
[]
6
Searching in other documents
[]
7
Searching in other documents
[0]
8
Searching in other documents
[]
9
Searching in other documents
[0]
10
Searching in other documents
[1]
11
Searching in other documents
[0, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]
12
Searching in other documents
[6, 8, 9, 10, 12, 13]
13
Searching in other documents
[]
14
Searching in other documents
[1, 2]
15
Searching in other documents
[0, 1, 2, 3, 4, 8]
16
Searching in other documents
[0, 2, 3, 4, 7]
17
