In [None]:
!pip install tf-keras

In [None]:
!pip install FlagEmbedding

In [None]:
import psycopg2
from FlagEmbedding import BGEM3FlagModel
from tqdm import tqdm  # <-- Barre de progression

def main():
    # -------------------------------------------------------------------
    # 1) Paramètres de connexion
    # -------------------------------------------------------------------
    dbname = "postgres"
    user = "jeremie"
    password = ""
    host = "localhost"
    port = "5432"

    # -------------------------------------------------------------------
    # 2) Connexion à PostgreSQL
    # -------------------------------------------------------------------
    conn = psycopg2.connect(
        dbname=dbname,
        user=user,
        password=password,
        host=host,
        port=port
    )

    try:
        with conn.cursor() as cur:
            # -------------------------------------------------------------------
            # 2a) Récupérer, pour chaque vidéo, les 10 commentaires
            #     les plus likés (tirés au hasard en cas d'ex aequo).
            # -------------------------------------------------------------------
            query_top_10_comments = """
            SELECT *
            FROM (
                SELECT
                    v.video_id,
                    c.comment_id,
                    c.comment_text_original,
                    c.comment_like_count,
                    ROW_NUMBER() OVER (
                        PARTITION BY v.video_id
                        ORDER BY c.comment_like_count DESC, RANDOM()
                    ) AS rn
                FROM video v
                JOIN comment c ON c.comment_video_id = v.video_id
            ) sub
            WHERE sub.rn <= 10;
            """
            cur.execute(query_top_10_comments)
            rows = cur.fetchall()
            # rows est une liste de tuples :
            # (video_id, comment_id, comment_text_original, comment_like_count, rn)

            # -------------------------------------------------------------------
            # 2b) Préparation des données à encoder
            # -------------------------------------------------------------------
            texts_to_encode = []
            meta_data = []  # Pour stocker (comment_id, video_id, comment_text, comment_likes)

            for row in rows:
                video_id       = row[0]
                comment_id     = row[1]
                comment_text   = row[2]
                comment_likes  = row[3]

                texts_to_encode.append(comment_text)
                meta_data.append((comment_id, video_id, comment_text, comment_likes))

            # -------------------------------------------------------------------
            # 3) Chargement du modèle BGE-M3
            # -------------------------------------------------------------------
            model_name = "bge-m3"  # Nom pour identifier le modèle dans la table
            model = BGEM3FlagModel(
                "BAAI/bge-m3",  # ID sur Hugging Face
                use_fp16=True   # Accélère l'inférence
            )

            # -------------------------------------------------------------------
            # 4) Génération des embeddings en batch
            # -------------------------------------------------------------------
            batch_size = 16  # Ajuste selon tes ressources
            encoding_output = model.encode(
                texts_to_encode,
                batch_size=batch_size,
                max_length=512
            )
            embeddings = encoding_output["dense_vecs"]  # shape: [n, dim]

            # -------------------------------------------------------------------
            # 6) Insertion des embeddings (avec barre de progression)
            # -------------------------------------------------------------------
            insert_query = """
            INSERT INTO comment_embedding (comment_id, model_name, embedding)
            VALUES (%s, %s, %s)
            ON CONFLICT (comment_id, model_name)
            DO UPDATE SET embedding = EXCLUDED.embedding
            """

            # Boucle d’insertion avec tqdm pour voir la progression
            for i, emb_vector in enumerate(tqdm(embeddings, desc="Inserting embeddings", total=len(embeddings))):
                comment_id = meta_data[i][0]

                # FlagEmbedding renvoie généralement un ndarray ou list
                # Convertissons au besoin en liste :
                if hasattr(emb_vector, "tolist"):
                    emb_vector = emb_vector.tolist()

                cur.execute(
                    insert_query,
                    (comment_id, model_name, emb_vector)
                )

            # -------------------------------------------------------------------
            # 7) Validation de la transaction
            # -------------------------------------------------------------------
            conn.commit()

    finally:
        # 8) Fermeture de la connexion
        conn.close()

if __name__ == "__main__":
    main()