In [None]:
from pgml import Database
import os
import json

In [None]:
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"

conninfo = os.environ.get("PGML_CONNECTION",local_pgml)
db = Database(conninfo,min_connections=4)

In [None]:
collection_name = "test_pgml_sdk_1"
collection = db.create_or_get_collection(collection_name)

In [None]:
from datasets import load_dataset

data = load_dataset("squad", split="train")
data = data.to_pandas()
data.head()

data = data.drop_duplicates(subset=["context"])
print(len(data))
data.head()

documents = [
    {
        'text': r['context'],
        'metadata': {
            'title': r['title']
        }
    } for r in data.to_dict(orient='records')
]
documents[:3]

In [None]:
collection.upsert_documents(documents[0:200])
collection.generate_chunks()
collection.generate_embeddings()

In [None]:
results = collection.vector_search("Who won 20 Grammy awards?", top_k=2)
print(json.dumps(results,indent=2))

In [None]:
collection.register_model(model_name="paraphrase-MiniLM-L6-v2")

In [None]:
collection.get_models()

In [None]:
print(json.dumps(collection.get_models(),indent=2))

In [None]:
collection.generate_embeddings(model_id=2)

In [None]:
results = collection.vector_search("Who won 20 Grammy awards?", top_k=2, model_id=2)
print(json.dumps(results,indent=2))

In [None]:
collection.register_model(model_name="hkunlp/instructor-xl", model_params={"instruction": "Represent the Wikipedia document for retrieval: "})

In [None]:
collection.get_models()

In [None]:
collection.generate_embeddings(model_id=3)

In [None]:
results = collection.vector_search("Who won 20 Grammy awards?", top_k=2, model_id=3, query_parameters={"instruction": "Represent the Wikipedia question for retrieving supporting documents: "})
print(json.dumps(results,indent=2))

In [None]:
collection.register_text_splitter(splitter_name="RecursiveCharacterTextSplitter",splitter_params={"chunk_size": 100,"chunk_overlap": 20})

In [None]:
collection.generate_chunks(splitter_id=2)

In [None]:
collection.generate_embeddings(splitter_id=2)

In [None]:
results = collection.vector_search("Who won 20 Grammy awards?", top_k=2, splitter_id=2)
print(json.dumps(results,indent=2))

In [None]:
db.delete_collection(collection_name)