In [1]:
!pip install -U sentence-transformers --quiet

In [2]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (31.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m57.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.12.0


In [22]:
from transformers import AutoTokenizer, AutoModel

import torch
import faiss
import numpy as np

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')


def get_embeddings(texts):
  inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
  with torch.no_grad():
    outputs = model(**inputs)
    # print(outputs)
  return outputs.last_hidden_state[:, 0:].numpy()

medical_facts = [
    "Fever and fatigue are symptoms of COVID-19.",
    "Shortness of breath is a common symptom of asthma.",
    "Chest pain can indicate a heart attack."
]

# embeddings

embeddings = get_embeddings(medical_facts)

In [23]:
embeddings.shape

(3, 15, 768)

In [24]:
# FAISS index

# The embeddings are currently 3D (num_sentences, num_tokens, embedding_dim)
# We need to reshape them to be 2D (num_sentences, embedding_dim) for FAISS
# Taking the embedding of the first token (usually the [CLS] token) for each sentence
dimensions = embeddings.shape[2] # Use the last dimension for the embedding size
index = faiss.IndexFlatL2(dimensions)

# Reshape the embeddings to be 2D
embeddings_2d = embeddings[:, 0, :]

index.add(embeddings_2d)

print(f"FAISS index created with {index.ntotal} vectors of dimension {index.d}")

FAISS index created with 3 vectors of dimension 768


In [27]:
query = "symptoms of asthama"

query_embeddings = get_embeddings(query)

# faiss index

k =2
distances, indices = index.search(query_embeddings[:, 0, :], k)
retrieved_docs = [medical_facts[i] for i in indices[0]]

# print('retrieved documents: ')
print(retrieved_docs)

['Shortness of breath is a common symptom of asthma.', 'Fever and fatigue are symptoms of COVID-19.']
