In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import AutoModelForSequenceClassification
from transformers import TFAutoModelForSequenceClassification
from transformers import AutoTokenizer
import numpy as np
import pickle
from general_functions import *
import glob, os
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Create Article Summary

In [None]:
def summarize_article(article,tokenizer,model,max_length_s,n_beams):
    ARTICLE_TO_SUMMARIZE = article
    inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')#.to('cuda')

    # Generate Summary
    summary_ids = model.generate(inputs['input_ids'], num_beams=n_beams, max_length=max_length_s, early_stopping=False)
    summary_text = ([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])    
    
    return summary_text

# Model Init and Run Over DataSet

In [None]:
def run_summarization(model_name,file_name_in,max_length_s,n_beams,file_name_out):
    
    print(file_name_out)
    
    with open(file_name_in, 'rb') as handle:
        DATA_LIST = pickle.load(handle)
    
    model = BartForConditionalGeneration.from_pretrained(model_name)
    #model.to('cuda')
    tokenizer = BartTokenizer.from_pretrained(model_name)
    
    example_counter = 0
    final_list_all = []

    for entry in DATA_LIST:
        print(example_counter)    
        article_mod = entry[2]
        article = clean_article_new(entry[6],src_type)
        highlights = clean_high_new(entry[7],src_type)

        final_list_all.append([entry,highlights,summarize_article(article,tokenizer,model,max_length_s,n_beams),summarize_article(article_mod,tokenizer,model,max_length_s,n_beams)])
        example_counter = example_counter + 1

    with open(file_name_out, 'wb') as handle:
        pickle.dump(final_list_all, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Get All Picke Files to Create Summary

In [None]:
file_name_in_list = []

for root, dirs, files in os.walk(os.getcwd()):
    for file in files:
        if file.find("pickle") != -1:
            if file.find("beams") == -1:
                f_name = str(os.path.join(root, file))
                print(file)
                file_name_in_list.append(file)

# Run Over All Files

In [None]:
max_length_s_list = [400]
n_beams_list = [1,4,10]
model_name_list = ["facebook/bart-large-cnn","facebook/bart-large-xsum","sshleifer/distilbart-cnn-12-6","sshleifer/distilbart-xsum-12-6"]

for model_name in model_name_list:  
    for file_name_in in file_name_in_list:
        for n_beams in n_beams_list:
            for max_length_s in max_length_s_list:
                
                mname = model_name.replace('/','-')    

                file_name_out = file_name_in + '_mLen_' + str(max_length_s) + '_name_'+mname+'_beams_'+str(n_beams)+'.pickle'

                if file_name_in.find("cnn") != -1:
                    src_type = 'cnn_dailymail'

                if file_name_in.find("multi") != -1:
                    src_type = 'multi_news'

                if file_name_in.find("newsroom") != -1:
                    src_type = 'newsroom'

                run_summarization(model_name,file_name_in,max_length_s,n_beams,file_name_out)