In [1]:
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset

# Load ScienceQA dataset (without streaming)
dataset = load_dataset("derek-thomas/ScienceQA", split={"train": "train", "validation": "validation"})

# Use DistilBERT for Question Answering
MODEL_NAME = "distilbert-base-cased-distilled-squad"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)

In [2]:
def preprocess_function(examples):
    """
    Tokenizes the question and context, finds the correct answer in the context, 
    and creates labels for start and end positions.
    """
    questions = examples["question"]
    contexts = examples["subject"]
    choices = examples["choices"]
#     answer_indices = examples["answer_index"]

    # Combine context and question
    question_and_context = [
        f"Question: {q} Context: {c}" for q, c in zip(questions, contexts)
    ]

    # Extract correct answer text using answer index
#     answers_text = [
#         a[idx] if 0 <= idx < len(a) else "" for a, idx in zip(choices, answer_indices)
#     ]
    
    answers_text = []
    questions = []
    index = 0
    for e in examples["question"]:
        choices = examples["choices"][index]
        question_and_choices = e +" Choices are "+ ", ".join(choices)+ "."
#         print(f'question {question_and_choices}')
        answer_index = examples["answer"][index]
        answer = choices[answer_index]
#         print(f'ans: {answer}')
        questions.append(question_and_choices)
        answers_text.append(answer)
#         question_answer_pairs.append({"question": question_and_choices, 
#                                      "answer": answer})
        index+=1

    # Tokenize
    tokenized_inputs = tokenizer(
        question_and_context, 
        padding="max_length", 
        truncation=True, 
        max_length=384, 
        return_tensors="pt"
    )

    # Find answer spans in the tokenized context
    start_positions = []
    end_positions = []

    for i, (context, answer) in enumerate(zip(contexts, answers_text)):
        start_index = context.find(answer)  # Find answer position in context
        if start_index == -1:
            start_positions.append(0)
            end_positions.append(0)
        else:
            before_answer = tokenizer(context[:start_index], truncation=True, max_length=384)
            answer_tokens = tokenizer(answer, truncation=True, max_length=384)
            start_positions.append(len(before_answer["input_ids"]) - 1)
            end_positions.append(start_positions[-1] + len(answer_tokens["input_ids"]) - 2)

    # Add labels
    tokenized_inputs["start_positions"] = torch.tensor(start_positions)
    tokenized_inputs["end_positions"] = torch.tensor(end_positions)

    return tokenized_inputs

In [3]:
# Apply preprocessing
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)

# Convert to PyTorch dataset
train_dataset = tokenized_dataset["train"].with_format("torch")
valid_dataset = tokenized_dataset["validation"].with_format("torch")

In [4]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./scienceqa_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=3e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=500,
    fp16=torch.cuda.is_available(),
    max_steps=100,  # Avoids trainer length error
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
trainer.save_model("./scienceqa_finetuned_model")
tokenizer.save_pretrained("./scienceqa_finetuned_model")

print("Fine-tuning complete! Model saved at './scienceqa_finetuned_model'")


  trainer = Trainer(
max_steps is given, it will override any value given in num_train_epochs
[34m[1mwandb[0m: Currently logged in as: [33mshaddie77[0m ([33mshaddie77-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch,Training Loss,Validation Loss
0,No log,6.1e-05


Fine-tuning complete! Model saved at './scienceqa_finetuned_model'
