In [0]:
%pip install --upgrade pip
%pip install sentence-transformers==2.2.2 torch --quiet
%pip install "huggingface_hub<=0.24.0" "sentence-transformers>=2.6.1"


%restart_python 

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, substring
from sentence_transformers import SentenceTransformer
from delta.tables import DeltaTable

import pandas as pd
import numpy as np
import time

In [0]:
MODEL_NAME = "BAAI/bge-small-en"
print("Model loading:", MODEL_NAME)
model = SentenceTransformer(MODEL_NAME, trust_remote_code=True)

EMBED_DIM = len(model.encode("test").tolist())

@pandas_udf(ArrayType(FloatType()))
def embedding_pandas_udf(texts: pd.Series) -> pd.Series:
    texts_filled = texts.fillna("")
    # Encoding 
    embeddings = model.encode(texts_filled.tolist(), show_progress_bar=True, convert_to_numpy=True)
    return pd.Series([emb.astype(float).tolist() for emb in embeddings])


silver_table = "cvee.job_metadata_silver"
silver_df = spark.table(silver_table).orderBy("job_id")
df_reduced = silver_df.select(["job_id", "vector_text_input"])
df_prepared = df_reduced.withColumn("vector_text_input", F.when(col("vector_text_input").isNull(), F.lit("")).otherwise(col("vector_text_input")))

NUM_PARTITIONS = 100
df_repart = df_prepared.repartition(NUM_PARTITIONS)
print("Partitions:", NUM_PARTITIONS)

gold_df = df_repart.withColumn("embedding", embedding_pandas_udf(F.col("vector_text_input")))
golf_df=gold_df.select("job_id", "embedding")

In [0]:
GOLD_TABLE_PATH = "cvee.job_embedding_gold"

try:
    delta_table = DeltaTable.forName(spark, GOLD_TABLE_PATH)
    old_count = delta_table.toDF().count()

    delta_table.alias("target").merge(
        golf_df.alias("source"),
        "target.job_id = source.job_id"
    ).whenNotMatchedInsertAll().execute()

    new_count = delta_table.toDF().count()
    print(f"Number of rows added: {new_count - old_count} / {golf_df.count()} read")
except:
    golf_df.write.format("delta").mode("overwrite").saveAsTable(GOLD_TABLE_PATH)
    print(f"Table created with {golf_df.count()} rows")