In [1]:
%%capture
%pip install --upgrade jupyter ipywidgets # due to warning: 
#'TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. 
# See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm'

In [2]:
import os

import numpy as np
from sentence_transformers import SentenceTransformer

os.chdir('..')
import pandas as pd

from db_utils import (create_db, create_embeddings_table,
                      create_pgvector_extension, delete_db,
                      insert_data_into_table, pg_connection)
from embed import HFModels
from retrieval import retrieve_from_pgvector

In [3]:
def generate_encodings(
        sentences: list, 
        model: SentenceTransformer = HFModels.default.value,
        save_to_file: bool = True, 
        filename: str = 'example_embeddings.npy'
        ) -> np.ndarray:
    
    try:
        embeddings = np.load(filename)
        return embeddings
    except FileNotFoundError:
        print(f"File '{filename}' not found. Generating embeddings...")

    model: SentenceTransformer = SentenceTransformer(HFModels.default.value)
    embeddings: np.ndarray = model.encode(sentences=sentences) # shape: (len(sentences), 384)
    if save_to_file: np.save('example_embeddings.npy', embeddings)
        
    return embeddings

In [None]:
# create the database and embeddings table
db_name = 'test_db'
create_db(db_name=db_name)
create_pgvector_extension(db_name)
create_embeddings_table(db_name)
CONN = pg_connection(db_name)
pd.read_sql_query('SELECT * FROM pg_embeddings', CONN)

In [None]:
sentences = ["I'm a physicist and a Data Scientist", "I don't linke the Copenhagen interpretation"]
embeddings: np.ndarray = generate_encodings(sentences)
embeddings = embeddings.tolist()
insert_data_into_table(db_name, sentences, embeddings)
pd.read_sql_query('SELECT * FROM pg_embeddings', CONN)

In [None]:
# example of retrieval


query = 'copenhagen'
print(retrieve_from_pgvector(query, 'test_db'))

In [7]:
CONN.close()

In [None]:
delete_db(db_name)