In [0]:
%pip install openai


In [0]:
import os

os.environ["AOAI_ENDPOINT"] = "AOAI_Endpoint"
os.environ["AOAI_API_KEY"] = "AOAI_API_KEY"
os.environ["AOAI_EMBED_DEPLOYMENT"] = "text-embedding-3-small"  


In [0]:
import os

AOAI_ENDPOINT         = os.environ["AOAI_ENDPOINT"]
AOAI_API_KEY          = os.environ["AOAI_API_KEY"]
AOAI_EMBED_DEPLOYMENT = os.environ["AOAI_EMBED_DEPLOYMENT"]  

In [0]:
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType, FloatType

CHUNK_TEXT_DIR  = "dbfs:/silver/jsai2025/chunks_text"
CHUNK_EMBED_DIR = "dbfs:/silver/jsai2025/chunks_embed"

# ステージ1（チャンク）の出力を読込
chunk_df = (
    spark.read
         .format("delta")
         .load(CHUNK_TEXT_DIR)
)

# 出力スキーマ
embed_schema = StructType([
    StructField("pdf_path",  StringType(), False),
    StructField("chunk_id",  IntegerType(), False),
    StructField("text",      StringType(), False),
    StructField("embedding", ArrayType(FloatType()), False),
])

def embed_partition(rows_iter):
    """
    1パーティション単位で Azure OpenAI embedding を呼ぶ。
    長すぎるテキストはログを出してスキップする簡易版。
    """
    from openai import AzureOpenAI

    client = AzureOpenAI(
        api_key=AOAI_API_KEY,
        azure_endpoint=AOAI_ENDPOINT,
        api_version="2024-02-15-preview",
    )

    MAX_CHARS = 8000  # ざっくりの上限（必要に応じて調整）

    for row in rows_iter:
        # まず text を正規化
        text = row.text or ""
        text = text.strip()

        # ★ ここで長すぎるチャンクをスキップ
        if len(text) > MAX_CHARS:
            print(
                f"[WARN] skip long text: "
                f"pdf_path={row.pdf_path}, chunk_id={row.chunk_id}, len={len(text)}"
            )
            continue

        # ここから先は「許容サイズのテキスト」だけが来る
        resp = client.embeddings.create(
            model=AOAI_EMBED_DEPLOYMENT,  # text-embedding-3-small のデプロイ名
            input=text,
        )
        emb = resp.data[0].embedding  # list[float]

        yield Row(
            pdf_path=row.pdf_path,
            chunk_id=int(row.chunk_id),
            text=row.text,
            embedding=emb,
        )

chunk_df_repart = chunk_df.repartition(32, "pdf_path")

embedded_rdd = chunk_df_repart.rdd.mapPartitions(embed_partition)
embedded_df  = spark.createDataFrame(embedded_rdd, schema=embed_schema)

display(embedded_df.limit(5))

(
    embedded_df
    .write
    .mode("overwrite")
    .format("delta")
    .save(CHUNK_EMBED_DIR)
)
