In [2]:
import vertexai
from vertexai.language_models import TextEmbeddingModel
from vertexai.language_models import TextEmbeddingInput
from sklearn.neighbors import NearestNeighbors
from dotenv import load_dotenv
import os
import numpy as np
import faiss

In [3]:
load_dotenv()
project = os.getenv('GOOGLE_VERTEX_PROJECT')
region = os.getenv('REGION_GOOGLE')
vertexai.init(project=project, location=region)

In [None]:
#text_inp = TextEmbeddingInput(task_type="RETRIVAL_QUERY", text=)

model = TextEmbeddingModel.from_pretrained('textembedding-gecko@003')
textToEmbed = [
    'we fought the battle for 12 hours',
    'there were horrible casualties',
    'in the end we were victorious in the fight'
]
query = "was the battle won?"
embedded_text = np.array([embedding.values for embedding in model.get_embeddings(texts=textToEmbed)])
embedded_query = np.array(model.get_embeddings([query])[0].values) #gets the first and only embedding and puts in query


In [5]:
print("showing embedded query needs axis increase")
print(embedded_query)
print(np.expand_dims(embedded_query, axis=0))

showing embedded query needs axis increase
[ 1.94935109e-02 -1.45250780e-03 -6.43334771e-03 -6.80354387e-02
  3.15410793e-02  4.85962303e-03 -1.32725192e-02 -5.14191051e-04
  4.13768105e-02 -8.37741420e-03 -1.73517969e-02  3.73026319e-02
 -3.43762897e-03 -9.31092817e-03  8.33463622e-04 -1.24842320e-02
  1.29655562e-02  4.91709858e-02  4.57978956e-02 -5.70242852e-03
 -1.29037043e-02  2.69559138e-02 -3.98038849e-02 -4.61396389e-02
  3.11291544e-03 -3.74394841e-02 -1.31389089e-02 -9.52515230e-02
 -3.61240171e-02  3.98861356e-02 -8.06753412e-02  2.68909335e-02
 -3.94430757e-02 -1.58107579e-02  3.26812193e-02 -8.60233977e-02
  3.46312076e-02  4.68937904e-02  1.08206300e-02  1.04509545e-02
  4.37067309e-03 -3.09134368e-02 -2.55491175e-02  2.06173100e-02
 -3.05911116e-02 -1.41051803e-02  1.00234207e-02  5.10836579e-03
 -2.35101171e-02 -5.38653545e-02  4.46330346e-02  2.00878978e-02
  5.00939451e-02 -5.27797937e-02  3.19011472e-02  1.97084807e-03
  1.89302489e-02  1.20100658e-03 -4.24677432e-0

In [9]:
brute = NearestNeighbors(n_neighbors=2, algorithm='brute').fit(embedded_text)
ball = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(embedded_text)
faissThing = faiss.IndexHNSWFlat(768, 3) #indexes and higher levels
faissThing.add(embedded_text)
bruteDistance, bruteIndexes = brute.kneighbors(np.expand_dims(embedded_query, axis=0)) #embedded query needs to have one more wrapper(only 1 vector checking)
ballDistance, ballIndexes = ball.kneighbors(np.expand_dims(embedded_query, axis=0))
faissDistance, faissIndexes = faissThing.search(np.expand_dims(embedded_query, axis=0), k=2) #top 2 k
print("brute then ball then faiss")
print(bruteDistance)
print(bruteIndexes)
print(ballDistance)
print(ballIndexes)
print(faissDistance)
print(faissIndexes)

brute then ball then faiss
[[0.56751815 0.65166942]]
[[2 0]]
[[0.56751815 0.65166942]]
[[2 0]]
[[0.32207686 0.42467302]]
[[2 0]]
