In [23]:
import sqlalchemy
from pgvector.sqlalchemy import Vector
from sqlalchemy import Column, String, text, Index
from datetime import datetime, timezone
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session

# Replace these with your actual PostgreSQL connection details
config = {
    'user': 'spotify_playlist_generator',
    'password': 'spotify_playlist_generator',
    'host': 'localhost',
    'port': '5432',
    'database': 'spotify_playlist_generator',
}

DATABASE_URL: str = f"postgresql+psycopg2://{config['user']}:{config['password']}@{config['host']}:5432/{config['database']}"
local_timezone: datetime.tzinfo = datetime.now(timezone.utc).astimezone().tzinfo

# Create a SQLAlchemy engine and session
engine: sqlalchemy.engine.Engine = create_engine(DATABASE_URL, pool_pre_ping=True)
Base = declarative_base()
class SongEmbedding(Base):
    __tablename__ = "song_embedding"
    __table_args__ = (
        Index("idx_id", "id"),
    )

    id: str = Column(String, primary_key=True, index=True, nullable=False)
    embedding: list[float | int] = Column(Vector(128))

with Session(engine) as session:
    session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
    session.commit()
    
Base.metadata.create_all(bind=engine)

In [64]:
from sqlalchemy import select
from sqlalchemy.orm import Session

from src.db.schemas.song_embedding import SongEmbedding
from src.db.tables.embeddings import SongEmbedding as SongEmbeddingSQL


def add_object(db: Session, obj: SongEmbeddingSQL) -> SongEmbeddingSQL:
    db.add(obj)
    db.commit()
    db.refresh(obj)
    return obj


def insert_embeddings(db: Session, songs_embedding: SongEmbedding | list[SongEmbedding]) -> None:
    if not isinstance(songs_embedding, (list, SongEmbedding)):
        print("insert_embeddings - songs_embedding is not of type SongEmbedding or list[SongEmbedding]")
        return

    if isinstance(songs_embedding, SongEmbedding):
        add_object(db, SongEmbeddingSQL(**songs_embedding.model_dump()))
        return

    for song_embedding in songs_embedding:
        add_object(db, SongEmbeddingSQL(**songs_embedding.model_dump()))
    return


def get_embeddings(db: Session, songs_id: str | list[str] | SongEmbedding | None = None) -> None | SongEmbeddingSQL | list[SongEmbeddingSQL]:
    if not songs_id:
        return db.query(SongEmbeddingSQL).all()
    
    if not isinstance(songs_id, (list, str, SongEmbedding)):
        print("get_embeddings - songs_embedding is not of type SongEmbedding or list[SongEmbedding]")
        return None

    if isinstance(songs_id, str):
        return db.query(SongEmbeddingSQL).filter(SongEmbeddingSQL.id == songs_id).first()
    if isinstance(songs_id, SongEmbedding):
        return db.query(SongEmbeddingSQL).filter(SongEmbeddingSQL.id == songs_id.id).first()

    return db.query(SongEmbeddingSQL).filter(SongEmbeddingSQL.id.in_(songs_id)).all()


def _update_single_embedding(db: Session, songs_embedding: SongEmbedding) -> None:
    if not isinstance(songs_embedding, SongEmbedding):
        print("update_single_embedding - songs_embedding is not of type SongEmbedding")

    db_song: SongEmbedding = get_embeddings(db, songs_embedding.id)
    if not db_song:
        print(f"Song {songs_embedding.id} not present in the db")
        return
    try:
        db.query(SongEmbeddingSQL).filter(SongEmbeddingSQL.id == songs_embedding.id).update({SongEmbeddingSQL.embedding.name: embedding.embedding})
    except Exception as e:
        print(f"Porcodio: {e}")



def update_embeddings(db: Session, songs_embedding: SongEmbedding | list[SongEmbedding]) -> None:
    if not isinstance(songs_embedding, (list, SongEmbedding)):
        print("update_embeddings - songs_embedding is not of type SongEmbedding or list[SongEmbedding]")
        return None

    if isinstance(songs_embedding, SongEmbedding):
        return _update_single_embedding(db, songs_embedding)

    for song_embedding in songs_embedding:
        _update_single_embedding(db, song_embedding)

    return None


def get_closest_embedding(
        db: Session, songs_embedding: SongEmbedding | list[SongEmbedding], k: int = 3
) -> list[SongEmbeddingSQL]:
    if not isinstance(songs_embedding, (list, SongEmbedding)):
        print("songs_embedding is not of type SongEmbedding or list[SongEmbedding]")
        return []

    if isinstance(songs_embedding, SongEmbedding):
        return db.query(SongEmbeddingSQL).order_by(SongEmbeddingSQL.embedding.cosine_distance(songs_embedding.embedding)).limit(k).all()


    nearest_neighbors_songs: list[SongEmbedding] = []
    for song_embedding in songs_embedding:
        nearest_neighbors_songs.extend(
            db.query(SongEmbeddingSQL).order_by(SongEmbeddingSQL.embedding.cosine_distance(song_embedding.embedding)).limit(k).all()
        )

    return set(nearest_neighbors_songs)


In [53]:
import numpy as np
import uuid
with Session(engine) as session:
    for i in range(5):
        embedding = SongEmbedding(id=str(uuid.uuid4()), embedding=np.random.randn(128))
        insert_embeddings(session, embedding)

In [67]:
with Session(engine) as session:
    embeddings = [SongEmbedding(id=str(uuid.uuid4()), embedding=np.random.randn(128)) for i in range(3)]
    closest_embedding = get_closest_embedding(session, embeddings, k=3)
