<a href="https://colab.research.google.com/github/oya163/bert-llm/blob/master/medical_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Question Answering based on Medical Transcription using BioBERT Model

In [1]:
!python3 -m pip install -U pip install huggingface_hub
!python3 -m pip install -U pip install accelerate
!python3 -m pip install -U pip install transformers
!python3 -m pip install -U pip install datasets evaluate

Collecting pip
  Downloading pip-23.3.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-23.3.1
Collecting accelerate
  Downloading accelerate-0.24.1-py3-none-any.whl.metadata (18 kB)
Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: accelerate
Successfully installed accelerate-0.24.1
Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl.metadata (20 kB)
Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-non

In [2]:
# Wrap the text in ipython notebook
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

# Data Preprocessing

## Load our BioQA SQuAD dataset

Our BioQA SQuAD dataset is prepared and annotated using [SQuAD dataset annotation](https://github.com/cdqa-suite/cdQA-annotator) tool. And, [custom data loading script](https://huggingface.co/datasets/lhoestq/custom_squad/raw/main/custom_squad.py) is used to load our custom squad dataset. This dataset is taken from [Medical Transcription samples](https://www.mtsamples.com/)

Clone the repo to download the required dataset

In [3]:
!git clone https://github.com/oya163/bert-llm.git

Cloning into 'bert-llm'...
remote: Enumerating objects: 32, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (27/27), done.[K
remote: Total 32 (delta 7), reused 21 (delta 2), pack-reused 0[K
Receiving objects: 100% (32/32), 5.05 MiB | 6.40 MiB/s, done.
Resolving deltas: 100% (7/7), done.


In [4]:
import os
from datasets import load_dataset

data_dir = os.path.join('/content/bert-llm/MedicalQA')
raw_datasets = load_dataset(os.path.join(data_dir, "custom_squad.py"), \
                     data_files={
                          "train": os.path.join(data_dir, "train.json"),
                          "validation": os.path.join(data_dir, "val.json")
                            }
                     )

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [5]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 16
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 8
    })
})

Verify context, question and answer

In [6]:
print("Context: ", raw_datasets["train"][0]["context"])
print("Question: ", raw_datasets["train"][0]["question"])
print("Answer: ", raw_datasets["train"][0]["answers"])

Context:  SUBJECTIVE:, This 23-year-old white female presents with complaint of allergies. She used to have allergies when she lived in Seattle but she thinks they are worse here. In the past, she has tried Claritin, and Zyrtec. Both worked for short time but then seemed to lose effectiveness. She has used Allegra also. She used that last summer and she began using it again two weeks ago. It does not appear to be working very well. She has used over-the-counter sprays but no prescription nasal sprays. She does have asthma but doest not require daily medication for this and does not think it is flaring up.,MEDICATIONS: , Her only medication currently is Ortho Tri-Cyclen and the Allegra.,ALLERGIES: , She has no known medicine allergies.,OBJECTIVE:,Vitals: Weight was 130 pounds and blood pressure 124/78.,HEENT: Her throat was mildly erythematous without exudate. Nasal mucosa was erythematous and swollen. Only clear drainage was seen. TMs were clear.,Neck: Supple without adenopathy.,Lungs:

We only need one answer during the training session, therefore check if any of the context/questions have more than one answer

In [7]:
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

Filter:   0%|          | 0/16 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

In [8]:
print(raw_datasets["validation"][0]["answers"])
print(raw_datasets["validation"][2]["answers"])

{'text': ['56'], 'answer_start': [2566]}
{'text': ['gastric bypass surgery'], 'answer_start': [2725]}


In [9]:
print(raw_datasets["validation"][2]["context"])
print(raw_datasets["validation"][2]["question"])

PAST MEDICAL HISTORY:, Significant for hypertension. The patient takes hydrochlorothiazide for this. She also suffers from high cholesterol and takes Crestor. She also has dry eyes and uses Restasis for this. She denies liver disease, kidney disease, cirrhosis, hepatitis, diabetes mellitus, thyroid disease, bleeding disorders, prior DVT, HIV and gout. She also denies cardiac disease and prior history of cancer.,PAST SURGICAL HISTORY: , Significant for tubal ligation in 1993. She had a hysterectomy done in 2000 and a gallbladder resection done in 2002.,MEDICATIONS: , Crestor 20 mg p.o. daily, hydrochlorothiazide 20 mg p.o. daily, Veramist spray 27.5 mcg daily, Restasis twice a day and ibuprofen two to three times a day.,ALLERGIES TO MEDICATIONS: , Bactrim which causes a rash. The patient denies latex allergy.,SOCIAL HISTORY: , The patient is a life long nonsmoker. She only drinks socially one to two drinks a month. She is employed as a manager at the New York department of taxation. She

Convert the text in the input into IDs so that model can understand by using a tokenizer

In [10]:
from transformers import AutoTokenizer

model_checkpoint = "dmis-lab/biobert-large-cased-v1.1-squad"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/467k [00:00<?, ?B/s]

Tokenizer inserts a special tokens to form a sentence like this

    [CLS] question [SEP] context [SEP]

In [11]:
context = raw_datasets["train"][0]["context"]
question = raw_datasets["train"][0]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] how old is the patient? [SEP] subjective :, this 23 - year - old white female presents with complaint of allergies. she used to have allergies when she lived in seattle but she thinks they are worse here. in the past, she has tried claritin, and zyrtec. both worked for short time but then seemed to lose effectiveness. she has used allegra also. she used that last summer and she began using it again two weeks ago. it does not appear to be working very well. she has used over - the - counter sprays but no prescription nasal sprays. she does have asthma but doest not require daily medication for this and does not think it is flaring up., medications :, her only medication currently is ortho tri - cyclen and the allegra., allergies :, she has no known medicine allergies., objective :, vitals : weight was 130 pounds and blood pressure 124 / 78., heent : her throat was mildly erythematous without exudate. nasal mucosa was erythematous and swollen. only clear drainage was seen. tms wer

Some of the contexts are very long which are beyond the maximum input length of a model. Thus, such contexts need to be truncated as shown below

In [12]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
    print(tokenizer.decode(ids))


[CLS] how old is the patient? [SEP] subjective :, this 23 - year - old white female presents with complaint of allergies. she used to have allergies when she lived in seattle but she thinks they are worse here. in the past, she has tried claritin, and zyrtec. both worked for short time but then seemed to lose effectiveness. she has used allegra also. she used that last summer and she began using it again two weeks ago. [SEP]
[CLS] how old is the patient? [SEP] the past, she has tried claritin, and zyrtec. both worked for short time but then seemed to lose effectiveness. she has used allegra also. she used that last summer and she began using it again two weeks ago. it does not appear to be working very well. she has used over - the - counter sprays but no prescription nasal sprays. she does have asthma but doest not require daily medication for [SEP]
[CLS] how old is the patient? [SEP] she began using it again two weeks ago. it does not appear to be working very well. she has used over

In [13]:
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'overflow_to_sample_mapping'])

In [14]:
print(inputs['overflow_to_sample_mapping'])

[0, 0, 0, 0, 0, 0, 0]


In [15]:
inputs = tokenizer(
    raw_datasets["train"][2:6]["question"],
    raw_datasets["train"][2:6]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

print(f"The 4 examples gave {len(inputs['input_ids'])} features.")
print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

The 4 examples gave 67 features.
Here is where each comes from: [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3].


Once we have those token indices, we look at the corresponding offsets, which are tuples of two integers representing the span of characters inside the original context. We can thus detect if the chunk of the context in this feature starts after the answer or ends before the answer begins (in which case the label is (0, 0)). If that’s not the case, we loop to find the first and last token of the answer:

In [16]:
answers = raw_datasets["train"][2:6]["answers"]
start_positions = []
end_positions = []

for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

print(start_positions, end_positions)

[25, 0, 0, 0, 0, 0, 0, 0, 0, 92, 53, 14, 0, 0, 0, 0, 28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [27, 0, 0, 0, 0, 0, 0, 0, 0, 94, 55, 16, 0, 0, 0, 0, 28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [17]:
start_positions[0], end_positions[0]

(25, 27)

In [18]:
tokenizer.decode(inputs["input_ids"][0][25 : 27 + 1])

'allergies'

Verify if the decoding is correct

In [19]:
idx = 0
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

Theoretical answer: allergies, labels give: allergies


In [20]:
idx = 4
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

decoded_example = tokenizer.decode(inputs["input_ids"][idx])
print(f"Theoretical answer: {answer}, decoded example: {decoded_example}")

Theoretical answer: allergies, decoded example: [CLS] what is the reason for this consultation? [SEP] cyclen and the allegra., allergies :, she has no known medicine allergies., objective :, vitals : weight was 130 pounds and blood pressure 124 / 78., heent : her throat was mildly erythematous without exudate. nasal mucosa was erythematous and swollen. only clear drainage was seen. tms were clear., neck : supple without adenop [SEP]


## Preprocess training dataset

Apply tokenization, padding, truncation, offsets mapping, add special tokens and removing extra spaces

In [21]:
max_length = 384
stride = 128


def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [22]:
train_dataset = raw_datasets["train"].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)
len(raw_datasets["train"]), len(train_dataset)

Map:   0%|          | 0/16 [00:00<?, ? examples/s]

(16, 36)

## Preprocess validation dataset

Preprocessing the validation data is slightly easier as it don’t need to generate labels

In [23]:
def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

In [24]:
validation_dataset = raw_datasets["validation"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)
len(raw_datasets["validation"]), len(validation_dataset)

Map:   0%|          | 0/8 [00:00<?, ? examples/s]

(8, 16)

The model will output logits for the start and end positions of the answer in the input IDs, as we saw during our exploration of the question-answering pipeline. The post-processing step will be similar to what we did there, so here’s a quick reminder of the actions we took:

We masked the start and end logits corresponding to tokens outside of the context.
We then converted the start and end logits into probabilities using a softmax.
We attributed a score to each (start_token, end_token) pair by taking the product of the corresponding two probabilities.
We looked for the pair with the maximum score that yielded a valid answer (e.g., a start_token lower than end_token).

## Prediction

Get the prediction using `distilbert-base-cased-distilled-squad` model on a small dataset and use it for evaluation

In [25]:
small_eval_set = raw_datasets["validation"].select(range(4))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/473 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

In [26]:
import torch
from transformers import AutoModelForQuestionAnswering

# Get the tokenizer back to our actual model
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# Pass our small evaluation set through the
# `distilbert-base-cased-distilled-squad` model
# to get the output logits
eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(
    device
)

with torch.no_grad():
    outputs = trained_model(**batch)

model.safetensors:   0%|          | 0.00/261M [00:00<?, ?B/s]

In [27]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

Now, we need to find the predicted answer for each example in our small_eval_set. One example may have been split into several features in eval_set, so the first step is to map each example in small_eval_set to the corresponding features in eval_set:

In [28]:
import collections

example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(eval_set):
    example_to_features[feature["example_id"]].append(idx)

Once we have all the scored possible answers for one example, we just pick the one with the best logit score


In [29]:
import numpy as np

n_best = 20
max_answer_length = 30
predicted_answers = []


for example in small_eval_set:
    example_id = example["id"]
    context = example["context"]
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = eval_set["offset_mapping"][feature_index]

        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Skip answers that are not fully in the context
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                # Skip answers with a length that is either < 0 or > max_answer_length.
                if (
                    end_index < start_index
                    or end_index - start_index + 1 > max_answer_length
                ):
                    continue

                answers.append(
                    {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                )

    best_answer = max(answers, key=lambda x: x["logit_score"])
    predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})

## Evaluation

Use `squad` evaluation metric from Hugging Face

In [30]:
import evaluate

metric = evaluate.load("squad")

Downloading builder script:   0%|          | 0.00/4.53k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.32k [00:00<?, ?B/s]

Format the ground truth answers in the format required by `evaluate.squad` metric. I have added question as well for now for display purpose. It will removed this `question` key from `theoretical_answers` after displaying and before metric evaluation.

In [31]:
theoretical_answers = [
    {"id": ex["id"], "answers": ex["answers"], "question": ex["question"]} for ex in small_eval_set
]

for x, y in zip(predicted_answers, theoretical_answers):
    print("Question: ", y['question'])
    print("Ground truth ", y['answers']['text'][0])
    print("Predictions ", x['prediction_text'])
    y.pop('question', None)
    print('\n')

Question:  How old is the patient?
Ground truth  56
Predictions  56


Question:  Does the patient have any complaints?
Ground truth  Positive for hot flashes. She also complains about snoring and occasional slight asthma. She does complain about peripheral ankle swelling and heartburn
Predictions  The patient denies latex allergy


Question:  What is the reason for this consultation?
Ground truth  gastric bypass surgery
Predictions  dietician and the psychologist preoperatively


Question:  What other symptoms does the patient have?
Ground truth  hypertension
Predictions  denies latex allergy




In [32]:
metric.compute(predictions=predicted_answers, references=theoretical_answers)

{'exact_match': 25.0, 'f1': 25.0}

Create a `compute_metrics` which will evaluate squad metric during training.
This function will be used in `Trainer` function.

In [33]:
from tqdm.auto import tqdm


def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

Validate `compute_metrics` function is well written and produce the same results as before.

In [34]:
compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

  0%|          | 0/4 [00:00<?, ?it/s]

{'exact_match': 25.0, 'f1': 25.0}

# Fine-tuning

## Data Loader

In [35]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

train_dataset.set_format("torch")
validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"])
validation_set.set_format("torch")

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=2,
)
eval_dataloader = DataLoader(
    validation_set, collate_fn=default_data_collator, batch_size=2
)

Get the pre-trained model

In [36]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

pytorch_model.bin:   0%|          | 0.00/1.45G [00:00<?, ?B/s]

Get the weighted Adam (AdamW) optimizer

In [37]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5)

Get the Accelerator for the training purpose.

Accelerate is a library that enables the same PyTorch code to be run across any distributed configuration by adding just four lines of code! In short, training and inference at scale made simple, efficient and adaptable.

In [38]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

Get the learning rate scheduler.

Sets the learning rate of each parameter group to the initial lr times a given function. When last_epoch=-1, sets initial lr as lr.

In [39]:
from transformers import get_scheduler

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

Set the model_name and output directory where the trained model will be saved

In [41]:
model_name = "bio_qa_model"
output_dir = model_name

## Training

In [42]:
from tqdm.auto import tqdm
import torch

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    start_logits = []
    end_logits = []
    accelerator.print("Evaluation!")
    for batch in tqdm(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
        end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())

    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)
    start_logits = start_logits[: len(validation_dataset)]
    end_logits = end_logits[: len(validation_dataset)]

    metrics = compute_metrics(
        start_logits, end_logits, validation_dataset, raw_datasets["validation"]
    )
    print(f"epoch {epoch}:", metrics)

    # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)

  0%|          | 0/54 [00:00<?, ?it/s]

Evaluation!


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

epoch 0: {'exact_match': 62.5, 'f1': 62.5}
Evaluation!


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

epoch 1: {'exact_match': 37.5, 'f1': 44.58333333333333}
Evaluation!


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

epoch 2: {'exact_match': 37.5, 'f1': 45.58333333333333}


## Prediction

In [43]:
small_eval_set = raw_datasets["validation"].select(range(4))
# trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(output_dir)
eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

In [44]:
# Pass our small evaluation set through the
# `distilbert-base-cased-distilled-squad` model
# to get the output logits
eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(output_dir).to(
    device
)

with torch.no_grad():
    outputs = trained_model(**batch)

In [45]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

In [46]:
example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(eval_set):
    example_to_features[feature["example_id"]].append(idx)

n_best = 20
max_answer_length = 30
predicted_answers = []


for example in small_eval_set:
    example_id = example["id"]
    context = example["context"]
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = eval_set["offset_mapping"][feature_index]

        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Skip answers that are not fully in the context
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                # Skip answers with a length that is either < 0 or > max_answer_length.
                if (
                    end_index < start_index
                    or end_index - start_index + 1 > max_answer_length
                ):
                    continue

                answers.append(
                    {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                )

    best_answer = max(answers, key=lambda x: x["logit_score"])
    predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})

## Evaluation

In [47]:
theoretical_answers = [
    {"id": ex["id"], "answers": ex["answers"], "question": ex["question"]} for ex in small_eval_set
]

for x, y in zip(predicted_answers, theoretical_answers):
    print("Question: ", y['question'])
    print("Ground truth ", y['answers']['text'][0])
    print("Predictions ", x['prediction_text'])
    y.pop('question', None)
    print('\n')

Question:  How old is the patient?
Ground truth  56
Predictions  56


Question:  Does the patient have any complaints?
Ground truth  Positive for hot flashes. She also complains about snoring and occasional slight asthma. She does complain about peripheral ankle swelling and heartburn
Predictions  Significant for hypertension


Question:  What is the reason for this consultation?
Ground truth  gastric bypass surgery
Predictions  obesity related comorbidities


Question:  What other symptoms does the patient have?
Ground truth  hypertension
Predictions  hypertension




In [48]:
compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

  0%|          | 0/4 [00:00<?, ?it/s]

{'exact_match': 50.0, 'f1': 52.0}

## Inference

In [49]:
from transformers import pipeline
import os
import pandas as pd

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

df = pd.read_csv(os.path.join(data_dir, "mtsamples.csv"), index_col=0)

# context = squad['validation'][0]['context']
q1 = "How old is the patient?"
q2 = "Does the patient have any complaints?"
q3 = "What is the reason for this consultation?"
q4 = "What does the echodiagram show?"
q5 = "What other symptoms does the patient have?"

questions = [q1, q2, q3, q4, q5]

context = df.iloc[21]['transcription']
print(context)
print('\n')

question_answerer = pipeline("question-answering", model=output_dir)

for i, q in enumerate(questions):
  print(f"Question {i+1}: {q}\n")
  print(f"Answer: {question_answerer(question=q, context=context)['answer']}\n")

FINAL DIAGNOSES,1.  Morbid obesity, status post laparoscopic Roux-en-Y gastric bypass. ,2.  Hypertension. ,3.  Obstructive sleep apnea, on CPAP.,OPERATION AND PROCEDURE: , Laparoscopic Roux-en-Y gastric bypass.,BRIEF HOSPITAL COURSE SUMMARY:  ,This is a 30-year-old male, who presented recently to the Bariatric Center for evaluation and treatment of longstanding morbid obesity and associated comorbidities.  Underwent standard bariatric evaluation, consults, diagnostics, and preop Medifast induced weight loss in anticipation of elective bariatric surgery. ,Taken to the OR via same day surgery process for elective gastric bypass, tolerated well, recovered in the PACU, and sent to the floor for routine postoperative care.  There, DVT prophylaxis was continued with subcu heparin, early and frequent mobilization, and SCDs.  PCA was utilized for pain control, efficaciously, he utilized the CPAP, was monitored, and had no new cardiopulmonary complaints.  Postop day #1, labs within normal limit