In [1]:
!pip install transformers datasets --quiet

In [2]:
# --- Step 1: Install and Import Necessary Libraries ---
import tensorflow as tf
from transformers import TFDistilBertForQuestionAnswering, DistilBertTokenizer
from datasets import load_dataset


In [None]:
# --- Step 2: Load the Dataset and Model ---
print("Loading SQuAD dataset...")
# SQuAD (Stanford Question Answering Dataset) is the standard for this task.

raw_datasets = load_dataset("squad", split={
    "train": "train[:1%]",      # Use 1% of the training data
    "validation": "validation[:5%]" # Use 5% of the validation data
})

print("Loading DistilBERT model and tokenizer...")
# DistilBERT is a smaller, faster version of BERT.
model_name = "distilbert-base-uncased"

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
model = TFDistilBertForQuestionAnswering.from_pretrained(model_name, use_safetensors=False)

In [None]:


# --- Step 3: Preprocess the Data ---
# This is the most complex step for QA. We need to tokenize the context and question
# together and then map the answer text to the token positions.
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two parts of the context when splitting

def preprocess_function(examples):
    # Tokenize the questions and contexts, allowing for truncation and overlap.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second", # Truncate the context, not the question
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # The 'overflow_to_sample_mapping' maps each new feature back to its original example.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # 'offset_mapping' maps each token to its character position in the original text.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Now we label our data with the start and end token positions.
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the original example corresponding to this feature.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]

        # If no answers are given, set the cls_index as the answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start and end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Find the start and end token indices.
            token_start_index = 0
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1

            token_end_index = len(offsets) - 1
            while token_end_index >= 0 and offsets[token_end_index][1] >= end_char:
                token_end_index -= 1

            # If the answer is not fully inside the context, label it with cls_index.
            if not (token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                 tokenized_examples["start_positions"].append(cls_index)
                 tokenized_examples["end_positions"].append(cls_index)
            else:
                 tokenized_examples["start_positions"].append(token_start_index - 1) # Adjust for python slicing
                 tokenized_examples["end_positions"].append(token_end_index - 1)


    return tokenized_examples

print("Tokenizing the dataset...")
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names)

In [None]:
# --- Step 4: Prepare Data for Training with tf.data ---
train_dataset = tokenized_datasets["train"].to_tf_dataset(
    columns=["input_ids", "attention_mask", "start_positions", "end_positions"],
    shuffle=True,
    batch_size=8
)

validation_dataset = tokenized_datasets["validation"].to_tf_dataset(
    columns=["input_ids", "attention_mask", "start_positions", "end_positions"],
    shuffle=False,
    batch_size=8
)

In [None]:
# --- Step 5: Compile and Fine-Tune the Model ---
print("Compiling the model...")
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
# The model will compute the loss internally from start and end logits.
model.compile(optimizer=optimizer)

print("Starting fine-tuning...")
history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=3
)

In [None]:
# --- Step 6: Inference
import numpy as np

def ask_question(model, tokenizer, question, context):
    """Asks a question to the fine-tuned model based on a context."""
    # Tokenize the input
    inputs = tokenizer(question, context, return_tensors="tf")

    # Get model outputs
    outputs = model(inputs)
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    # Get the most likely start and end token positions
    start_index = tf.argmax(start_logits, axis=1).numpy()[0]
    end_index = tf.argmax(end_logits, axis=1).numpy()[0]

    # Decode the tokens between start and end to get the answer
    input_ids = inputs["input_ids"].numpy()[0]
    answer_tokens = input_ids[start_index : end_index + 1]
    answer = tokenizer.decode(answer_tokens)

    return answer

# Example usage with a new context and question
new_context = "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower."
new_question = "Who is the Eiffel Tower named after?"

print("\n--- Answering a New Question ---")
print(f"Context: {new_context}")
print(f"Question: {new_question}")
answer = ask_question(model, tokenizer, new_question, new_context)
print(f"Predicted Answer: {answer}")