<a href="https://colab.research.google.com/github/saumya112-IN/sagemaker-flight-prices-prediction/blob/master/SLM_QA_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0


In [43]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import faiss
import numpy as np

class SLMQuestionAnswering:
    def __init__(self, model_name="distilbert-base-uncased-distilled-squad"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
        self.index = None  # FAISS index for retrieval
        self.contexts = []  # Store document chunks

    def chunk_text(self, text, chunk_size=512, overlap=50):
        """Splits the text into overlapping chunks."""
        words = text.split()
        chunks = []
        for i in range(0, len(words), chunk_size - overlap):
            chunk = " ".join(words[i:i + chunk_size])
            chunks.append(chunk)
        return chunks

    def build_index(self, text):
        """Builds FAISS index for efficient retrieval."""
        self.contexts = self.chunk_text(text)
        embeddings = self.get_embeddings(self.contexts)
        self.index = faiss.IndexFlatL2(768)
        self.index.add(embeddings)

    def get_embeddings(self, texts):
        """Converts texts to embeddings using the model's tokenizer."""
        with torch.no_grad():
           inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
           # Access the DistilBERT base model directly using 'distilbert'
           outputs = self.model.distilbert(**inputs)
           embeddings = outputs.last_hidden_state[:, 0, :].numpy()
        return embeddings

    def retrieve_context(self, question, top_k=1):
        """Retrieves the most relevant context for a given question."""
        question_embedding = self.get_embeddings([question])
        _, indices = self.index.search(question_embedding, top_k)
        return [self.contexts[i] for i in indices[0]]

    def answer_question(self, question):
        """Answers a question based on retrieved context."""
        contexts = self.retrieve_context(question)
        answers = []
        for context in contexts:
            inputs = self.tokenizer(question, context, return_tensors="pt", truncation=True)
            outputs = self.model(**inputs)
            answer_start = torch.argmax(outputs.start_logits)
            answer_end = torch.argmax(outputs.end_logits) + 1
            answer = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
            answers.append(answer)
        return answers

# Example Usage
slm = SLMQuestionAnswering()
book_text = "This is a sample book content for testing our SLM-based QA model. The model should accurately retrieve and answer questions from the given text."
slm.build_index(book_text)
question = "What should the model do?"
answers = slm.answer_question(question)
print("Answer:", answers)


Answer: ['accurately retrieve and answer questions from the given text']


In [44]:
# Initialize the model
slm = SLMQuestionAnswering()

# Provide a sample book text
book_text = "The sun rises in the east and sets in the west. It provides light and energy to the Earth."
slm.build_index(book_text)

# Ask a question
question = "Where does the sun rise?"
answers = slm.answer_question(question)

# Print the output
print("Predicted Answer:", answers)


Predicted Answer: ['the east']


In [29]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3


In [45]:
#!pip install evaluate
import evaluate

class SLMQuestionAnswering:
    def __init__(self, model_name="distilbert-base-uncased-distilled-squad"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
        self.index = None  # FAISS index for retrieval
        self.contexts = []  # Store document chunks

    def chunk_text(self, text, chunk_size=512, overlap=50):
        """Splits the text into overlapping chunks."""
        words = text.split()
        chunks = []
        for i in range(0, len(words), chunk_size - overlap):
            chunk = " ".join(words[i:i + chunk_size])
            chunks.append(chunk)
        return chunks

    def build_index(self, text):
        """Builds FAISS index for efficient retrieval."""
        self.contexts = self.chunk_text(text)
        embeddings = self.get_embeddings(self.contexts)
        self.index = faiss.IndexFlatL2(768)
        self.index.add(embeddings)

    def get_embeddings(self, texts):
        """Converts texts to embeddings using the model's tokenizer."""
        with torch.no_grad():
           inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
           # Access the DistilBERT base model directly using 'distilbert'
           outputs = self.model.distilbert(**inputs)
           embeddings = outputs.last_hidden_state[:, 0, :].numpy()
        return embeddings

    def retrieve_context(self, question, top_k=1):
        """Retrieves the most relevant context for a given question."""
        question_embedding = self.get_embeddings([question])
        _, indices = self.index.search(question_embedding, top_k)
        return [self.contexts[i] for i in indices[0]]

    def answer_question(self, question):
        """Answers a question based on retrieved context."""
        contexts = self.retrieve_context(question)
        answers = []
        for context in contexts:
            inputs = self.tokenizer(question, context, return_tensors="pt", truncation=True)
            outputs = self.model(**inputs)
            answer_start = torch.argmax(outputs.start_logits)
            answer_end = torch.argmax(outputs.end_logits) + 1
            answer = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
            answers.append(answer)
        return answers

    def evaluate(self, dataset):
        """Evaluates the model using Exact Match (EM) and F1-score."""
        # Use evaluate.load instead of load_metric
        metric = evaluate.load("squad")
        predictions = []
        references = []

        for example in dataset:
            question = example["question"]
            expected_answer = example["answers"]["text"][0]
            predicted_answer = self.answer_question(question)[0]

            predictions.append({"id": example["id"], "prediction_text": predicted_answer})
            references.append({"id": example["id"], "answers": example["answers"]})

        results = metric.compute(predictions=predictions, references=references)
        return results

In [46]:
# Example dataset for evaluation
dataset = [
    {
        "id": "1",
        "question": "Where does the sun rise?",
        "answers": {"text": ["in the east"], "answer_start": [13]},
    },
    {
        "id": "2",
        "question": "What does the sun provide?",
        "answers": {"text": ["light and energy"], "answer_start": [38]},
    }
]

# Re-instantiate the model to pick up changes to the class definition
slm = SLMQuestionAnswering()
book_text = "The sun rises in the east and sets in the west. It provides light and energy to the Earth."
slm.build_index(book_text)

# Evaluate the model
results = slm.evaluate(dataset)

# Print evaluation metrics
print("Exact Match (EM):", results["exact_match"])
print("F1 Score:", results["f1"])

Exact Match (EM): 50.0
F1 Score: 83.33333333333333


### Evaluation Metrics
- **Exact Match (EM):** Measures how often the predicted answer matches the ground truth exactly.
- **F1 Score:** Evaluates the overlap between the predicted and actual answer using precision and recall.

### Key Learnings
- Implemented **FAISS** for context retrieval.
- Optimized text chunking for **efficient processing**.
- Used **DistilBERT** for lightweight but effective question answering.
- Achieved **accurate responses** with minimal compute overhead.
- Integrated **evaluation metrics** to assess model performance.

### Future Enhancements
- Extend support for **multi-turn conversations**.
- Implement **longer context handling** using **RAG (Retrieval-Augmented Generation)**.
- Deploy as a **web API using FastAPI or Flask**.
- Experiment with **other evaluation metrics** like BLEU and ROUGE.