In [1]:
import pandas as pd
import numpy as np
import nltk
from datasets import Dataset, DatasetDict, load_metric, load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, TrainingArguments
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from evaluate import load
import argparse

metric_rouge = load("rouge")
metric_bertscore = load("bertscore")
metric_sari = load("sari")

def compute_metrics(eval_pred):
    predictions, labels, sources = eval_pred

    if isinstance(predictions, tuple):
        predictions = predictions[0]
        print("preds again", predictions)
        
    # Replace -100 in the labels and sources as we can't decode them.
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        
    sources = np.where(sources != -100, sources, tokenizer.pad_token_id)
    decoded_inputs = tokenizer.batch_decode(sources, skip_special_tokens=True)
    
    # Tokenize and clean
    decoded_preds_newln = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_preds_space = [ " ".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_label_newln = ["\n".join(nltk.sent_tokenize(labl.strip())) for labl in decoded_labels]
    decoded_label_space = [ " ".join(nltk.sent_tokenize(labl.strip())) for labl in decoded_labels]
    decoded_input_space = [ " ".join(nltk.sent_tokenize(inpt.strip())) for inpt in decoded_inputs]
    
    result_rouge = metric_rouge.compute(predictions=decoded_preds_newln, references=decoded_label_newln, use_stemmer=True)
    result_berts = metric_bertscore.compute(predictions=decoded_preds_space, references=decoded_label_space, lang="en")
    result_sari  = metric_sari.compute(sources=decoded_input_space, predictions=decoded_preds_space, references=[[i] for i in decoded_label_space])

    # Extract results
    result = result_rouge # {key: value.mid.fmeasure * 100 for key, value in result_rouge.items()}
    result['bert_score'] = np.mean(result_berts['f1'])
    result['sari']       = result_sari['sari']
    prediction_lens      = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"]    = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

# Get dataset from arguments
# parser = argparse.ArgumentParser()
# parser.add_argument("--dataset", required=True)
# parser.add_argument("--lr", required=True)
# parser.add_argument("--epochs", required=True)
# parser.add_argument("--batch_size", required=True)
# parser.add_argument("--checkpoint", default=None, required=False)
# parser.add_argument("--model", required=True)
# parser.add_argument("--predict_only", required=False, default=False, type=bool)
# args = parser.parse_args()
# print(f"Using dataset: {args.dataset}, Args: {args.lr} (lr), {args.epochs} (epochs), {args.batch_size} (batch_size), {args.checkpoint} (checkpoint)")

DATASET_NAME    = 'radiology_indiv' # args.dataset 
dataset         = load_dataset('json', data_files=f'data/{DATASET_NAME}.json', field='train')
dataset['test'] = load_dataset('json', data_files=f'data/{DATASET_NAME}_multiple.json', field='test')['train']


Using custom data configuration default-dfd9fd8626c408cd
Reusing dataset json (/home/lily/lyf6/.cache/huggingface/datasets/json/default-dfd9fd8626c408cd/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-1380f428cc862734
Reusing dataset json (/home/lily/lyf6/.cache/huggingface/datasets/json/default-1380f428cc862734/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)


  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
dataset['train']

Dataset({
    features: ['input', 'labels', 'vocab', 'report_id'],
    num_rows: 8761
})

In [7]:
# Load in the model and tokenizer, for this we're using BART, which is good at generation tasks
MODEL_NAME = "BART"
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
    
def preprocess_function(examples):
    """This function takes a batch of samples, and tokenizes them into IDs for the model
       It does this by adding new arguments to the Dataset dictionary, namely
       - input_ids:      tokenized IDs of the findings
       - attention_mask: mask that tells us which tokens are words and which are padding
       - labels:         tokenized IDs of the impressions
    Args:
        examples (Dataset): {'Findings':[<list of findings texts>],
                             'Impressions':[[<list of impressions texts>] per item]}

    Returns:
        model_inputs (Dataset): {'Findings':      [<list of findings texts>],
                                 'Impressions':   [<list of impressions texts>],
                                 'input_ids':     list of lists with impressions IDs,
                                 'attention_mask':list of lists with impressions IDs masks,
                                 'labels':        list of lists with findings IDs}
    """
    # Tokenize the Findings (the input)
    input_str = examples["input"]
    model_inputs = tokenizer(input_str, max_length=512, padding=True, truncation=True)
    # Tokenize the Impressions (the output)
    labels = tokenizer([lst[0] for lst in examples["labels"]], max_length=512, padding=True, truncation=True)
    # Set the label as the token ids (i.e. the vocab IDs) of the findings
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# We apply the function to all the examples in our train and test datasets
dataset['train'] = dataset['train'].map(preprocess_function, batched=True)
dataset['test']  = dataset['test'].map(preprocess_function, batched=True)

# Remove the original columns
dataset['train'] = dataset['train'].remove_columns(["input"])
dataset['test']  = dataset['test'].remove_columns(["input"])

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['labels', 'vocab', 'report_id', 'input_ids', 'attention_mask'],
        num_rows: 8761
    })
    test: Dataset({
        features: ['labels', 'vocab', 'report_id', 'input_ids', 'attention_mask'],
        num_rows: 1916
    })
})

In [9]:
# Write out the arguments
MODEL_OUT_NAME = f"{MODEL_NAME}_{DATASET_NAME}"

training_args = Seq2SeqTrainingArguments(
    f"models/{MODEL_OUT_NAME}",
    evaluation_strategy = "epoch",
    learning_rate=7e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
    fp16=True,
    include_inputs_for_metrics=True
)

data_collator = DataCollatorForSeq2Seq(tokenizer)

# Create the Trainer and train
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)


In [None]:
# Use the model to generate outputs
test_output = trainer.predict(dataset['test'][:20])


In [21]:
x = trainer.predict(dataset['test'].select(range(10)))

In [22]:
y = trainer.predict(dataset['test'].remove_columns(['vocab','report_id']).select(range(10)))

In [31]:
tokenizer.batch_decode(dataset['train'][:10]['input_ids'])

['<s>There are bilateral pulmonary nodules measuring up to 8 mm, for example in the right middle lobe, stable compared to prior exam but no additional exam before 2021 to document stability.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s>Status post median sternotomy and CABG.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s>There is skin thickening with fat stranding along the right anterior chest without a discrete fluid collection to suggest an abscess.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

In [25]:
tokenizer.batch_decode(x.predictions)

['</s><s>Significantly improved pulmonary metastatic disease.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '</s><s>An 11 mm nodule at the right lung base previously measured 15 mm.</s><pad><pad>',
 '</s><s>Calcified left and right lymph nodes are noted.</s><pad><pad><pad><pad><pad><pad>',
 '</s><s>Prior thyroidectomy.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '</s><s>Diffuse mild bronchiectasis, areas of mucus impaction, and bron</s>',
 '</s><s>Severe aortic valvular calcifications.</s><pad><pad><pad><pad><pad><pad>',
 '</s><s>Severe mitral annular calcifications.</s><pad><pad><pad><pad><pad><pad><pad><pad>',
 '</s><s>Lungs/Airways/Pleura: Stable scattered pulmonary nodules</s>',
 '</s><s>Previously seen nodules are stable in size (example: Image 195, series 4).</s>',
 '</s><s>Mediastinum/Lymph nodes: Heterogeneous and enlarged thyroid gland</s>']

In [26]:
tokenizer.batch_decode(y.predictions)

['</s><s>Significantly improved pulmonary metastatic disease.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '</s><s>An 11 mm nodule at the right lung base previously measured 15 mm.</s><pad><pad>',
 '</s><s>Calcified left and right lymph nodes are noted.</s><pad><pad><pad><pad><pad><pad>',
 '</s><s>Prior thyroidectomy.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '</s><s>Diffuse mild bronchiectasis, areas of mucus impaction, and bron</s>',
 '</s><s>Severe aortic valvular calcifications.</s><pad><pad><pad><pad><pad><pad>',
 '</s><s>Severe mitral annular calcifications.</s><pad><pad><pad><pad><pad><pad><pad><pad>',
 '</s><s>Lungs/Airways/Pleura: Stable scattered pulmonary nodules</s>',
 '</s><s>Previously seen nodules are stable in size (example: Image 195, series 4).</s>',
 '</s><s>Mediastinum/Lymph nodes: Heterogeneous and enlarged thyroid gland</s>']

In [None]:
test_output = tokenizer.batch_decode(test_output.predictions)
test_output = list(map(lambda s: s.replace('<s>','').replace('</s>','').replace('<pad>',''), test_output))
