In [None]:
import faiss
import numpy as np
import torch
from transformers import DPRContextEncoderTokenizer
from transformers import DPRContextEncoder
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer

In [None]:
def read_split(filename):
    with open(filename, 'r', encoding='utf-8') as file:
            text = file.read()
    paragraphs = text.split('\n')
    paragraphs = [para.strip() for para in paragraphs if len(para.strip()) >0]
    return paragraphs

In [None]:
paragraphs = read_split('about_ai.txt')

In [None]:
paragraphs

In [None]:
model_name = "facebook/dpr-ctx_encoder-single-nq-base"
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(model_name)
model_name = "facebook/dpr-ctx_encoder-single-nq-base"
context_encoder = DPRContextEncoder.from_pretrained(model_name)

In [None]:
def encode_contexts(text_list):
    embeddings = []
    for text in text_list:
        inputs = context_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=256)
        outputs = context_encoder(**inputs)
        embeddings.append(outputs.pooler_output)
    return torch.cat(embeddings).detach().numpy()

In [None]:
context_embeddings =encode_contexts(paragraphs)

In [None]:
context_embeddings

In [None]:
embedding_dim = 768
context_embeddings_np = np.array(context_embeddings).astype('float32')

<img src='./images/IMG_0360.jpg' width="800">

In [None]:
index = faiss.IndexFlatL2(embedding_dim)
index.add(context_embeddings_np)

In [None]:
index

In [None]:
tokenizer_model = 'facebook/dpr-question_encoder-single-nq-base'
encoder_model = 'facebook/dpr-question_encoder-single-nq-base'

In [None]:
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_model)
question_encoder = DPRQuestionEncoder.from_pretrained(encoder_model)

In [None]:
question = 'research areas of ai'
question_inputs = question_tokenizer(question, return_tensors='pt')
question_embedding = question_encoder(**question_inputs).pooler_output.detach().numpy()

In [None]:
D, I = index.search(question_embedding, k=4) 

In [None]:
print("D",D,"\nI",I)

In [None]:
for i, idx in enumerate(I[0]):
    print(f"{i+1}: {paragraphs[idx]}")
    print(f"distance {D[0][i]}")  