In [267]:
from transformers import pipeline, AutoTokenizer, BertForQuestionAnswering, BertForMaskedLM
from datasets import load_dataset

# Load the SQuAD dataset
squad_dataset = load_dataset("pubmed_qa", "pqa_labeled")
raw_datasets = squad_dataset

raw_datasets

DatasetDict({
    train: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
        num_rows: 1000
    })
})

In [268]:
# Load the pre-trained model and tokenizer
model_name = "adsabs/astroBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at adsabs/astroBERT and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [269]:
def tokenize_function(examples):
    return tokenizer(examples["question"],  examples["final_decision"], padding="max_length", truncation=True, max_length=206)

In [270]:
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer', 'final_decision', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1000
    })
})

In [271]:
tokenized_datasets = tokenized_datasets.rename_column("question", "text")
# tokenized_datasets = tokenized_datasets.rename_column("final_decision", "labels")

In [272]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000)) 
full_train_dataset = tokenized_datasets["train"]

In [273]:
full_train_dataset[0]

{'pubid': 21645374,
 'text': 'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?',
 'context': {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
   'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells 

In [274]:
full_train_dataset = full_train_dataset.map(
    lambda example: {"label": 1 if example["final_decision"] == "yes" else 0},
    remove_columns=["context"],
)

In [275]:
print(full_train_dataset.column_names)
print(full_train_dataset.features["input_ids"])

['pubid', 'text', 'long_answer', 'final_decision', 'input_ids', 'token_type_ids', 'attention_mask', 'label']
Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None)


In [276]:
from transformers import Trainer, TrainingArguments
# Define training arguments
training_args = TrainingArguments(
    output_dir="./qa_finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    save_steps=500,
    save_total_limit=2,
)

# training_args = TrainingArguments("test_trainer")

In [277]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels),

In [278]:

# Define the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=full_train_dataset,
)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
model.save_pretrained("./qa_finetuned")
tokenizer.save_pretrained("./qa_finetuned")



TypeError: BertForQuestionAnswering.forward() got an unexpected keyword argument 'labels'

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    compute_metrics=compute_metrics,
)
trainer.evaluate()