In [None]:
%pip install sqlalchemy psycopg2 pgvector


In [None]:
import torch
import numpy as np
import sqlalchemy
from pgvector.sqlalchemy import Vector
from sqlalchemy.orm import Mapped, mapped_column, Session, DeclarativeBase, sessionmaker

In [None]:
# define connectionstring
# pgurl = 'postgresql://username:password@databasehost:port/databasename'
dbname = f'test_ormalchemy'
pgdburl = f'postgresql+psycopg2://root:root@localhost:54322/{dbname}'
pgrootdburl = 'postgresql+psycopg2://root:root@localhost:54322/root'


In [None]:
rootengine = sqlalchemy.create_engine(pgrootdburl, isolation_level='AUTOCOMMIT', echo=False)    
dataengine = sqlalchemy.create_engine(pgdburl, isolation_level='AUTOCOMMIT', echo=False)

# a sessionmaker(), also in the same scope as the engine
Session = sessionmaker(dataengine)

In [None]:
# define some ORM types
class Base(DeclarativeBase):
    pass

class TensorItem(Base):
    __tablename__ = 'tensors'
    key: Mapped[int] = mapped_column(type_=sqlalchemy.BigInteger, primary_key=True, autoincrement=False)
    embedding: Mapped[Vector] = mapped_column(Vector(5))
    
    def __repr__(self) -> str:
        return f'''
            key: {self.key}
            embedding: {self.embedding}
            embedding type: {type(self.embedding)}
        '''[1:-1]

In [None]:

def init_database():
    with rootengine.connect() as rootconnection:
        # rootconnection.execute(sqlalchemy.text(f'DROP DATABASE "{dbname}";'))
        rows = rootconnection.execute(sqlalchemy.text(f"SELECT 1 FROM pg_database WHERE datname='{dbname}';"))
        if not rows.first():
            print(f"Database '{dbname}' does not exist and is beeing created.")
            rootconnection.execute(sqlalchemy.text(f'CREATE DATABASE "{dbname}";'))
            return True
    return False

def init_tables():
    with dataengine.connect() as dataconnection:
        print(f"Creating tables for empty database '{dbname}'.")
        # create tables
        Base.metadata.create_all(dataconnection)
    return True

def table_size():
    stmt = sqlalchemy.select(sqlalchemy.func.count()).select_from(TensorItem)
    with Session() as datasession:
        count: int = datasession.execute(stmt).scalar()
        return count

nrows = 0 if init_database() and init_tables() else table_size()
print(f'Database {dbname} #rows: {nrows}.')

In [None]:
a = torch.rand((int(1e4), 5))
print(a.shape)

In [None]:
items = [ TensorItem(key=i, embedding=e) for i,e in enumerate(a) ]
print(len(items))

In [None]:
with Session() as datasession:
    with datasession.begin():
        datasession.add_all(items)

In [None]:
# retrieve
stmt = sqlalchemy.select(TensorItem).where(TensorItem.key.in_([1, 2, 7, 8, 12, 241231]))

with Session() as datasession:
    res = datasession.scalars(stmt)    
    embeddings = list(map(lambda item: item.embedding, res))
    arr = np.array(embeddings)
    tensors = torch.tensor(arr, dtype=torch.float32)
    print(tensors.shape)

In [None]:
rootengine.dispose()
dataengine.dispose()

In [None]:
# Notes for full text search with bm25 

# SELECT *
# FROM my_table
# WHERE my_table @@@ '"my query string"'

# SELECT *
# FROM my_table
# WHERE my_table @@@ 'description:keyboard^2 OR electronics:::fuzzy_fields=description&distance=2'