# TL;DR – Too Long, Doctor

TL;DR is a ML model designed to synthesize and cluster scientific papers. Tailored for both students and researchers seeking to optimize their study time, TL;DR provides a tool to quickly grasp the essence of complex scientific material. Additionally, it caters to those who desire a concise summary or a preliminary overview of a paper before delving into a detailed reading.

## Importing libraries

In [None]:
# Import library to extract data from XML file
import xml.etree.ElementTree as ET
import io
import os
import json
from pymongo import MongoClient
import re
import logging
import string

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BartTokenizer, BartForConditionalGeneration, AdamW, pipeline
from tqdm.auto import tqdm

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

## Dataset Generation (from XML to JSON)

In [None]:
def extract_information_from_xml(xml_content):
    """
    Parameters:
    xml_path (str): Path to the XML file.

    Returns:
    dict: A dictionary containing the extracted information, with separated abstract sub-layers.
    """
    try:
        # Parse the XML content
        root = ET.fromstring(xml_content)

        # Initialize a dictionary to hold the extracted information
        extracted_info = {
            'Title': '',
            'Abstract': {'Simple Summary': '', 'Detailed Abstract': ''},
            'Sections': [],
            'Keywords': []
        }

        # Extract Title
        title_element = root.find('.//article-title')
        if title_element is not None:
            extracted_info['Title'] = ''.join(title_element.itertext())

        # Extract Abstracts
        abstract_element = root.find('.//abstract')
        if abstract_element is not None:
            sec_elements = abstract_element.findall('.//sec')
            for sec in sec_elements:
                section_title = ''.join(sec.find('.//title').itertext()).strip() if sec.find('.//title') is not None else ""
                section_text = ''.join(sec.itertext()).strip()
                
                # Remove the section title from the beginning of the section text
                if section_text.startswith(section_title):
                    section_text = section_text[len(section_title):].strip()
                
                if 'simple summary' in section_title.lower():
                    extracted_info['Abstract']['Simple Summary'] = section_text
                else:
                    # Append other sections to the 'Detailed Abstract', removing repeated titles if present
                    if extracted_info['Abstract']['Detailed Abstract']:
                        extracted_info['Abstract']['Detailed Abstract'] += ' ' + section_text
                    else:
                        extracted_info['Abstract']['Detailed Abstract'] = section_text
            
        # Extract Keywords
        kwd_group_elements = root.findall('.//kwd-group')
        for kwd_group in kwd_group_elements:
            keywords = [kwd.text for kwd in kwd_group.findall('.//kwd')]
            extracted_info['Keywords'].extend(keywords)

        # Extract Sections
        sections = root.findall('.//body//sec')
        for sec in sections:
            section_title_element = sec.find('.//title')
            if section_title_element is not None:
                section_title = ''.join(section_title_element.itertext())
                # Remove the title element to avoid repetition in the content
                sec.remove(section_title_element)
            else:
                section_title = "No Title"
            
            # Extracting content after removing the title
            section_content = ''.join(sec.itertext()).strip()
            extracted_info['Sections'].append({
                'Title': section_title,
                'Content': section_content
            })

        return extracted_info
    except ET.ParseError as e:
        print(f"XML parsing error: {e}")
        return None
    except Exception as e:
        print(f"Unexpected error: {e}")
        return None

In [None]:
# MongoDB setup
mongo_conn_string = os.environ['MONGO_CONN_STRING']
db_name = os.environ['DB_NAME']
xml_collection_name = os.environ['XML_COLLECTION']
json_raw_collection_name = os.environ['JSON_RAW_COLLECTION']

client = MongoClient(mongo_conn_string)
db = client[db_name]
xml_collection = db[xml_collection_name]
json_raw_collection = db[json_raw_collection_name]

def extract_information_and_write_to_json_file(mongo_id, json_raw_collection, xml_collection):
    """
    Extracts Title, Abstract, Body, and Keyword Group from scientific papers in XML format stored in MongoDB.
    Writes the information to a JSON file. The JSON file will have the same name as the original XML file, but with a .json extension.

    Parameters:
    mongo_id (ObjectId): The ObjectId of the document in MongoDB.
    """
    try:
        if json_raw_collection.count_documents({'original_id': mongo_id}) == 0:
            # Retrieve the document from MongoDB using its ObjectId
            document = xml_collection.find_one({'_id': mongo_id})

            # Extract information from XML content
            extracted_info = extract_information_from_xml(document['content'])

            # Get the title (filename without the .xml extension)
            title = document['filename'].rsplit('.', 1)[0]

            # Insert the dictionary to the JSON_RAW_COLLECTION
            json_document = {
                'mongo_id': mongo_id,
                'title': title,
                'content': extracted_info,
            }

            json_raw_collection.insert_one(json_document)

            print(f"Data inserted into {json_raw_collection.name} for document with ID: {mongo_id}")
        else:
            print(f"Document {mongo_id} already processed, skipping.")
    except Exception as e:
        print(f"Unexpected error: {e}")

In [None]:
# Reminder to myself: add .limit() to the function to fetch a certain number of documents to process
for document in xml_collection.find({}):
    extract_information_and_write_to_json_file(document['_id'], json_raw_collection, xml_collection)

## Preprocessing (from JSON to cleaned JSON that can be given to train the model)

In [None]:
def clean_html_tags(text):
    """
    Remove HTML/XML tags from the given text.
    
    Parameters:
    - text (str): Text to clean.
    
    Returns:
    str: Text with HTML/XML tags removed.
    """
    clean_text = re.sub(r'<[^>]+>', '', text)  # Remove anything within angle brackets
    return clean_text

In [None]:
def standardize_special_characters(text):
    """
    Standardize special characters in the text, such as converting different types of quotation marks
    to a standard form, and optionally removing characters that are not beneficial for the model.
    
    Parameters:
    - text (str): Text to process.
    
    Returns:
    str: Text with standardized special characters.
    """
    # Standardize quotation marks and apostrophes
    text = text.replace('“', '"').replace('”', '"').replace("‘", "'").replace("’", "'")
    
    # Remove or replace other special characters as needed, e.g.:
    text = text.replace('—', '-')  # Replace long dashes with short ones
    
    return text

In [None]:
def replace_figures_tables_references(text):
    """
    Replace or remove references to figures and tables in the text.
    
    Parameters:
    - text (str): Text to process.
    
    Returns:
    str: Text with references to figures and tables handled.
    """
    
    # This regex targets common patterns like Figure 1, Fig. 1, Table 1, etc.
    text_without_references = re.sub(r'(Figure|Fig\.|Table)\s+\d+', '', text)
    
    return text_without_references

In [None]:
def remove_bibliography_references(text):
    """
    Remove references to the bibliography in the text, typically formatted as [1] or [1,2] or [1-3].
    
    Parameters:
    - text (str): Text to process.
    
    Returns:
    str: Text with bibliography references removed.
    """
    # This regex matches patterns like [1], [1,2], [1-3], etc.
    text_without_references = re.sub(r'\[\d+(-\d+)?(,\d+(-\d+)?)*\]', '', text)
    
    return text_without_references

In [None]:
def normalize_whitespace(text):
    """
    Normalize the whitespace in the text by replacing multiple spaces or line breaks with a single space
    and trimming leading and trailing spaces.
    
    Parameters:
    - text (str): Text to process.
    
    Returns:
    str: Text with normalized whitespace.
    """
    # Replace multiple spaces with a single space
    text = re.sub(r'\s+', ' ', text)
    
    # Trim leading and trailing spaces
    text = text.strip()
    
    return text

In [None]:
def preprocess_json(original_json_path, output_folder):
    """
    Preprocess the given JSON file by combining all sections into a single body string.
    The result is saved in a specified output folder with the same filename.
    
    Parameters:
    - original_json_path: Path to the original JSON file.
    - output_folder: Folder where the processed JSON should be saved.
    """
    # Create the output folder if it does not exist

    try:
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
            
        with open(original_json_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
    except Exception as e:
        logging.error(f"Error reading {original_json_path}: {e}")
        return None
        
    try:
        # Combine all section contents into a single string, separated by double newlines
        combined_sections = "\n\n".join([section['Content'] for section in data['Sections']])

        combined_sections = normalize_whitespace(remove_bibliography_references(replace_figures_tables_references(standardize_special_characters(clean_html_tags(combined_sections)))))

        # Preprocess Abstract
        cleaned_abstract = data['Abstract']['Detailed Abstract']
        cleaned_abstract = normalize_whitespace(remove_bibliography_references(replace_figures_tables_references(standardize_special_characters(clean_html_tags(cleaned_abstract)))))
        
        # Update the 'Sections' key to a single string containing all combined sections
        processed_data = {
            'Title': data['Title'],
            'Abstract': cleaned_abstract,
            'Body': combined_sections
        }
        
        # Add keywords if present
        if 'Keywords' in data:
            processed_data['Keywords'] = data['Keywords']

        # Construct the output path
        output_path = os.path.join(output_folder, os.path.basename(original_json_path))
        
        # Save the processed data to the new JSON file
        with open(output_path, 'w', encoding='utf-8') as outfile:
            json.dump(processed_data, outfile, ensure_ascii=False, indent=4)

        logging.info(f"Processed file saved to {output_path}")
    except Exception as e:
        logging.error(f"Error processing {original_json_path}: {e}")
        return None
    
    return output_path

In [None]:
json_directory = "./data/json/"
json_processed_output_directory = "./data/json_processed"

# List all the JSON files in the directory
json_files = [f for f in os.listdir(json_directory) if f.endswith('.json')]

# Loop through each file and process it
for json_file in json_files:
    preprocess_json(json_directory+json_file, json_processed_output_directory)

## Preprocessing (from cleaned JSON to Model Input)
For each paper in the dataset, we want to tokenize both the body (as input) and the abstract (as the target summary) using the BART tokenizer. This involves converting the text to model-compatible input IDs. We'll also need to manage the token length due to model limitations.

In [None]:
def tokenize_for_bart(text, tokenizer, max_length=1024):
    """
    Tokenizes the text, ensuring it does not exceed the maximum length for BART.
    """
    # Tokenize and truncate or pad the sequence to be of 'max_length'
    tokens = tokenizer.encode_plus(text, max_length=max_length, truncation=True, padding="max_length", return_tensors="pt")
    return tokens

The following custom dataset class takes a list of file paths (data_files), loads the data, and stores body-abstract pairs. When accessed, it returns tokenized versions of these pairs, suitable for training.

In [None]:
class SummarizationDataset(Dataset):
    def __init__(self, tokenizer, data_files, max_length=1024, max_target_length=128):
        self.tokenizer = tokenizer
        self.data = []
        self.max_length = max_length
        self.max_target_length = max_target_length

        for file_path in data_files:
            with open(file_path, 'r', encoding='utf-8') as f:
                paper = json.load(f)
                self.data.append((paper['Body'], paper['Abstract']))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        body_text, abstract_text = self.data[idx]

        # Tokenize body text
        body_tokens = self.tokenizer.encode_plus(
            body_text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt"
        )

        # Tokenize abstract (summary)
        abstract_tokens = self.tokenizer.encode_plus(
            abstract_text, max_length=self.max_target_length, truncation=True, padding="max_length", return_tensors="pt"
        )

        return {
            "input_ids": body_tokens['input_ids'].squeeze(),  # Squeeze to remove batch dimension
            "attention_mask": body_tokens['attention_mask'].squeeze(),
            "labels": abstract_tokens['input_ids'].squeeze(),
            # BART uses -100 to ignore padding in labels
            "decoder_attention_mask": abstract_tokens['attention_mask'].squeeze()
        }

Assuming we have a list of JSON file paths we want to use for training (see previous cells), we create an instance of our dataset and then a DataLoader.

In [None]:
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

clean_json_directory = "./data/json_processed/"

# This code will list all JSON files in clean_json_directory and create full paths to each file. 
# These paths are then passed to the SummarizationDataset to be processed.
data_files = [os.path.join(clean_json_directory, f) for f in os.listdir(clean_json_directory) if f.endswith('.json')]

# The SummarizationDataset class is expected to loop over data_files, which should be a list of 
# file paths, and process each file by tokenizing its contents and preparing them for training the BART model. 
dataset = SummarizationDataset(tokenizer, data_files)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Prepare optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [None]:
# Training loop
model.train()
num_epochs = 3

for epoch in range(num_epochs):
    epoch_loss = 0
    for batch in tqdm(dataloader, desc=f"Training Epoch {epoch + 1}"):
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        labels = batch["labels"].to(model.device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    
    print(f"Epoch {epoch + 1} completed. Average loss: {epoch_loss / len(dataloader)}")

In [None]:
# Save the model
model.save_pretrained('./model/')
tokenizer.save_pretrained('./model/')

## Evaluation

### Preparing the Test Set

In [None]:
xml_test_directory = "./data/test/xml/"
output_test_directory = "./data/test/json"

# List all the XML files in the directory
xml_test_files = [f for f in os.listdir(xml_test_directory) if f.endswith('.xml')]

# Loop through each file and process it
for xml_test_file in xml_test_files:
    extract_information_and_write_to_json_file(xml_test_directory+xml_test_file, output_test_directory)

In [None]:
json_test_directory = "./data/test/json/"
json_test_processed_output_directory = "./data/test/json_processed"

# List all the JSON files in the directory
json_test_files = [f for f in os.listdir(json_test_directory) if f.endswith('.json')]

# Loop through each file and process it
for json_test_file in json_test_files:
    preprocess_json(json_test_directory+json_test_file, json_test_processed_output_directory)

We will use the 'SummarizationDataset' class we've created to process the test set. We'll change a little bit, because this time no labels are required, we don't have shuffling and we need to put the model in evaluation mode.

In [None]:
# Function to process the test set using the SummarizationDataset class
def process_test_set(tokenizer, test_files, batch_size=2):
    test_dataset = SummarizationDataset(tokenizer, test_files, max_length=1024, max_target_length=128)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_dataloader

# Get the list of file paths for your test JSON files
test_json_directory = "./data/test/json_processed/"
test_data_files = [os.path.join(test_json_directory, f) for f in os.listdir(test_json_directory) if f.endswith('.json')]

# Process the test set
test_dataloader = process_test_set(tokenizer, test_data_files)

# Model in evaluation mode
model.eval()

### Generate Summaries

In [None]:
# Loading the model
model = BartForConditionalGeneration.from_pretrained('./model/')

In [None]:
# Function to generate summaries for the test set with progress logging
def generate_summaries_with_logging(model, dataloader, device):
    model.to(device)
    model.eval()
    summaries = []
    actuals = []
    progress_bar = tqdm(dataloader, desc='Generating Summaries')
    
    with torch.no_grad():
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            # Labels are actually the target summaries (abstracts)
            labels = batch['labels'].to(device)  

            # Generate summaries
            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
            
            # Convert generated ids to text
            batch_summaries = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs]
            summaries.extend(batch_summaries)
            
            # Convert labels ids to text for actual summaries
            batch_actuals = [tokenizer.decode(a, skip_special_tokens=True, clean_up_tokenization_spaces=False) for a in labels]
            actuals.extend(batch_actuals)

            # Update progress bar description with latest information
            progress_bar.set_description(f"Processed {len(summaries)} summaries")

    return summaries, actuals

In [None]:
# Use a CUDA device if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate summaries with logging
generated_summaries, actual_abstracts = generate_summaries_with_logging(model, test_dataloader, device)

### Calculate ROUGE Scores

In [None]:
from rouge_score import rouge_scorer, scoring

# Function to calculate ROUGE scores
def calculate_rouge_scores(references, predictions):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    aggregator = scoring.BootstrapAggregator()
    
    for reference, prediction in zip(references, predictions):
        scores = scorer.score(reference, prediction)
        aggregator.add_scores(scores)
    
    result = aggregator.aggregate()
    
    # Accessing the results
    for key in result.keys():
        print(f"{key}: F1: {result[key].mid.fmeasure:.3f}, Precision: {result[key].mid.precision:.3f}, Recall: {result[key].mid.recall:.3f}")

In [None]:
calculate_rouge_scores(actual_abstracts, generated_summaries)

### Qualitative Testing

In [None]:
import numpy as np

def process_json_file(file_path, tokenizer):
    with open(file_path, 'r', encoding='utf-8') as file:
        paper = json.load(file)
    body_text = paper['Body']  # Get the body text
    # Tokenize the body text
    inputs = tokenizer(body_text, return_tensors="pt", max_length=1024, truncation=True, padding="max_length")
    return inputs

# Directory where the test JSON files are stored
test_json_directory = "./data/test/json_processed/"
# List all the JSON files in the directory
test_data_files = [os.path.join(test_json_directory, f) for f in os.listdir(test_json_directory) if f.endswith('.json')]

# Now process and tokenize a random file from the test set
random_paper_index = np.random.randint(len(test_data_files))
random_paper_path = test_data_files[random_paper_index]
tokenized_inputs = process_json_file(random_paper_path, tokenizer)

In [None]:
# Load the model
model = BartForConditionalGeneration.from_pretrained('./model/')

def generate_summary(model, tokenizer, tokenized_inputs):
    model.eval()  # Put model in evaluation mode
    # Move to the appropriate device (CPU or GPU)
    tokenized_inputs = {key: value.to(model.device) for key, value in tokenized_inputs.items()}
    with torch.no_grad():
        summary_ids = model.generate(**tokenized_inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# Generate the summary
generated_summary = generate_summary(model, tokenizer, tokenized_inputs)

# Load the actual abstract for comparison
with open(random_paper_path, 'r', encoding='utf-8') as file:
    actual_paper = json.load(file)
actual_abstract = actual_paper['Abstract']

# Output the generated summary and the actual abstract
print("Generated Summary:\n", generated_summary)
print("\nActual Abstract:\n", actual_abstract)