# Vector

> A module for writing and querying vectors to Postgres

In [3]:
#| default_exp core

In [4]:
#| hide
from nbdev.showdoc import *

In [5]:
#| hide
import nbdev; nbdev.nbdev_export()

In [54]:
import asyncpg
import psycopg.sql
import uuid
from pgvector.asyncpg import register_vector
from typing import (List, Optional)
import json 

In [20]:
#| hide
from dotenv import load_dotenv, find_dotenv
import os
_ = load_dotenv(find_dotenv()) 
connection_string  = os.environ['PG_CONNECTION_STRING'] 

In [65]:
class Vector:
    def __init__(
        self,
        connection_string: str,
        table_name: str,
        num_dimensions: int,
        distance_type: str = 'cosine') -> None:
            self.connection_string = connection_string
            self.table_name = table_name
            self.num_dimensions=num_dimensions
            if distance_type == 'cosine' or distance_type == '<=>':
                self.distance_type = '<=>'
            elif distance_type == 'euclidean' or distance_type == '<->' or distance_type == 'l2':
                self.distance_type = '<->'
            else:
                raise ValueError(f"unrecognized distance_type {distance_type}")

    async def __post_init__(
        self,
    ) -> None:
        await self.connect() #test the connection

    def _quote_ident(self, ident):
        return '"{}"'.format(ident.replace('"', '""'))

    def connect(self):
        """
        Establishes a connection to a PostgreSQL database using asyncpg.
    
        Returns:
            asyncpg.Connection: The established database connection.
        """
        async def init(conn):
            await register_vector(conn)
        return asyncpg.create_pool(dsn=self.connection_string, init=init)

    def _get_row_exists_query(self):
        return "SELECT 1 FROM {table_name} LIMIT 1".format(table_name=self._quote_ident(self.table_name))

    async def table_is_empty(self):
        query = self._get_row_exists_query()
        async with self.connect() as pool:
            rec = await pool.fetchrow(query)
            return rec == None

    def get_upsert_query(self):
        return "INSERT INTO {table_name} (id, metadata, contents, embedding) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING".format(table_name=self._quote_ident(self.table_name))
    
    async def upsert(self, records):
        query = self.get_upsert_query()
        async with self.connect() as pool:
            await pool.executemany(query, records)
            
    def get_create_query(self):
        return "CREATE EXTENSION IF NOT EXISTS vector;" + "\n" + \
        '''
CREATE TABLE IF NOT EXISTS {table_name} (
    id UUID PRIMARY KEY,
    metadata JSONB,
    contents TEXT,
    embedding VECTOR({dimensions})
);


CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} USING GIN(metadata jsonb_path_ops);
'''.format(table_name=self._quote_ident(self.table_name), index_name=self._quote_ident(self.table_name+"_meta_idx"), dimensions=self.num_dimensions)
    
    async def create_tables(self):
        query = self.get_create_query()
        async with self.connect() as pool:
            await pool.execute(query)

    async def _get_approx_count(self):
        #todo optimize with approx
        query = "SELECT COUNT(*) as cnt FROM {table_name}".format(table_name=self._quote_ident(self.table_name))
        async with self.connect() as pool:
            rec = await pool.fetchrow(query)
            return rec[0]
    
    def create_ivfflat_index_query(self, num_records):
        column_name = "embedding" 

        index_method = "invalid"
        if self.distance_type == "<->":
            index_method = "vector_l2_ops"
        elif self.distance_type == "<#>":
            index_method = "vector_ip_ops"
        elif self.distance_type == "<=>":
            index_method = "vector_cosine_ops"
        else:
            raise ValueError(f"unrecognized operator {query_operator}")
        
        num_lists = num_records / 1000
        if num_lists < 10:
            num_lists = 10
        if num_records > 1000000:
            num_lists = math.sqrt(num_records)

        return "CREATE INDEX ON {table_name} USING ivfflat ({column_name} {index_method}) WITH (lists = {num_lists});"\
        .format(table_name=self._quote_ident(self.table_name), column_name=self._quote_ident(column_name), index_method=index_method, num_lists=num_lists)

    async def create_ivfflat_index(self, num_records=None):
        if num_records == None:
            num_records = await self._get_approx_count()
        query = self.create_ivfflat_index_query(num_records)
        async with self.connect() as pool:
            await pool.execute(query)

    def get_similarity_query(self, query_embedding: List[float], k: int=10, filter: Optional[dict] = None):
        params = []
        distance = "embedding {op} ${index}".format(op=self.distance_type, index=len(params)+1)
        params = params + [query_embedding]
        
        where = "TRUE"
        if filter != None:
            where = "metadata @> ${index}".format(index=len(params)+1)
            json_object = json.dumps(filter)
            params = params + [json_object]
        query = '''
        SELECT
            id, metadata, contents, embedding, {distance}
        FROM
           {table_name}
        WHERE 
           {where}
        ORDER BY {distance} ASC
        LIMIT {k}
        '''.format(distance=distance, where=where, table_name=self._quote_ident(self.table_name), k=k)
        return (query, params)

    async def get_similarity(self, query_embedding: List[float], k: int=10, filter: Optional[dict] = None):
        (query, params) = self.get_similarity_query(query_embedding, k, filter)
        async with self.connect() as pool:
            return await pool.fetch(query, *params)

In [66]:
con = await asyncpg.connect(connection_string)
await con.execute("DROP TABLE IF EXISTS data_table;")
await con.close()

vec  = Vector(connection_string, "data_table", 2)
await vec.create_tables()
empty = await vec.table_is_empty()
assert empty
await vec.upsert([(uuid.uuid4(), '''{"key":"val"}''', "the brown fox", [1.0,1.2])])
empty = await vec.table_is_empty()
assert not empty

await vec.upsert([\
    (uuid.uuid4(), '''{"key":"val"}''', "the brown fox", [1.0,1.3]),\
    (uuid.uuid4(), '''{"key":"val2"}''', "the brown fox", [1.0,1.4]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,1.5]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,1.6]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,1.6]),\
    (uuid.uuid4(), '''{"key2":"val2"}''', "the brown fox", [1.0,1.7]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,1.8]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,1.9]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,100.8]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,101.8]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,1.8]),\
    (uuid.uuid4(), '''{"key2":"val"}''', "the brown fox", [1.0,1.8]),\
    (uuid.uuid4(), '''{"key_1":"val_1", "key_2":"val_2"}''', "the brown fox", [1.0,1.8]),\
])

await vec.create_ivfflat_index()

rec = await vec.get_similarity([1.0, 2.0])
assert len(rec) == 10
rec = await vec.get_similarity([1.0, 2.0], k=4)
assert len(rec) == 4
rec = await vec.get_similarity([1.0, 2.0], k=4, filter={"key2":"val2"})
assert len(rec) == 1
rec = await vec.get_similarity([1.0, 2.0], k=4, filter={"key2":"does not exist"})
assert len(rec) == 0
rec = await vec.get_similarity([1.0, 2.0], k=4, filter={"key_1":"val_1"})
assert len(rec) == 1
rec = await vec.get_similarity([1.0, 2.0], filter={"key_1":"val_1", "key_2":"val_2"})
assert len(rec) == 1
rec = await vec.get_similarity([1.0, 2.0], k=4, filter={"key_1":"val_1", "key_2":"val_3"})
assert len(rec) == 0