In [8]:
import json

print("Script started.")

def parse_pubtator_file(path):
    print(f"Parsing PubTator file from: {path}")
    docs = {}
    with open(path, encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split('|')
            if len(parts) == 3:
                pmid, typ, text = parts
                if pmid not in docs:
                    docs[pmid] = {'title': '', 'abstract': '', 'entities': [], 'relations': []}
                if typ == 't':
                    docs[pmid]['title'] = text
                elif typ == 'a':
                    docs[pmid]['abstract'] = text
            elif '\t' in line:
                fields = line.split('\t')
                pmid = fields[0]
                if len(fields) == 6:
                    # Entity annotation
                    start, end, mention, typ, mesh = fields[1:]
                    docs[pmid]['entities'].append({
                        'start': int(start), 'end': int(end), 'text': mention, 'type': typ, 'mesh': mesh
                    })
                elif len(fields) == 4 and fields[1] == 'CID':
                    # Relation annotation
                    _, _, chem_mesh, dis_mesh = fields
                    docs[pmid]['relations'].append((chem_mesh, dis_mesh))
    print(f"Parsed {len(docs)} docs")
    return docs

def generate_qa_examples(docs, out_path):
    print("generate_qa_examples function called!")
    squad_data = []
    for doc_id, doc in docs.items():
        print(f"Processing doc_id: {doc_id}")
        context = doc['title'] + ' ' + doc['abstract']
        title = doc['title']
        mesh_to_text = {e['mesh']: e['text'] for e in doc['entities']}
        qas = []
        for chem_mesh, dis_mesh in doc['relations']:
            print(f"  relation: {chem_mesh} - {dis_mesh}")
            if chem_mesh in mesh_to_text and dis_mesh in mesh_to_text:
                chem_text = mesh_to_text[chem_mesh]
                dis_text = mesh_to_text[dis_mesh]
                ans_start_dis = context.find(dis_text)
                if ans_start_dis != -1:
                    print(f"    Adding QA: {chem_text} -> {dis_text}")
                    qas.append({
                        "id": f"{doc_id}_chem_{chem_mesh}_dis_{dis_mesh}",
                        "question": f"What diseases are associated with {chem_text}?",
                        "answers": [{"text": dis_text, "answer_start": ans_start_dis}]
                    })
                ans_start_chem = context.find(chem_text)
                if ans_start_chem != -1:
                    print(f"    Adding QA: {dis_text} -> {chem_text}")
                    qas.append({
                        "id": f"{doc_id}_dis_{dis_mesh}_chem_{chem_mesh}",
                        "question": f"What chemicals are associated with {dis_text}?",
                        "answers": [{"text": chem_text, "answer_start": ans_start_chem}]
                    })
        if qas:
            squad_data.append({
                "title": title,
                "paragraphs": [{
                    "context": context,
                    "qas": qas
                }]
            })
    print(f"Writing {len(squad_data)} docs to {out_path}")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump({"data": squad_data}, f, indent=2, ensure_ascii=False)

if __name__ == "__main__":
    docs = parse_pubtator_file('/content/bio-bert-medical-chatbot/data/CDR_TestSet.PubTator.txt')
    print(f"Loaded {len(docs)} documents.")
    generate_qa_examples(docs, "/content/bio-bert-medical-chatbot/data/cdr_qa_train.json")
    print("Script finished.")

Script started.
Parsing PubTator file from: /content/bio-bert-medical-chatbot/data/CDR_TestSet.PubTator.txt
Parsed 555 docs
Loaded 555 documents.
generate_qa_examples function called!
Processing doc_id: 8701013
  relation: D015738 - D003693
    Adding QA: famotidine -> delirium
    Adding QA: delirium -> famotidine
Processing doc_id: 439781
  relation: D007213 - D007022
    Adding QA: indomethacin -> hypotension
    Adding QA: hypotension -> indomethacin
Processing doc_id: 22836123
  relation: D016572 - D057049
    Adding QA: cyclosporine -> thrombotic microangiopathy
    Adding QA: thrombotic microangiopathy -> cyclosporine
  relation: D000305 - D012595
    Adding QA: corticosteroids -> SSc
    Adding QA: SSc -> corticosteroids
  relation: D016559 - D012595
    Adding QA: tacrolimus -> SSc
    Adding QA: SSc -> tacrolimus
Processing doc_id: 23433219
  relation: D008694 - D011605
    Adding QA: methamphetamine -> psychotic symptoms
    Adding QA: psychotic symptoms -> methamphetamine
 

In [9]:
# Show one sample document (title, abstract, entities, and relations)
sample = next(iter(docs.values()))
print("Title:", sample['title'])
print("Abstract:", sample['abstract'])
print("\nEntities:")
for ent in sample['entities']:
    print(f" - {ent['text']} ({ent['type']}, MeSH: {ent['mesh']}) [{ent['start']}-{ent['end']}]")
print("\nRelations:")
for rel in sample['relations']:
    print(f" - Chemical MeSH: {rel[0]}  <--> Disease MeSH: {rel[1]}")

Title: Famotidine-associated delirium. A series of six cases.
Abstract: Famotidine is a histamine H2-receptor antagonist used in inpatient settings for prevention of stress ulcers and is showing increasing popularity because of its low cost. Although all of the currently available H2-receptor antagonists have shown the propensity to cause delirium, only two previously reported cases have been associated with famotidine. The authors report on six cases of famotidine-associated delirium in hospitalized patients who cleared completely upon removal of famotidine. The pharmacokinetics of famotidine are reviewed, with no change in its metabolism in the elderly population seen. The implications of using famotidine in elderly persons are discussed.

Entities:
 - Famotidine (Chemical, MeSH: D015738) [0-10]
 - delirium (Disease, MeSH: D003693) [22-30]
 - Famotidine (Chemical, MeSH: D015738) [55-65]
 - ulcers (Disease, MeSH: D014456) [156-162]
 - delirium (Disease, MeSH: D003693) [324-332]
 - fam

In [10]:
import json

def generate_qa_examples(docs, out_path):
    squad_data = []
    for doc_id, doc in docs.items():
        context = doc['title'] + ' ' + doc['abstract']
        title = doc['title']
        mesh_to_text = {e['mesh']: e['text'] for e in doc['entities']}
        qas = []
        for chem_mesh, dis_mesh in doc['relations']:
            if chem_mesh in mesh_to_text and dis_mesh in mesh_to_text:
                chem_text = mesh_to_text[chem_mesh]
                dis_text = mesh_to_text[dis_mesh]
                # Q: diseases for chemical
                ans_start = context.find(dis_text)
                if ans_start != -1:
                    qas.append({
                        "id": f"{doc_id}_chem_{chem_mesh}",
                        "question": f"What diseases are associated with {chem_text}?",
                        "answers": [{"text": dis_text, "answer_start": ans_start}]
                    })
                # Q: chemicals for disease
                ans_start = context.find(chem_text)
                if ans_start != -1:
                    qas.append({
                        "id": f"{doc_id}_dis_{dis_mesh}",
                        "question": f"What chemicals are associated with {dis_text}?",
                        "answers": [{"text": chem_text, "answer_start": ans_start}]
                    })
        if qas:
            squad_data.append({
                "title": title,
                "paragraphs": [{
                    "context": context,
                    "qas": qas
                }]
            })
    with open("/content/bio-bert-medical-chatbot/data/cdr_qa_train.json", "w") as f:
        json.dump({"data": squad_data}, f, indent=2)

In [11]:
import os
print(os.getcwd())
os.listdir()

/content


['.config', 'bio-bert-medical-chatbot', 'sample_data']

In [12]:
pip install -U datasets



In [13]:
from datasets import load_dataset

train_file = "/content/bio-bert-medical-chatbot/data/cdr_qa_train.json"
dataset = load_dataset("json", data_files={"train": train_file}, field="data")
dataset

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['title', 'paragraphs'],
        num_rows: 498
    })
})

In [14]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

model_name = "dmis-lab/biobert-base-cased-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.1 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
from datasets import Dataset

def flatten_squad(dataset):
    contexts = []
    questions = []
    answers = []
    for entry in dataset:  # direct iteration over dataset rows
        for paragraph in entry["paragraphs"]:
            context = paragraph["context"]
            for qa in paragraph["qas"]:
                if qa["answers"]:
                    answer = qa["answers"][0]
                else:
                    answer = {"text": "", "answer_start": 0}
                contexts.append(context)
                questions.append(qa["question"])
                answers.append([answer])  # HuggingFace expects a list of dicts
    return {"context": contexts, "question": questions, "answers": answers}

# 1. Load your nested SQuAD-style data (replace path as needed)
from datasets import load_dataset
dataset = load_dataset("json", data_files={"train": "/content/bio-bert-medical-chatbot/data/cdr_qa_train.json"}, field="data")

# 2. Flatten it
flattened = flatten_squad(dataset["train"])
flat_dataset = Dataset.from_dict(flattened)

# Print the first 2 flattened examples
print("First 2 flattened examples:")
for i in range(2):
    print(f"Example {i+1}:")
    print("  Context:", flat_dataset[i]["context"][:100], "...")
    print("  Question:", flat_dataset[i]["question"])
    print("  Answer:", flat_dataset[i]["answers"])
    print("------")

# 3. Preprocessing function (for HuggingFace tokenizer)
def preprocess_function(examples):
    questions = [q.lstrip() for q in examples["question"]]
    contexts = examples["context"]

    tokenized_examples = tokenizer(
        questions,
        contexts,
        truncation="only_second",
        max_length=384,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")

    start_positions = []
    end_positions = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        sample_index = sample_mapping[i]
        answer = examples["answers"][sample_index][0]  # <-- FIXED

        if answer["answer_start"] == 0 and answer["text"] == "":
            start_positions.append(cls_index)
            end_positions.append(cls_index)
        else:
            start_char = answer["answer_start"]
            end_char = start_char + len(answer["text"])

            sequence_ids = tokenized_examples.sequence_ids(i)
            context_start = sequence_ids.index(1)
            context_end = len(sequence_ids) - 1 - sequence_ids[::-1].index(1)

            # If answer is outside this feature, label CLS
            if not (offsets[context_start][0] <= start_char and offsets[context_end - 1][1] >= end_char):
                start_positions.append(cls_index)
                end_positions.append(cls_index)
            else:
                token_start = context_start
                while token_start < context_end and offsets[token_start][0] <= start_char:
                    token_start += 1
                token_end = context_end - 1
                while token_end >= context_start and offsets[token_end][1] >= end_char:
                    token_end -= 1
                start_positions.append(token_start - 1)
                end_positions.append(token_end + 1)

    tokenized_examples["start_positions"] = start_positions
    tokenized_examples["end_positions"] = end_positions
    return tokenized_examples

# 4. Apply preprocessing
processed_datasets = flat_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=flat_dataset.column_names
)

print("\n✅ Preprocessing complete.")
print("Number of processed samples:", len(processed_datasets))
print("First processed example:")
for k in processed_datasets.features.keys():
    print(f"{k}: {processed_datasets[0][k]}")

First 2 flattened examples:
Example 1:
  Context: Famotidine-associated delirium. A series of six cases. Famotidine is a histamine H2-receptor antagon ...
  Question: What diseases are associated with famotidine?
  Answer: [{'answer_start': 22, 'text': 'delirium'}]
------
Example 2:
  Context: Famotidine-associated delirium. A series of six cases. Famotidine is a histamine H2-receptor antagon ...
  Question: What chemicals are associated with delirium?
  Answer: [{'answer_start': 395, 'text': 'famotidine'}]
------


Map:   0%|          | 0/2084 [00:00<?, ? examples/s]


✅ Preprocessing complete.
Number of processed samples: 3072
First processed example:
input_ids: [101, 1184, 8131, 1132, 2628, 1114, 175, 16931, 3121, 10399, 136, 102, 175, 16931, 3121, 10399, 118, 2628, 3687, 17262, 1818, 119, 170, 1326, 1104, 1565, 2740, 119, 175, 16931, 3121, 10399, 1110, 170, 1117, 27621, 177, 1477, 118, 10814, 19173, 1215, 1107, 1107, 27420, 11106, 1111, 13347, 1104, 6600, 23449, 14840, 1116, 1105, 1110, 4000, 4138, 5587, 1272, 1104, 1157, 1822, 2616, 119, 1780, 1155, 1104, 1103, 1971, 1907, 177, 1477, 118, 10814, 19173, 1116, 1138, 2602, 1103, 21146, 5026, 1785, 1106, 2612, 3687, 17262, 1818, 117, 1178, 1160, 2331, 2103, 2740, 1138, 1151, 2628, 1114, 175, 16931, 3121, 10399, 119, 1103, 5752, 2592, 1113, 1565, 2740, 1104, 175, 16931, 3121, 10399, 118, 2628, 3687, 17262, 1818, 1107, 2704, 2200, 4420, 1150, 5323, 2423, 1852, 8116, 1104, 175, 16931, 3121, 10399, 119, 1103, 185, 7111, 1918, 2528, 4314, 22259, 1104, 175, 16931, 3121, 10399, 1132, 7815, 117, 1114, 1185,

In [20]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model_name = "dmis-lab/biobert-base-cased-v1.1"
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

split = processed_datasets.train_test_split(test_size=0.1)
train_dataset = split["train"]   # Use all training samples
eval_dataset = split["test"]     # Use all eval samples

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=3e-5,
    per_device_train_batch_size=8,      # Lower if OOM
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=100,
    save_total_limit=1,
    save_strategy="epoch",
    fp16=True,                         # Remove if your GPU doesn't support fp16
    dataloader_num_workers=2,
    report_to=[],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

trainer.train()

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.1 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss
100,3.029
200,1.7207
300,1.468
400,1.2299
500,1.1018
600,1.2162
700,1.0177
800,0.7963
900,0.8134
1000,0.8422


TrainOutput(global_step=1038, training_loss=1.3072662096262437, metrics={'train_runtime': 268.1626, 'train_samples_per_second': 30.922, 'train_steps_per_second': 3.871, 'total_flos': 1625004530141184.0, 'train_loss': 1.3072662096262437, 'epoch': 3.0})

In [38]:
!pip install evaluate
from datasets import load_metric
metric = load_metric("squad_v2")  # or "squad" if no unanswerable questions

def compute_metrics(p):
    return metric.compute(predictions=p.predictions, references=p.label_ids)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
results = trainer.evaluate()
print(results)

Collecting evaluate
  Downloading evaluate-0.4.4-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.4-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.4


ImportError: cannot import name 'load_metric' from 'datasets' (/usr/local/lib/python3.11/dist-packages/datasets/__init__.py)

In [25]:
trainer.save_model("/content/bio-bert-medical-chatbot/data/my_biobert_qa_model")
tokenizer.save_pretrained("/content/bio-bert-medical-chatbot/data/my_biobert_qa_model")

('/content/bio-bert-medical-chatbot/data/my_biobert_qa_model/tokenizer_config.json',
 '/content/bio-bert-medical-chatbot/data/my_biobert_qa_model/special_tokens_map.json',
 '/content/bio-bert-medical-chatbot/data/my_biobert_qa_model/vocab.txt',
 '/content/bio-bert-medical-chatbot/data/my_biobert_qa_model/added_tokens.json',
 '/content/bio-bert-medical-chatbot/data/my_biobert_qa_model/tokenizer.json')

In [None]:
from huggingface_hub import HfApi

# You can also use `model.push_to_hub()` and `tokenizer.push_to_hub()` directly:
model.push_to_hub("my-bio-bert-qa-model")
tokenizer.push_to_hub("my-bio-bert-qa-model")