In [1]:
from typing import List
import string
from abc import ABC
import csv
from collections import defaultdict
import statistics
import numpy as np
import pandas as pd
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from rouge_score import rouge_scorer
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
                          LogitsProcessor, LogitsProcessorList)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# NLTK Import
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\vince\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


### Load LogitsProcessor for Beam Search Constrainment

We implement the LogitsProcessor class to get our desired effect. Our custom class should implement the __call__ method of LogitsProcessor.

This method will be called during each step of the beam search algorithm. The method takes as input the input_ids sequence of the partially generated beam and the scores of the next possible tokens.

By manipulating these scores based on the tokens present in the input_ids, we can control the structure of the generated sentence.

In [3]:
class WordValidator(ABC):
    def __init__(self):
        pass

class BannedWords(WordValidator):
    def __init__(self, dictionary):
        self.dictionary = dictionary

    def is_valid_word(self, word, input_idx, beam_sequence, beam_scores):
        return word not in self.dictionary

In [37]:
SPLIT_WORD_TOKENS = {
    ' ',
    '.',
    ',',
    '_',
    '?',
    '!',
}

class ConsistentLogitsProcessor(LogitsProcessor):
    r"""
    [`ConsistentLogitsProcessor`] enforcing constraints from source documentation on logits
    Args:
        min_length (`int`):
            The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
        eos_token_id (`int`):
            The id of the *end-of-sequence* token.
    """

    def __init__(self, tokenizer, num_beams, word_validator: WordValidator):
        self.tokenizer = tokenizer
        self.word_validator = word_validator
        self.num_beams = num_beams
        self.excluded_beams_by_input_idx = defaultdict(list)
        self.words_to_check_by_input_idx = defaultdict(lambda: 0)
        self.failed_sequences = set()

    def is_valid_beam(
        self,
        input_idx, # input idx being processed
        sequence,  # sequence generated so far
        token_id,  # next token to be generated
        beam_scores,  # probability of all tokens to be generated
    ):
        """
        Check whether the beam is valid. This method backtracks to confirm
        words are valid when it detects the predicted suword (token) is a
        word ending
        """
        
        current_subword = self.tokenizer.decode(token_id)
        backtrack_word = ""
        is_subword_ending = False
        for char in current_subword:
            if char in SPLIT_WORD_TOKENS:
                is_subword_ending = True
                break
            else:
                backtrack_word += char
        
        backtrack_done = False
        if is_subword_ending:
            prev_subword_idx = len(sequence) - 1
            while prev_subword_idx != 0 and not backtrack_done:
                prev_token_id = sequence[prev_subword_idx]
                prev_subword = self.tokenizer.decode(prev_token_id)
                prev_char_idx = len(prev_subword) - 1
                while prev_char_idx >= 0:
                    prev_char = prev_subword[prev_char_idx]
                    if prev_char not in SPLIT_WORD_TOKENS:
                        backtrack_word = prev_char + backtrack_word
                    else:
                        backtrack_done = True
                        break 
                    prev_char_idx -= 1
                prev_subword_idx -= 1
            self.words_to_check_by_input_idx[input_idx] += 1
            # Call validator to check whether the word is valid
            if not self.word_validator.is_valid_word(
                backtrack_word,
                input_idx, 
                sequence, 
                beam_scores
            ):
                return False
        return True
            
                

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        blocked_beams_by_input_idx = defaultdict(lambda: 0)
        # for every beam (partially generated sentence)
        for beam_idx, (beam_input_ids, beam_scores) in enumerate(
            zip(input_ids, scores)
        ):
            top_k = beam_scores.topk(k=5)
            for prob, idx in zip(top_k[0], top_k[1]):
                input_idx = beam_idx // self.num_beams
                if not self.is_valid_beam(
                    input_idx, beam_input_ids, idx.item(), scores[beam_idx]
                ):
                    scores[beam_idx, :] = -float("inf")
                    self.excluded_beams_by_input_idx[input_idx].append((
                        beam_input_ids,
                        idx.item(),
                        prob.item(),
                    ))
                    blocked_beams_by_input_idx[input_idx] += 1
        
        for input_idx, n_blocked in blocked_beams_by_input_idx.items():
            if n_blocked == self.num_beams:
                self.failed_sequences.add(input_idx)
                    
        return scores

In [5]:
def entropy(p_dist: torch.Tensor) -> float:
    """ "
    Calculates Shannon entropy for a probability distribution
    Args:
        p_dist: probability distribution (torch.Tensor)
    Returns:
        entropy (float)
    """
    # add epsilon because log(0) = nan
    p_dist = p_dist.view(-1) + 1e-12
    return -torch.mul(p_dist, p_dist.log()).sum(0).item()


def generate_summaries_with_constraints(
    model: AutoModelForSeq2SeqLM,
    tokenizer: AutoTokenizer,
    docs_to_summarize: List[str],
    word_validator: WordValidator,
    num_beams: int = 4,
    max_length: int = 150,
    return_beam_metadata: bool = False,
    device: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
):
    inputs = tokenizer(
        docs_to_summarize,
        max_length=1024,
        truncation=True,
        return_tensors="pt",
        padding=True,
    )
    input_token_ids = inputs.input_ids.to(device)
    consistency_forced = ConsistentLogitsProcessor(
        tokenizer,
        num_beams,
        word_validator
    )
    model_output = model.generate(
        input_token_ids,
        num_beams=num_beams,
        early_stopping=True,
        return_dict_in_generate=True,
        output_scores=True,
        # remove_invalid_values=True,
        logits_processor=LogitsProcessorList([consistency_forced]),
        max_length=max_length,
    )

    generated_summaries = [
        (
            tokenizer.decode(
                ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            if (
                consistency_forced is None
                or idx not in consistency_forced.failed_sequences
            )
            else "<Failed generation: blocked all beams>"
        )
        for idx, ids in enumerate(model_output.sequences)
    ]
    
    if not return_beam_metadata:
        return generated_summaries
    
    else:
        # reshape model_output scores to (n_seqs x seq len x n_beams x vocab)
        model_beam_scores = (
            torch.stack(model_output.scores)
            .reshape(len(model_output.scores), len(generated_summaries), num_beams, -1)
            .permute(1, 0, 2, 3)
        )
        # Collect Beam Search Metadata
        beams_metadata = []
        if model_output.beam_indices is not None:
            for seq_idx in range(model_output.sequences.shape[0]):
                top_beam_indices = [x.item() for x in model_output.beam_indices[seq_idx]]
                seq_beams = {
                    "beams": [list() for _ in range(num_beams)],
                    "selected_beam_indices": top_beam_indices,
                    "dropped_seqs": consistency_forced.excluded_beams_by_input_idx[seq_idx],
                    "n_words_checked": consistency_forced.words_to_check_by_input_idx[seq_idx],
                }
                beams_metadata.append(seq_beams)

                for idx, output_token_id in enumerate(model_output.sequences[seq_idx][1:]):
                    for beam_idx in range(num_beams):
                        beam_probs = torch.exp(model_beam_scores[seq_idx][idx][beam_idx])
                        beam_top_alternatives = []
                        top_probs = torch.topk(beam_probs, k=num_beams)
                        for i, v in zip(top_probs.indices, top_probs.values):
                            beam_top_alternatives.append(
                                {
                                    "token": tokenizer.decode(i),
                                    "token_id": i.item(),
                                    "probability": v.item(),
                                }
                            )
                        seq_beams["beams"][beam_idx].append(
                            {
                                "top_tokens": beam_top_alternatives,
                                "entropy": entropy(beam_probs),
                            }
                        )

        return generated_summaries, beams_metadata

### Generate a summary without beam search constraints

In [6]:
def generate_summaries_without_constraints(
    model: AutoModelForSeq2SeqLM,
    tokenizer: AutoTokenizer,
    docs_to_summarize: List[str],
    num_beams: int = 4,
    max_length: int = 150,
    device: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
):
    inputs = tokenizer(docs_to_summarize, max_length=1024, return_tensors='pt', truncation=True,)
    ids = model.generate(inputs['input_ids'].to(device), num_beams=num_beams, max_length=max_length, early_stopping=True)
    summaries = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in ids]
    return summaries

### Load HPI Summarization Models
The first sentence of the hospital course section

In [8]:
df = pd.read_csv('../data/hpi-dataset/summarization-datasets/HPI_TEST.csv')

In [9]:
def load_model_and_tokenizer(path):
    return (
        AutoModelForSeq2SeqLM.from_pretrained(path),
        AutoTokenizer.from_pretrained(path),
    )

In [10]:
# Load the model and tokenizer
model, tokenizer = load_model_and_tokenizer('../models/HPI-BART')

### Create a Medical Dictionary from SNOMED

In [11]:
def to_words(text):
    res = re.findall(r'\b[^\d\W]+\b', text)
    return [x.lower() for x in res]

In [18]:
with open('../data/snomed/medical_dictionary.txt', 'r',encoding="ISO-8859-1") as fd:
    medical_words = []
    for row in fd:
        medical_words.append(row.rstrip())

In [19]:
len(medical_words)

79521

### Set HPI summary we will use

In [21]:
text = df.loc[2,"TEXT"] #14 is a good example
true_summary = df.loc[2,"SUMMARY"] 

In [22]:
text

'CATEGORY: Physician Attending Admission Note - MICU, TYPE: EMERGENCY, ADMIT LOCATION: EMERGENCY ROOM ADMIT, AGE: 81, GENDER: M, MARITAL STATUS: WIDOWED, ETHNICITY: WHITE, ADMIT DX: LOWER GI BLEED, ADMIT TEXT: 81 yr old man with remote hx of colon ca, diverticulitis, presents with 4 days of BRBPR and LLQ pain. This AM home health aide noted blood in stool. In ED afeb 129/83 but HR 100s rectal BRB Abd'

### Set Banned Words as Medical Words Not in Source Text

In [23]:
words_in_text = list(set(to_words(text)))

In [24]:
medical_in_text = [x for x in medical_words if x in words_in_text]

In [25]:
print(len(medical_in_text))

5


In [26]:
medical_in_text

['diverticulitis', 'colon', 'brb', 'brbpr', 'rectal']

In [27]:
# allow synonyms from NLTK wordnet
medical_synonyms_in_text = []
for word in medical_in_text:
  for syn in wordnet.synsets(word):
    for i in syn.lemmas():
      medical_synonyms_in_text.append(i.name())
medical_synonyms_in_text = medical_in_text + list(set(medical_synonyms_in_text))

In [28]:
banned_words = BannedWords(set([x for x in medical_words if x not in medical_synonyms_in_text]))

In [29]:
print(len(banned_words.dictionary))

79516


### Run Summarization with Constraints

In [30]:
model.to(device)
pred_summaries_notconstrained = generate_summaries_without_constraints(
    model,
    tokenizer,
    [text],
    num_beams=6,
    max_length=150
)

In [38]:
model.to(device)
pred_summaries_constrained, metadata_constrained = generate_summaries_with_constraints(
    model,
    tokenizer,
    [text],
    word_validator = banned_words,
    num_beams=6,
    max_length = 150,
    return_beam_metadata=True,
)

In [39]:
true_summary

'Acute blood loss anemia [**1-9**] GIB: The patient was initially admitted to the MICU for a GI bleed.'

In [40]:
pred_summaries_notconstrained

['Acute Blood Loss Anemia due to GI Bleeding due to Duodenal Ulcer - GI Consultation was obtained.Patient was initially admitted to the MICU and given his history of BRBPR and diverticulitis, was taken to the']

In [41]:
pred_summaries_constrained

['Acute Blood Loss Anemia due to GI Bleeding due to Duodenal Ulcers: The patient was initially admitted with blood in the rectal BRBPR and had an upper GI Bleed which was treated with Vaseline and a band anast']

In [42]:
scorer = rouge_scorer.RougeScorer(['rouge1','rouge2','rougeL'], use_stemmer=True)
scores_nc = scorer.score(true_summary, pred_summaries_notconstrained[0])
scores_c = scorer.score(true_summary, pred_summaries_constrained[0])
print("ROUGE NOT CONSTRAINED:")
print('ROUGE1": ',round(scores_nc['rouge1'][1],3))
print('ROUGE2": ',round(scores_nc['rouge2'][1],3))
print('ROUGEL": ',round(scores_nc['rougeL'][1],3))
print("/n")
print("ROUGE CONSTRAINED:")
print('ROUGE1": ',round(scores_c['rouge1'][1],3))
print('ROUGE2": ',round(scores_c['rouge2'][1],3))
print('ROUGEL": ',round(scores_c['rougeL'][1],3))

ROUGE NOT CONSTRAINED:
ROUGE1":  0.737
ROUGE2":  0.556
ROUGEL":  0.579
/n
ROUGE CONSTRAINED:
ROUGE1":  0.737
ROUGE2":  0.444
ROUGEL":  0.632


### RUN EVALUATION

In [43]:
def add_breaks(text):
  return re.sub("(.{50})", "\\1\n", text, 0, re.DOTALL)

In [None]:
from pathlib import Path
Path("../results").mkdir(parents=True, exist_ok=True)

In [None]:
model.to(device)

rouge_scores_nc = {'r1': [], 'r2': [], 'rL': []}
rouge_scores_c = {'r1': [], 'r2': [], 'rL': []}

thrown_out = 0

for id in range(0, len(df)):
    text = df.loc[id,"TEXT"]
    true_summary = df.loc[id,"SUMMARY"]
    words_in_text = list(set(to_words(text)))
    medical_in_text = [x for x in medical_words if x in words_in_text]

    # allow synonyms from NLTK wordnet
    medical_synonyms_in_text = []
    for word in medical_in_text:
      for syn in wordnet.synsets(word):
        for i in syn.lemmas():
          medical_synonyms_in_text.append(i.name())
    medical_synonyms_in_text = medical_in_text + list(set(medical_synonyms_in_text))
    
    banned_words = BannedWords(set([x for x in medical_words if x not in medical_synonyms_in_text]))
    
    pred_summaries_notconstrained = generate_summaries_without_constraints(
        model,
        tokenizer,
        [text],
        num_beams=6,
        max_length=150
    )
    
    pred_summaries_constrained, metadata_constrained = generate_summaries_with_constraints(
        model,
        tokenizer,
        [text],
        word_validator = banned_words,
        num_beams=6,
        max_length = 150,
        return_beam_metadata=True
    )
    
    if id % 10 == 0:
        print(str(id) + " of " + str(len(df)) + " summaries processed.")
    
    if pred_summaries_constrained[0] != "<Failed generation: blocked all beams>":
        scorer = rouge_scorer.RougeScorer(['rouge1','rouge2','rougeL'], use_stemmer=True)
        scores_nc = scorer.score(true_summary, pred_summaries_notconstrained[0])
        scores_c = scorer.score(true_summary, pred_summaries_constrained[0])
        rouge_scores_nc['r1'].append(scores_nc['rouge1'][1])
        rouge_scores_nc['r2'].append(scores_nc['rouge2'][1])
        rouge_scores_nc['rL'].append(scores_nc['rougeL'][1])
        rouge_scores_c['r1'].append(scores_c['rouge1'][1])
        rouge_scores_c['r2'].append(scores_c['rouge2'][1])
        rouge_scores_c['rL'].append(scores_c['rougeL'][1])

        #save summaries
        text_file = 'ADMISSION NOTE:\n'
        text_file += add_breaks(text)
        text_file += '\n\nGOLD TRUTH SUMMARY:\n'
        text_file += add_breaks(true_summary)
        text_file += '\n\nNON-CONSTRAINED SUMMARY:\n'
        text_file += add_breaks(pred_summaries_notconstrained[0])
        text_file += '\n\nCONSTRAINED SUMMARY:\n'
        text_file += add_breaks(pred_summaries_constrained[0])

        file_path = "../results/"
        file_path += "file" + str(id + 1) + ".txt"
        with open(file_path, "w") as f:
          f.write(text_file)

    else:
        thrown_out += 1
print("Summaries thrown out: " + str(thrown_out))

0 of 560 summaries processed.
10 of 560 summaries processed.


In [None]:
print("ROUGE NOT CONSTRAINED:")
print('ROUGE1": ',round(statistics.mean(rouge_scores_nc['r1']),3))
print('ROUGE2": ',round(statistics.mean(rouge_scores_nc['r2']),3))
print('ROUGEL": ',round(statistics.mean(rouge_scores_nc['rL']),3))
print("/n")
print("ROUGE CONSTRAINED:")
print('ROUGE1": ',round(statistics.mean(rouge_scores_c['r1']),3))
print('ROUGE2": ',round(statistics.mean(rouge_scores_c['r2']),3))
print('ROUGEL": ',round(statistics.mean(rouge_scores_c['rL']),3))

In [None]:
# Examples

