In [1]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM

#checkpoint = "facebook/bart-base"
#checkpoint = "facebook/bart-large"
checkpoint = "GanjinZero/biobart-base"
#checkpoint = "GanjinZero/biobart-large"
#checkpoint = "GanjinZero/biobart-v2-base"
#checkpoint = "GanjinZero/biobart-v2-large"
#dataset_config = "mimic-iii"
dataset_config = "mimic-cxr"

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [2]:
from pathlib import Path
import datasets
#dataset_config = 'mimic-cxr','mimic-iii'  
#split = 'train','validate',test

def build_dataset(dataset_config, tokenizer, split):
    data_path = '/nfs/turbo/umms-vgvinodv/data/bioNLP23-Task-1B/data/'
    findings_file_path = Path(data_path).joinpath(dataset_config).joinpath(split+'.findings.tok')
    impression_file_path = Path(data_path).joinpath(dataset_config).joinpath(split+'.impression.tok')

    findings = [line.strip() for line in open(findings_file_path).readlines()]
    impression = [line.strip() for line in open(impression_file_path).readlines()]

    dataset = datasets.Dataset.from_dict({"text":findings,"summary":impression}) 
    
    
    def preprocess_function(samples):
        texts = samples["text"]
        summaries = samples["summary"]
        prompt = " The main impression based on the given FINDINGS section of the chest X-ray report are:"
        #prompt = "summarize: "

        inputs = [_text+prompt for _text in texts]
        model_inputs = tokenizer(inputs)
        
        labels = tokenizer(text_target=summaries, max_length=1024, truncation=True)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    dataset = dataset.map(preprocess_function, batched=True, num_proc=4, remove_columns=list(dataset.features))

    return dataset

In [3]:
tokenized_train_data = build_dataset(dataset_config,tokenizer,"train")
tokenized_eval_data = build_dataset(dataset_config,tokenizer,"test")

Map (num_proc=4):   0%|          | 0/125417 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/1624 [00:00<?, ? examples/s]

In [4]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [5]:
import evaluate
from radgraph import F1RadGraph
from f1chexbert import F1CheXbert

rouge = evaluate.load("rouge")
f1radgraph = F1RadGraph(reward_level="partial")
f1chexbert = F1CheXbert(device="cuda")

In [6]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result["F1RadGraph"] = f1radgraph(hyps=decoded_preds, refs=decoded_labels)[0]
    
    class_report_5 = f1chexbert(hyps=decoded_preds,refs=decoded_labels)[-1]
    result["F1CheXbert"] = class_report_5["micro avg"]["f1-score"]

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [7]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

model_name = checkpoint.split("/")[-1]
batch_size = 16
num_train_epochs = 20 #5
save_path: str="/nfs/turbo/umms-vgvinodv/models/finetuned-checkpoints/radsum"
save_path = f"{save_path}/{model_name}-{dataset_config}"

training_args = Seq2SeqTrainingArguments(
    output_dir=save_path,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    overwrite_output_dir = True,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    fp16=True,
    #push_to_hub=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_eval_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,F1radgraph,F1chexbert,Gen Len
1,1.0371,1.150739,0.4101,0.2774,0.3857,0.3862,0.3552,0.5605,15.9828
2,0.9415,1.100073,0.4335,0.2965,0.4078,0.4076,0.377,0.5822,16.197
3,0.8851,1.066168,0.4355,0.2976,0.4082,0.4082,0.3737,0.5921,15.8319
4,0.8367,1.049886,0.4407,0.3014,0.4122,0.4125,0.3826,0.5901,16.7494
5,0.7945,1.049267,0.4371,0.2977,0.4121,0.412,0.3806,0.5887,15.8947
6,0.7663,1.049064,0.434,0.2961,0.4071,0.4072,0.3766,0.577,15.9748
7,0.7271,1.049103,0.439,0.2979,0.4117,0.4114,0.3835,0.5837,16.5012
8,0.7039,1.054283,0.4478,0.3095,0.4198,0.4192,0.3923,0.5954,16.5246
9,0.6658,1.066923,0.4465,0.3052,0.4167,0.4157,0.3868,0.5924,16.6293
10,0.6482,1.069488,0.4489,0.3114,0.4217,0.4214,0.3909,0.5923,16.6268


TrainOutput(global_step=156780, training_loss=0.6754704514941569, metrics={'train_runtime': 19268.9263, 'train_samples_per_second': 130.175, 'train_steps_per_second': 8.136, 'total_flos': 2.4242863811152896e+17, 'train_loss': 0.6754704514941569, 'epoch': 20.0})

## Evaluate on Hidden Test Set

In [8]:
hidden_test_data = build_dataset(dataset_config,tokenizer,"test.hidden")
trainer.evaluate(hidden_test_data)

Map (num_proc=4):   0%|          | 0/1000 [00:00<?, ? examples/s]

{'eval_loss': 3.865882396697998,
 'eval_rouge1': 0.279,
 'eval_rouge2': 0.1552,
 'eval_rougeL': 0.2517,
 'eval_rougeLsum': 0.2515,
 'eval_F1RadGraph': 0.0968,
 'eval_F1CheXbert': 0.3869,
 'eval_gen_len': 17.875,
 'eval_runtime': 123.0317,
 'eval_samples_per_second': 8.128,
 'eval_steps_per_second': 0.512,
 'epoch': 20.0}