In [45]:
import os
import re
import faiss
import pickle
import nltk
from typing import List
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer


class Retriever:
    def __init__(self, model_name: str = "all-MiniLM-L6-v2", sentences_per_chunk: int = 2, overlap: int = 1):
        self.model = SentenceTransformer(model_name)
        self.sentences_per_chunk = sentences_per_chunk
        self.overlap = overlap
        self.index = None
        self.documents = []
        self.embeddings = []

    def _clean_text(self, text: str) -> str:
        """
        Clean the input text by collapsing whitespace.
        """
        return re.sub(r'\s+', ' ', text).strip()

    def _chunk_text(self, text: str) -> List[str]:
        """
        Split text into sentence-based chunks with optional overlap.
        """
        sentences = sent_tokenize(text)
        chunks = []
        step = self.sentences_per_chunk - self.overlap
        for i in range(0, len(sentences), step):
            chunk = " ".join(sentences[i:i + self.sentences_per_chunk])
            if chunk:
                chunks.append(chunk)
        return chunks

    def _load_txt(self, filepath: str) -> str:
        """
        Load and clean a .txt file.
        """
        with open(filepath, 'r', encoding='utf-8') as f:
            return self._clean_text(f.read())

    def add_txt_files(self, filepaths: List[str]):
        """
        Load .txt files, chunk them, embed, and add to the FAISS index.
        """
        all_chunks = []
        for filepath in filepaths:
            if not filepath.endswith('.txt'):
                print(f"Skipping unsupported file type: {filepath}")
                continue
            text = self._load_txt(filepath)
            chunks = self._chunk_text(text)
            self.documents.extend(chunks)
            all_chunks.extend(chunks)

        if all_chunks:
            embeddings = self.model.encode(all_chunks, show_progress_bar=True)
            self.embeddings.extend(embeddings)
            dim = embeddings[0].shape[0]
            if self.index is None:
                self.index = faiss.IndexFlatL2(dim)
            self.index.add(embeddings)

    def query(self, text: str, top_k: int = 5) -> List[str]:
        """
        Perform a semantic search against the indexed chunks.
        """
        query_vec = self.model.encode([text])
        D, I = self.index.search(query_vec, top_k)
        return [self.documents[i] for i in I[0]]

    def save(self, path: str):
        """
        Save the FAISS index and documents to disk.
        """
        os.makedirs(path, exist_ok=True)
        faiss.write_index(self.index, os.path.join(path, "faiss.index"))
        with open(os.path.join(path, "docs.pkl"), "wb") as f:
            pickle.dump(self.documents, f)

    def load(self, path: str):
        """
        Load the FAISS index and documents from disk.
        """
        self.index = faiss.read_index(os.path.join(path, "faiss.index"))
        with open(os.path.join(path, "docs.pkl"), "rb") as f:
            self.documents = pickle.load(f)


In [46]:
retriever = Retriever()
retriever.add_txt_files(["sample_file.txt"])
print(retriever.query("What is book generation?", top_k=3))

retriever.save("retriever_data")

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

["Book generation Not an NLP task proper but an extension of natural language generation and other NLP tasks is the creation of full-fledged books. The first machine-generated book was created by a rule-based system in 1984 (Racter, The policeman's beard is half-constructed).", '[37] Natural-language generation (NLG): Convert information from computer databases or semantic intents into readable human language. Book generation Not an NLP task proper but an extension of natural language generation and other NLP tasks is the creation of full-fledged books.', 'The first machine-generated science book was published in 2019 (Beta Writer, Lithium-Ion Batteries, Springer, Cham). [39] Unlike Racter and 1 the Road, this is grounded on factual knowledge and based on text summarization.']


In [47]:
def test_retriever_returns_expected_chunk():
    # Setup
    retriever = Retriever(sentences_per_chunk=2, overlap=0)
    retriever.add_txt_files(["sample_file.txt"])

    # Query and expected phrase
    query = "What is book generation?"
    results = retriever.query(query, top_k=1)

    # Check if expected answer is in retrieved text
    expected_phrase = "Book generation"
    assert any(expected_phrase in chunk for chunk in results), "Expected phrase not found in retrieved chunk"

    print("✅ Test passed: Relevant chunk retrieved.")

# Run the test
test_retriever_returns_expected_chunk()

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Test passed: Relevant chunk retrieved.
