In [1]:
import huggingface
from transformers import AutoTokenizer, BartForConditionalGeneration
from datasets import load_dataset
from sklearn.model_selection import train_test_split

from tqdm import tqdm
import pickle


  from .autonotebook import tqdm as notebook_tqdm


#### Load Dataset

In [2]:
dataset = load_dataset("abisee/cnn_dailymail", "1.0.0")
dataset

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})

#### Check Dataset

In [3]:
# Exploring some examples from the dataset

example_indices = [50, 100, 154, 200]


dash_line = 100 * '-'

for i, index in enumerate(example_indices):
    print(dash_line)
    print('Example ', i + 1)
    print(dash_line)
    print('INPUT ARTICLE:')
    print(dataset['test'][index]['article'])
    print(dash_line)
    print('BASELINE HUMAN HIGHLIGHTS:')
    print(dataset['test'][index]['highlights'])
    print(dash_line)
    print()

----------------------------------------------------------------------------------------------------
Example  1
----------------------------------------------------------------------------------------------------
INPUT ARTICLE:
(CNN)According to an outside review by Columbia Journalism School professors, "(a)n institutional failure at Rolling Stone resulted in a deeply flawed article about a purported gang rape at the University of Virginia." The Columbia team concluded that "The failure encompassed reporting, editing, editorial supervision and fact-checking." Hardly a ringing endorsement of the editorial process at the publication. The magazine's managing editor, Will Dana, wrote, "We would like to apologize to our readers and to all of those who were damaged by our story and the ensuing fallout, including members of the Phi Kappa Psi fraternity and UVA administrators and students." Brian Stelter: Fraternity to 'pursue all available legal action' The next question is: . Can UVA, Phi K

These examples show that the news articles get summurised by in a fairly succinct fashion to give just the main highlights for each. 

#### Pre-process Dataset

In [None]:
# Load model and tokenizer
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

In [None]:
def tokenize_dataset(batch_size, data, category):
    '''
    Function to handle batch tokenizing each part of dataset
    batch_size(int): number of articles to be tokenized at once (reduced to help with compute)
    data (string): specifying train, validation or test subsets of the dataset
    category (string): specifying 'article' or 'highlights'
    
    '''

    input_ids = []
    attention_masks = []
    
    for i in tqdm(range(0, len(dataset[data][category]), batch_size)):
        batch_training_articles = dataset[data][category][i:i+batch_size]
        tokenized_batch = tokenizer(batch_training_articles, max_length=512, padding = True, truncation = True, return_tensors = 'pt')
        input_ids.extend(tokenized_batch['input_ids'])
        attention_masks.extend(tokenized_batch['attention_mask'])
                        

    return input_ids, attention_masks



    

In [None]:
# Initially I am going to not use the train subset to reduce the dataset size and the time spent tokenizing and training. I will include this set in the future


# Train
# print("Tokenizing Train:Article")
# train_tokenized_article = tokenize_dataset(1000, 'train', 'article')

# print("Tokenizing Train:Highlight")
# train_tokenized_highlight = tokenize_dataset(1000, 'train', 'highlight')


# Validation
print("Tokenizing Validation:Article")
validation_tokenized_article_ids, validation_tokenized_article_masks  = tokenize_dataset(1000, 'validation', 'article')

print("Tokenizing Validation:Highlight")
validation_tokenized_highlight_ids, validation_tokenized_highlight_masks = tokenize_dataset(1000, 'validation', 'highlights')


# Test
print("Tokenizing Test:Article")
test_tokenized_article_ids, test_tokenized_article_masks = tokenize_dataset(1000, 'test', 'article')

print("Tokenizing Test:Highlight")
test_tokenized_highlight_ids, test_tokenized_highlight_masks = tokenize_dataset(1000, 'test', 'highlights')

In [None]:
# Save tokenized data
def save_tokenized_data(tokenized_data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(tokenized_data, f)

# Load tokenized data
def load_tokenized_data(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)


# save_tokenized_data(validation_tokenized_article, 'validation_tokenized_article.pkl')
# save_tokenized_data(validation_tokenized_highlight, 'validation_tokenized_highlight.pkl')
# save_tokenized_data(test_tokenized_article, 'test_tokenized_article.pkl')
# save_tokenized_data(test_tokenized_highlight, 'test_tokenized_highlight.pkl')

In [None]:
sentence = "Does the tokenizer work I wonder?"

sentence_encoded = tokenizer(sentence, return_tensors='pt')

sentence_decoded = tokenizer.decode(
        sentence_encoded["input_ids"][0], 
        skip_special_tokens=True
    )

print('ENCODED SENTENCE:')
print(sentence_encoded["input_ids"][0])
print('\nDECODED SENTENCE:')
print(sentence_decoded)

It does.

#### Combine validation and test sets and resplit to create train-val-test

In [None]:
# Combine datasets (concatenate validation and test sets)
tokenized_articles_ids = validation_tokenized_article_ids + test_tokenized_article_ids
tokenized_articles_masks = validation_tokenized_article_masks + test_tokenized_article_masks
tokenized_highlights_ids = validation_tokenized_highlight_ids + test_tokenized_highlight_ids
tokenized_highlights_masks = validation_tokenized_highlight_masks + test_tokenized_highlight_masks


In [None]:

# Function to flatten both input_ids and attention_masks
def flatten_batches(input_batches, mask_batches):
    flat_input_ids = []
    flat_attention_masks = []
    
    for input_batch, mask_batch in tqdm(zip(input_batches, mask_batches), total=len(input_batches), desc="Flattening batches"):
        for input_item, mask_item in zip(input_batch, mask_batch):
            flat_input_ids.append(input_item)
            flat_attention_masks.append(mask_item)
    
    return flat_input_ids, flat_attention_masks



In [None]:
import torch

def flatten_batches_in_chunks(input_batches, mask_batches, chunk_size=1000):
    flat_input_ids = []
    flat_attention_masks = []

    for i in range(0, len(input_batches), chunk_size):
        # Process in smaller chunks to avoid high memory usage
        chunk_input_batches = input_batches[i:i + chunk_size]
        chunk_mask_batches = mask_batches[i:i + chunk_size]

        # Flatten the current chunk
        for input_batch, mask_batch in tqdm(zip(chunk_input_batches, chunk_mask_batches), total=len(input_batches), desc="Flattening batches"):
            for input_item, mask_item in zip(input_batch, mask_batch):
                flat_input_ids.append(input_item)
                flat_attention_masks.append(mask_item)

        # Optionally save intermediate results to disk to reduce memory usage
        # torch.save((flat_input_ids, flat_attention_masks), f"flattened_chunk_{i//chunk_size}.pt")

    return flat_input_ids, flat_attention_masks

# Flatten articles and highlights in smaller chunks
flattened_article_ids, flattened_article_masks = flatten_batches_in_chunks(tokenized_articles_ids, tokenized_articles_masks, chunk_size=1000)
flattened_highlight_ids, flattened_highlight_masks = flatten_batches_in_chunks(tokenized_highlights_ids, tokenized_highlights_masks, chunk_size=1000)



In [None]:

# Recreate train:val:test splits (splitting both input_ids and attention_masks)
train_articles, temp_articles, train_highlights, temp_highlights = train_test_split(
    flattened_article_ids, flattened_highlight_ids, test_size=0.3, random_state=10)

train_article_masks, temp_article_masks, train_highlight_masks, temp_highlight_masks = train_test_split(
    flattened_article_masks, flattened_highlight_masks, test_size=0.3, random_state=10)

# Splitting again to give the test and val split
validation_articles, test_articles, validation_highlights, test_highlights = train_test_split(
    temp_articles, temp_highlights, test_size=0.5, random_state=10)

validation_article_masks, test_article_masks, validation_highlight_masks, test_highlight_masks = train_test_split(
    temp_article_masks, temp_highlight_masks, test_size=0.5, random_state=10)


In [None]:
print(f"Train Articles: {len(train_articles)} |  Train highlights: {len(train_highlights)}")
print(f"Val Articles: {len(validation_articles)} |  Val highlights: {len(validation_highlights)}")
print(f"Test Articles: {len(test_articles)} |  Test highlights: {len(test_highlights)}")



In [None]:
# Decoding Sanity Check

sentence_decoded = tokenizer.decode(
        train_articles[1], 
        skip_special_tokens=True
    )


print('\nDECODED ARTICLE:')
print(sentence_decoded)

In [None]:
# Check max length of inputs 
max_length = max(len(x) for x in train_articles)
max_length

In [None]:
print(tokenizer.decode(train_articles[4]))
print("\n\n", 200*'-')
print(tokenizer.decode(train_highlights[4]))

#### Test out pre-trained BART

In [None]:
# test the model out on a few examples

indexes = [1, 50, 120]


# Iterate over the selected examples
for i, index in enumerate(indexes):
    # Extract the pre-tokenized article and highlight
    article = train_articles[index]  # This is a tensor directly
    highlight = train_highlights[index]  # Also a tensor

    # Since 'article' is already a tensor, pass it directly to the model
    inputs = article.unsqueeze(0)  # Add a batch dimension since the model expects a batch of inputs

    # Generate a summary using the pre-tokenized input_ids
    outputs = model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)

    # Decode and print the results
    print(dash_line)
    
    # Decode the tokenized article
    decoded_input = tokenizer.decode(article.tolist(), skip_special_tokens=True)
    print(f'INPUT PROMPT:\n{decoded_input}')
    
    print(dash_line)
    
    # Decode the human-provided highlight (assuming it's already tokenized)
    decoded_highlight = tokenizer.decode(highlight.tolist(), skip_special_tokens=True)
    print(f'Baseline Human Highlights:\n{decoded_highlight}')
    
    print(dash_line)
    
    # Decode the model's generated summary
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f'Base Model Generation:\n{decoded_output}\n')


In [None]:
train_articles[0]