#Setup

In [None]:
!pip install transformers
!pip install datasets

In [None]:
from transformers import (
    AutoTokenizer,
    LEDForConditionalGeneration,
)
from datasets import load_dataset, load_metric
import torch
import random

#Helper Fns

In [None]:
"""
Called by clean. Listify source documents, clean up summary (remove "– " at the beginning)

input: single datapoint {'document': String, 'summary': String}
output: {'document': List, 'summary': String}
"""
def clean_single_dp(datapoint):
  docs_str = datapoint['document']
  doc_sep = "|||||"
  doc_list = docs_str.split("|||||") #list of the source documents

  sum = datapoint['summary']
  summary_clean = sum[2:] #get rid of "– " at beginning of each summary

  return doc_list, summary_clean

In [None]:
"""
Clean each datapoint (listify docs, clean summaries) 
Create dicitonary where keys are number of source docs, in case we decide to 
aggregate our data based on #docs (to standardize input size)

input: unaugmented multiNews dataset: List of {'document': String, 'summary': String}
output: 
  all_data: List of {'document_list': List, 'clean_summary': String}
  numdocs_dict: Dictionary where key = #source docs, value = list of datapoints - 
                each one is {'document': String, 'summary': String}
"""
def clean_orig(data): #takes all data
  numdocs_dict = {4: ["hi"]} #initialize dictionary in which key = # source docs, value = list of datapoints (dicitonaries)
  all_data = []

  for point in data:
    #augment single point
    doc_list, summary_clean = clean_single_dp(point)

    #add to all_data
    all_data += [{'document_list': doc_list, 'clean_summary': summary_clean}]

    #add to numdocs_dict
    num_docs = len(doc_list)
    if num_docs in numdocs_dict:
      numdocs_dict[num_docs] += [{'document_list': doc_list, 'clean_summary': summary_clean}]
    else:
      numdocs_dict[num_docs] = [{'document_list': doc_list, 'clean_summary': summary_clean}]


  return all_data

In [None]:
"""
Called by group_data. Aggregate given list of datapoints.

input: list of datapoints, each of which is {document_list: [...], clean_summary: "..."}
output: 
  combo_documents: list of doc lists
  combo_summaries_str: String of combined summaries separated by \n\n
  combo_summaries_list: list of summaries
"""
def combine_points(data_list):

  combo_documents = [] #list of lists
  combo_summaries_list = [] #list of strings

  for point in data_list:
    # print("point has", len(point['document_list']), "sources")
    combo_documents += point['document_list']
    combo_summaries_list += [point['clean_summary']]

  combo_summaries_str = "\n\n".join(combo_summaries_list) #string (concatenated summaries separated by a black line)

  return combo_documents, combo_summaries_str, combo_summaries_list

In [None]:
"""
Randomly partition dataset into groups of GROUP_SIZE, aggregate each group.

input: full cleaned dataset - List of {'document_list': List, 'clean_summary': String}
output: List of aggregated data {'documents': List, 'summary': String'
"""
def group_data(data):
  GROUP_SIZE = 3

  #partition
  random.shuffle(data)
  groups = [data[i:i+GROUP_SIZE] for i in range(0, len(data), GROUP_SIZE)]

  #if last group is a lonely datapoint, merge it with the previous group
  if len(groups[-1]) == 1:
    groups[-2] += groups[-1]
    groups = groups[:-1]

  #aggregate each group
  for i in range(len(groups)):
    combo_docs, combo_sum_str, combo_sum_list = combine_points(groups[i])
    aggregated_group = {'documents': combo_docs, 'summary': combo_sum_str}
    groups[i] = aggregated_group

  return groups

In [None]:
"""
Compute distribution of #documents i.e. how many datapoints have x docs

input: grouped and aggregated dataset - List of {'documents': List, 'summary': String'
output: {int: int} key = number of documents, value = number of datapoints with that number of documents
"""
def get_group_numdoc_freq(grouped_data):
    numdoc_frequency = {}

    for data in grouped_data:
        num_docs = len(data.get('documents', []))
        if num_docs in numdoc_frequency:
            numdoc_frequency[num_docs] += 1
        else:
            numdoc_frequency[num_docs] = 1

    print("Number of Documents Frequency:")
    for numdocs, frequency in numdoc_frequency.items():
        print(f"#documents {numdocs}: {frequency}")

    return numdoc_frequency

#Dataset Augmentation

In [None]:
dataset= load_dataset('multi_news')



  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
train = list(dataset['train'])
val = list(dataset['validation'])
test = list(dataset['test'])

print("Train: ", len(train))
print("Val: ", len(val))
print("Test: ", len(test))

Train:  44972
Val:  5622
Test:  5622


In [None]:
"""
Full data augmentation process
"""
def augment_data(data):
  clean_data = clean_orig(data)
  augmented_data = group_data(clean_data)
  numdoc_freq = get_group_numdoc_freq(augmented_data)
  return augmented_data #List of {'documents': List, 'summary': String}

In [None]:
aug_train = augment_data(train)