In [1]:
import psycopg2

# Replace these values with your own database credentials
db_host = "192.168.1.2"
db_port = 55432
db_name = "lnc"
db_user = "postgres"
db_password = "postgres"

# Establish a connection to the PostgreSQL database
connection = psycopg2.connect(
    host=db_host,
    port=db_port,
    dbname=db_name,
    user=db_user,
    password=db_password
)

In [2]:
cursor = connection.cursor()

In [3]:
cursor.execute("SELECT version();")
db_version = cursor.fetchone()

In [4]:
print("Connected to PostgreSQL database. Version:", db_version)

Connected to PostgreSQL database. Version: ('PostgreSQL 15.2 (Debian 15.2-1.pgdg110+1) on x86_64-pc-linux-gnu, compiled by gcc (Debian 10.2.1-6) 10.2.1 20210110, 64-bit',)


In [5]:
import spacy
nlp = spacy.load("el_core_news_sm")

In [6]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('lighteternal/stsb-xlm-r-greek-transfer')

In [7]:
model = model.to('cuda')

In [8]:
print(model.device)

cuda:0


In [9]:
from tqdm import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
import spacy

In [10]:
def process_paragraphs(paragraphs):
    # Split each paragraph into sentences
    docs = list(nlp.pipe(paragraphs))
    sentences = [[sent.text for sent in doc.sents] for doc in docs]

    # Flatten the list of sentences
    sentences_flat = [sent for sublist in sentences for sent in sublist]

    # print("Batch Sentences size:" + str(len(sentences_flat)))

    # Generate an embedding for each sentence
    embeddings = model.encode(sentences_flat)

    # Reshape the embeddings to match the original structure
    embeddings = np.split(embeddings, np.cumsum([len(sents) for sents in sentences[:-1]]))

    # Aggregate the embeddings by taking the mean
    paragraph_embeddings = np.array([np.mean(emb, axis=0) for emb in embeddings])

    return np.array(paragraph_embeddings).tolist()


In [11]:
insert_query = "UPDATE item SET paragraph_embeddings_avg = %s where id = %s"
select_query = """
        SELECT id, title, summary, content 
        FROM item 
        WHERE paragraph_embeddings_avg IS NULL
        /* ORDER BY created_at ASC */
        LIMIT 2048
        """
count_query = """
        SELECT id
        FROM item 
        WHERE paragraph_embeddings_avg IS NULL
        """

In [None]:
from tqdm.notebook import tqdm

# Get total rows count
with connection.cursor() as cursor:
    cursor.execute(count_query)
    total_rows = cursor.rowcount
#total_rows = 7549902

print (f"Total rows: {total_rows}")

pbar = tqdm(total=total_rows, desc="Processing rows", dynamic_ncols=True)

while True:
    try:
        with connection.cursor() as cursor, connection.cursor() as bulk_cursor:
            cursor.execute(select_query)
            total_batch_rows = cursor.rowcount
            rows = cursor.fetchall()
            if not rows:
                raise ValueError("No rows returned from select query.")
            paragraphs = []
            ids = []
            for row in rows:
                id, title, summary, content = row
                ids.append(id)
                paragraph = f"{title}. {(content or summary or '')}"        
                paragraphs.append(paragraph)
        
            paragraph_embeddings = process_paragraphs(paragraphs)

            # Check if the number of paragraphs and ids match
            if len(paragraphs) != len(ids):
                raise ValueError("Mismatch between number of paragraphs and ids.")
            
            bulk_cursor.executemany(insert_query, zip(paragraph_embeddings, ids))
            connection.commit()
            
            # Update the progress bar
            pbar.update(total_batch_rows)
    except Exception as e:
        print(f"An error occured: {e}")
        pbar.close()
        break
    

Total rows: 56611


Processing rows:   0%|                                                                                        …