# PgVectorizer

In [None]:
#| default_exp pgvectorizer

In [None]:
#| export
import psycopg2.pool
from contextlib import contextmanager
import psycopg2.extras
import pgvector.psycopg2
import numpy as np
import re

from timescale_vector import client

In [None]:
#| export
def _create_ident(base: str, suffix: str):
     if len(base) + len(suffix) > 62:
            base = base[:62 - len(suffix)]
     return re.sub(r'[^a-zA-Z0-9_]', '_', f"{base}_{suffix}")

class Vectorize:
    def __init__(self,
                 service_url: str, 
                 table_name: str,
                 schema_name: str='public',
                 id_column_name: str='id', 
                 work_queue_table_name: str=None, 
                 trigger_name: str='track_changes_for_embedding', 
                 trigger_name_fn: str=None) -> None:
        self.service_url = service_url
        self.table_name_unquoted = table_name
        self.schema_name_unquoted = schema_name
        self.table_name = client.QueryBuilder._quote_ident(table_name)
        self.schema_name = client.QueryBuilder._quote_ident(schema_name)
        self.id_column_name = client.QueryBuilder._quote_ident(id_column_name)
        if work_queue_table_name is None:
            work_queue_table_name = _create_ident(table_name, 'embedding_work_queue')
        self.work_queue_table_name = client.QueryBuilder._quote_ident(work_queue_table_name)
        
        self.trigger_name = client.QueryBuilder._quote_ident(trigger_name)

        if trigger_name_fn is None:
            trigger_name_fn = _create_ident(table_name, 'wq_for_embedding')
        self.trigger_name_fn = client.QueryBuilder._quote_ident(trigger_name_fn) 


    def register(self):        
        with psycopg2.connect(self.service_url) as conn:
            with conn.cursor() as cursor:
                cursor.execute(f"""
                    SELECT to_regclass('{self.schema_name}.{self.work_queue_table_name}') is not null; 
                """)
                table_exists = cursor.fetchone()[0]
                if table_exists:
                    return
                
                cursor.execute(f"""
                    CREATE TABLE {self.schema_name}.{self.work_queue_table_name} (
                        id int
                    );

                    CREATE INDEX ON {self.schema_name}.{self.work_queue_table_name}(id);

                    CREATE OR REPLACE FUNCTION {self.schema_name}.{self.trigger_name_fn}() RETURNS TRIGGER LANGUAGE PLPGSQL AS $$ 
                    BEGIN 
                        IF (TG_OP = 'DELETE') THEN
                            INSERT INTO {self.work_queue_table_name} 
                            VALUES (OLD.{self.id_column_name});
                        ELSE
                            INSERT INTO {self.work_queue_table_name} 
                            VALUES (NEW.{self.id_column_name});
                        END IF;
                        RETURN NULL;
                    END; 
                    $$;

                    CREATE TRIGGER {self.trigger_name} 
                    AFTER INSERT OR UPDATE OR DELETE
                    ON {self.schema_name}.{self.table_name} 
                    FOR EACH ROW EXECUTE PROCEDURE {self.schema_name}.{self.trigger_name_fn}();

                    INSERT INTO {self.schema_name}.{self.work_queue_table_name} SELECT {self.id_column_name} FROM {self.schema_name}.{self.table_name};
                """)

    def process(self, embed_and_write_cb, batch_size:int=10, autoregister=True):
        if autoregister:
            self.register()
            
        with psycopg2.connect(self.service_url) as conn:
            with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
                cursor.execute(f"""
                    SELECT to_regclass('{self.schema_name}.{self.work_queue_table_name}')::oid; 
                """)
                table_oid = cursor.fetchone()[0]
            
                cursor.execute(f"""
                    WITH selected_rows AS (
                        SELECT id
                        FROM {self.schema_name}.{self.work_queue_table_name}
                        LIMIT {int(batch_size)}
                        FOR UPDATE SKIP LOCKED
                    ), 
                    locked_items AS (
                        SELECT id, pg_try_advisory_xact_lock({int(table_oid)}, id) AS locked
                        FROM (SELECT DISTINCT id FROM selected_rows ORDER BY id) as ids
                    ),
                    deleted_rows AS (
                        DELETE FROM {self.schema_name}.{self.work_queue_table_name}
                        WHERE id IN (SELECT id FROM locked_items WHERE locked = true ORDER BY id)
                    )
                    SELECT locked_items.id as locked_id, {self.table_name}.*
                    FROM locked_items
                    LEFT JOIN {self.schema_name}.{self.table_name} ON {self.table_name}.{self.id_column_name} = locked_items.id
                    WHERE locked = true
                    ORDER BY locked_items.id
                """)
                res = cursor.fetchall()
                if len(res) > 0:
                    embed_and_write_cb(res, self)
                return len(res)

In [None]:
#| hide
from dotenv import load_dotenv, find_dotenv
import os

In [None]:
_ = load_dotenv(find_dotenv(), override=True)
service_url = os.environ['TIMESCALE_SERVICE_URL']

In [None]:

#| hide
with psycopg2.connect(service_url) as conn:
    with conn.cursor() as cursor:
        for item in ['blog', 'blog_embedding_work_queue', 'blog_embedding']:
            cursor.execute(f"DROP TABLE IF EXISTS {item};")
        
        for item in ['public','test']:
            cursor.execute(f"DROP SCHEMA IF EXISTS {item} CASCADE;")
            cursor.execute(f"CREATE SCHEMA {item};")

In [None]:
with psycopg2.connect(service_url) as conn:
    with conn.cursor() as cursor:
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS blog (
            id              SERIAL PRIMARY KEY NOT NULL,
            title           TEXT NOT NULL,
            author          TEXT NOT NULL,
            contents        TEXT NOT NULL,
            category        TEXT NOT NULL,
            published_time  TIMESTAMPTZ NULL --NULL if not yet published
        );
        ''')
        cursor.execute('''
            insert into blog (title, author, contents, category, published_time) VALUES ('first', 'mat', 'first_post', 'personal', '2021-01-01');
        ''')


vectorizer = Vectorize(service_url, 'blog')
vectorizer.register()
# should be idempotent
vectorizer.register()

In [None]:
from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from timescale_vector import client
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.timescalevector import TimescaleVector
from datetime import timedelta

In [None]:
def get_document(blog):
    text_splitter = CharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
    )
    docs = []
    for chunk in text_splitter.split_text(blog['contents']):
        content = f"Author {blog['author']}, title: {blog['title']}, contents:{chunk}"
        metadata = {
            "id": str(client.uuid_from_time(blog['published_time'])),
            "blog_id": blog['id'], 
            "author": blog['author'], 
            "category": blog['category'],
            "published_time": blog['published_time'].isoformat(),
        }
        docs.append(Document(page_content=content, metadata=metadata))
    return docs

def embed_and_write(blog_instances, vectorizer):
    TABLE_NAME = vectorizer.table_name_unquoted +"_embedding"
    embedding = OpenAIEmbeddings()
    vector_store = TimescaleVector(
        collection_name=TABLE_NAME,
        service_url=service_url,
        embedding=embedding,
        time_partition_interval=timedelta(days=30),
    )

    # delete old embeddings for all ids in the work queue
    metadata_for_delete = [{"blog_id": blog['locked_id']} for blog in blog_instances]
    vector_store.delete_by_metadata(metadata_for_delete)

    documents = []
    for blog in blog_instances:
        # skip blogs that are not published yet, or are deleted (will be None because of left join)
        if blog['published_time'] != None:
            documents.extend(get_document(blog))

    if len(documents) == 0:
        return

    texts = [d.page_content for d in documents]
    metadatas = [d.metadata for d in documents]
    ids = [d.metadata["id"] for d in documents]
    vector_store.add_texts(texts, metadatas, ids)

vectorizer = Vectorize(service_url, 'blog')
assert vectorizer.process(embed_and_write) == 1
assert vectorizer.process(embed_and_write) == 0

TABLE_NAME = "blog_embedding"
embedding = OpenAIEmbeddings()
vector_store = TimescaleVector(
    collection_name=TABLE_NAME,
    service_url=service_url,
    embedding=embedding,
    time_partition_interval=timedelta(days=30),
)

res = vector_store.similarity_search_with_score("first", 10)
assert len(res) == 1


with psycopg2.connect(service_url) as conn:
    with conn.cursor() as cursor:
        cursor.execute('''
            insert into blog (title, author, contents, category, published_time) VALUES ('2', 'mat', 'second_post', 'personal', '2021-01-01');
            insert into blog (title, author, contents, category, published_time) VALUES ('3', 'mat', 'third_post', 'personal', '2021-01-01');
        ''')
assert vectorizer.process(embed_and_write) == 2
assert vectorizer.process(embed_and_write) == 0

res = vector_store.similarity_search_with_score("first", 10)
assert len(res) == 3

with psycopg2.connect(service_url) as conn:
    with conn.cursor() as cursor:
        cursor.execute('''
            DELETE FROM blog WHERE title = '3';
        ''')
assert vectorizer.process(embed_and_write) == 1
assert vectorizer.process(embed_and_write) == 0
res = vector_store.similarity_search_with_score("first", 10)
assert len(res) == 2

res = vector_store.similarity_search_with_score("second", 10)
assert len(res) == 2
content = res[0][0].page_content
assert "new version" not in content
with psycopg2.connect(service_url) as conn:
    with conn.cursor() as cursor:
        cursor.execute('''
            update blog set contents = 'second post new version' WHERE title = '2';
        ''')
assert vectorizer.process(embed_and_write) == 1
assert vectorizer.process(embed_and_write) == 0
res = vector_store.similarity_search_with_score("second", 10)
assert len(res) == 2
content = res[0][0].page_content
assert "new version" in content


with psycopg2.connect(service_url) as conn:
    with conn.cursor() as cursor:
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS test.blog_table_name_that_is_really_really_long_and_i_mean_long (
            id              SERIAL PRIMARY KEY NOT NULL,
            title           TEXT NOT NULL,
            author          TEXT NOT NULL,
            contents        TEXT NOT NULL,
            category        TEXT NOT NULL,
            published_time  TIMESTAMPTZ NULL --NULL if not yet published
        );
        ''')
        cursor.execute('''
            insert into test.blog_table_name_that_is_really_really_long_and_i_mean_long (title, author, contents, category, published_time) VALUES ('first', 'mat', 'first_post', 'personal', '2021-01-01');
        ''')

vectorizer = Vectorize(service_url, 'blog_table_name_that_is_really_really_long_and_i_mean_long', schema_name='test')
assert vectorizer.process(embed_and_write) == 1
assert vectorizer.process(embed_and_write) == 0