# Postgresql vector store and retrieve

In [1]:
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col, concat, lit
from pyspark.sql.types import ArrayType, FloatType, IntegerType, StringType
from pyspark.sql.types import StructType, StructField, DataType
from sentence_transformers import SentenceTransformer
from duckduckgo_search import DDGS
from pgvector.sqlalchemy import Vector
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.orm import declarative_base
from sqlalchemy.dialects.postgresql import ARRAY
import numpy as np


## Initialize Spark Session

In [2]:
spark = SparkSession.builder \
    .appName('SparkPostgreSQL') \
    .config('spark.jars', '/opt/spark_hadoop_3/jars/postgresql-42.6.0.jar') \
    .getOrCreate()


24/11/07 12:21:43 WARN Utils: Your hostname, FarmRaider2 resolves to a loopback address: 127.0.1.1; using 192.168.2.85 instead (on interface enp3s0)
24/11/07 12:21:43 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/11/07 12:21:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


## Database connection details

In [3]:
db_properties = {
    'url': 'jdbc:postgresql://localhost:5432/penman',
    'user': 'nuwan',
    'password': 'kotchchi7281',
    'driver': 'org.postgresql.Driver'
}

## SQLAlchemy setup for pgvector

In [5]:
Base = declarative_base()
engine = create_engine(f"postgresql://{db_properties['user']}:{db_properties['password']}@localhost:5432/penman")

class SearchResult(Base):
    __tablename__ = 'search_results'

    id = Column(Integer, primary_key=True, index=True)
    title = Column(String, index=True)
    body = Column(String)
    url = Column(String)
    embedding = Column(Vector(384))  # Dimension for 'all-MiniLM-L6-v2' model

Base.metadata.create_all(engine)

## Initialize the embedding model

In [4]:
model = SentenceTransformer('all-MiniLM-L6-v2')

## UDF for creating embeddings

In [37]:
from pyspark.sql.functions import explode

@udf(returnType=ArrayType(FloatType()))
def create_embedding(text):
    return model.encode(text).tolist()

def search_and_store(query: str, num_results: int = 5):
    ddgs = DDGS()
    results = list(ddgs.text(query, max_results=num_results))
    
    # Create DataFrame from search results
    df = spark.createDataFrame(results)
    df = df.withColumnRenamed('href','url')
    
    # Create embeddings
#     df = df.withColumn('text_to_embed', col('title') + ' ' + col('body'))
    df = df.withColumn('text_to_embed', concat(col('title'),lit(' '),col('body')))
    df = df.withColumn('embedding', create_embedding('text_to_embed'))
    df = df.drop('text_to_embed')

    # Write to PostgreSQL
    df.write.jdbc(url=db_properties['url'], table='search_results', mode='append', properties=db_properties)
    return df

def get_similar_results(query: str, limit: int = 5):
    query_embedding = model.encode(query).tolist()

    ''' create schema to convert pgsql vector to array'''
#         StructField('embedding',  DataType.createArrayType(StringType()), True),

    customSchema = StructType([
        StructField('id', IntegerType(), True),
        StructField('title', StringType(), True),
        StructField('body', StringType(), True),
        StructField('url', StringType(), True),
        StructField('embedding', StringType(), True), #ArrayType(FloatType()), True),
    ])
    # Read from PostgreSQL
#     df = spark.read.jdbc(
#         url=db_properties['url'], 
#         table='search_results', 
#         properties=db_properties,
#         options = {'schema':customSchema}
#     )
# .option('schema', customSchema) \\
    df = spark.read.format('jdbc') \
                .option('url', db_properties['url']) \
                .option('dbtable', 'search_results') \
                .option('user', db_properties['user']) \
                .option('password', db_properties['password']) \
                .option('numPartitions', '1') \
                .option('schema', customSchema) \
                .load()
    print("\nraw  from DB")
    df.printSchema()

    @udf(returnType=ArrayType(FloatType()))
    def conv(col):
        return udf(col, ArrayType(FloatType()))
    df = df.withColumn('embedding', conv('embedding'))
    print("\nafter conv")
    df.printSchema()

    # emb_data = [x['embedding'] for x in df.select(col('embedding')).collect()]
    # print("emb_data\n",type(emb_data),emb_data)
    
    # UDF for cosine similarity
    @udf(returnType=FloatType())
    def cosine_similarity(vec1, vec2):
        print(type(vec1),vec1)
        print(type(vec2),vec2)
#         return 0.0
        return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))

    # # Calculate similarities and sort
    # df = df.withColumn('similarity', cosine_similarity(emb_data, query_embedding))
    # #     df = df.withColumn('similarity', cosine_similarity(col('embedding'), query_embedding))
    # results = df.orderBy(col('similarity').desc()).limit(limit).collect()
    
    # return results

## Example usage

In [15]:
# if __name__ == '__main__':
# Search and store results
search_query = 'artificial intelligence'
df=search_and_store(search_query)
print("stored %d records" % df.count())

                                                                                

## Query similar items

In [38]:
query = 'machine learning applications'
print(f"Similar results for '{query}':")
similar_results = get_similar_results(query)
for result in similar_results:
    print(f"Title: {result.title}")
    print(f"URL: {result.url}")
    print(f"Snippet: {result.body[:100]}...")
    print(f"Similarity: {result.similarity}")
    print()


Similar results for 'machine learning applications':

raw  from DB
root
 |-- id: integer (nullable = true)
 |-- title: string (nullable = true)
 |-- body: string (nullable = true)
 |-- url: string (nullable = true)
 |-- embedding: string (nullable = true)


after conv
root
 |-- id: integer (nullable = true)
 |-- title: string (nullable = true)
 |-- body: string (nullable = true)
 |-- url: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)



24/11/07 12:40:54 ERROR PythonUDFRunner: Python worker exited unexpectedly (crashed)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/opt/spark_hadoop_3/python/lib/pyspark.zip/pyspark/worker.py", line 810, in main
    eval_type = read_int(infile)
  File "/opt/spark_hadoop_3/python/lib/pyspark.zip/pyspark/serializers.py", line 596, in read_int
    raise EOFError
EOFError

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:561)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$2.read(PythonUDFRunner.scala:94)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$2.read(PythonUDFRunner.scala:75)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:514)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$

Py4JJavaError: An error occurred while calling o484.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 2.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2.0 (TID 2) (192.168.2.85 executor driver): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for functools.partial). This happens when an unsupported/unregistered class is being unpickled that requires construction arguments. Fix it by registering a custom IObjectConstructor for this class.
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:759)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:199)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:109)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:122)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.$anonfun$evaluate$6(BatchEvalPythonExec.scala:95)
	at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:1211)
	at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:1217)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:320)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.writeIteratorToStream(PythonUDFRunner.scala:57)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:440)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2088)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:274)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2785)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2721)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2720)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2720)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1206)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2984)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2923)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2912)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:971)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2263)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2284)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2303)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2328)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1019)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:405)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1018)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:448)
	at org.apache.spark.sql.Dataset.$anonfun$collectToPython$1(Dataset.scala:3997)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4167)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:526)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4165)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:118)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:195)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:103)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:827)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4165)
	at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:3994)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:78)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:567)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:831)
Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for functools.partial). This happens when an unsupported/unregistered class is being unpickled that requires construction arguments. Fix it by registering a custom IObjectConstructor for this class.
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:759)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:199)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:109)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:122)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.$anonfun$evaluate$6(BatchEvalPythonExec.scala:95)
	at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:1211)
	at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:1217)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:320)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.writeIteratorToStream(PythonUDFRunner.scala:57)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:440)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2088)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:274)


## CRUD operations

In [8]:
def create_result(title: str, body: str, url: str):
    text_to_embed = f'{title} {body}'
    embedding = model.encode(text_to_embed).tolist()
    df = spark.createDataFrame([(title, body, url, embedding)], ['title', 'body', 'url', 'embedding'])
    df.write.jdbc(url=db_properties['url'], table='search_results', mode='append', properties=db_properties)

def read_result(id: int):
    df = spark.read.jdbc(url=db_properties['url'], table='search_results', properties=db_properties)
    return df.filter(col('id') == id).collect()

def update_result(id: int, title: str = None, body: str = None, url: str = None):
    df = spark.read.jdbc(url=db_properties['url'], table='search_results', properties=db_properties)
    row = df.filter(col('id') == id).collect()[0]
    
    title = title or row.title
    body = body or row.body
    url = url or row.url
    
    text_to_embed = f'{title} {body}'
    embedding = model.encode(text_to_embed).tolist()
    
    update_df = spark.createDataFrame([(id, title, body, url, embedding)], ['id', 'title', 'body', 'url', 'embedding'])
    update_df.write.jdbc(url=db_properties['url'], table='search_results', mode='overwrite', properties=db_properties)

def delete_result(id: int):
    df = spark.read.jdbc(url=db_properties['url'], table='search_results', properties=db_properties)
    df_filtered = df.filter(col('id') != id)
    df_filtered.write.jdbc(url=db_properties['url'], table='search_results', mode='overwrite', properties=db_properties)

In [None]:
# Example CRUD operations
# create_result('New AI Research', 'Breakthrough in natural language processing', 'https://example.com/ai-research')
# result = read_result(1)
# if result:
#     print('Read result:', result[0])
# update_result(1, title='Updated AI Research')
# delete_result(2)

# spark.stop()