<a href="https://colab.research.google.com/github/rawanamrrr/DeepLearning1/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install transformers datasets torch

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_c

In [14]:
import torch
from transformers import DistilBertForQuestionAnswering, DistilBertTokenizerFast, Trainer, TrainingArguments
from datasets import load_dataset
import os
os.environ["WANDB_DISABLED"] = "true"
dataset = load_dataset("squad_v2")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

small_train_dataset = dataset["train"].shuffle(seed=42).select(range(500))
small_val_dataset = dataset["validation"].shuffle(seed=42).select(range(100))

def preprocess_data(example):
    inputs = tokenizer(example["question"], example["context"], truncation=True, padding="max_length", max_length=512)

    if example["answers"]["answer_start"] and len(example["answers"]["answer_start"]) > 0:
        inputs["start_positions"] = example["answers"]["answer_start"][0]
        inputs["end_positions"] = example["answers"]["answer_start"][0] + len(example["answers"]["text"][0])
    else:
        inputs["start_positions"] = 0
        inputs["end_positions"] = 0

    return inputs

train_dataset = small_train_dataset.map(preprocess_data)
val_dataset = small_val_dataset.map(preprocess_data)

model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")

training_args = TrainingArguments(
    output_dir="./qa_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    logging_dir="./logs",
    logging_steps=200
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

trainer.train()
trainer.save_model("./trained_qa_model")

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Epoch,Training Loss,Validation Loss
1,No log,3.521897


In [31]:

def ask_question(question, context):
    inputs = tokenizer(question, context, return_tensors="pt", truncation=True, padding=True)

    with torch.no_grad():
        outputs = model(**inputs)

    start_idx = torch.argmax(outputs.start_logits)
    end_idx = torch.argmax(outputs.end_logits)

    if start_idx >= end_idx:
        return "I'm not sure about the answer."

    answer = tokenizer.convert_tokens_to_string(
        tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_idx:end_idx + 1])
    )

    return answer.strip()

context_text = """Diabetes is a chronic disease that occurs when the pancreas is no longer able to make insulin,
or when the body cannot make good use of the insulin it produces. Common symptoms include increased thirst,
frequent urination, and extreme fatigue."""

question_text = input("Enter your question: ")
answer = ask_question(question_text, context_text)

print(f"Answer: {answer}")


Enter your question: What are the symptoms of diabetes
Answer: increased thirst, frequent urination, and extreme fatigue
