In [None]:
!pip install pyspark spark-nlp==5.3.0 faiss-cpu
from pyspark.sql import SparkSession

In [None]:
from pyspark.sql import SparkSession
from google.colab import drive
import os
import time
import numpy as np
import faiss
import pandas as pd

from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType, LongType

from sparknlp.base import DocumentAssembler
from sparknlp.annotator import BertSentenceEmbeddings
from pyspark.ml import Pipeline

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# --- Configuration ---
# Path to the QA dataset directory on Google Drive
qa_dataset_base_path = '/content/drive/My Drive/QA_dataset'
# Input Parquet file name
input_parquet_filename = "data_qa_combined.parquet"
# Output Parquet file for embeddings
output_embeddings_filename = "qa_combined_embeddings.parquet"
# Output FAISS index file
output_faiss_index_filename = "qa_combined_embeddings.index"

input_parquet_path = os.path.join(qa_dataset_base_path, input_parquet_filename)
embeddings_parquet_path = os.path.join(qa_dataset_base_path, output_embeddings_filename)
faiss_index_path = os.path.join(qa_dataset_base_path, output_faiss_index_filename)

# --- SparkSession Initialization ---
spark = SparkSession.builder \
    .appName("QA_Embeddings_FAISS_FullProcess") \
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.3.0") \
    .config("spark.driver.memory", "8G") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .master("local[*]") \
    .getOrCreate()

print("Spark session created.")
overall_start_time = time.time()

# --- 1. Load QA Data from Parquet ---
print(f"\nLoading QA data from: {input_parquet_path}")
load_data_start_time = time.time()
try:
    df = spark.read.parquet(input_parquet_path)
except Exception as e:
    print(f"Error loading Parquet file: {input_parquet_path}. Error: {e}")
    spark.stop()
    exit()

# Add a unique ID column
df = df.withColumn("qa_id", F.monotonically_increasing_id())

# --- 2. Combine Question and Answer into qa_text ---
print("Combining Question and Answer columns into 'qa_text'...")
df = df.withColumn("qa_text", F.concat_ws(" ",
    F.lit("Question:"), F.col("Question"),
    F.lit("Answer:"), F.col("Answer")
))

# Handle potential nulls
df = df.na.fill({
    "Question": "",
    "Answer": "",
    "qa_text": ""
})
# Ensure 'Level' column is handled if it exists and might be null, or select explicitly.
# For now, assuming 'Level' exists and will be carried forward.
# If 'Level' might not exist, add: if 'Level' not in df.columns: df = df.withColumn('Level', F.lit(None).cast(StringType()))


df.persist()
original_count = df.count()
print(f"Data loaded and 'qa_text' created. Total rows: {original_count}")

if original_count == 0:
    print("No data to process after loading. Exiting.")
    spark.stop()
    exit()

print(f"Data loading and preparation time: {time.time() - load_data_start_time:.2f} seconds.")
df.select("qa_id", "Question", "Answer", "Level", "qa_text").show(5, truncate=50)


# --- 3. Generate Embeddings for qa_text ---
print("\nStarting to generate embeddings for 'qa_text'...")
embedding_generation_start_time = time.time()

document_assembler = DocumentAssembler() \
    .setInputCol("qa_text") \
    .setOutputCol("document_qa")

embeddings_generator = BertSentenceEmbeddings.pretrained("sent_small_bert_L2_128") \
    .setInputCols(["document_qa"]) \
    .setOutputCol("qa_embedding_col") \
    .setCaseSensitive(False)

pipeline = Pipeline(stages=[document_assembler, embeddings_generator])

# Fit the pipeline. We select only necessary columns for fitting.
# If df is very large, consider sampling or using a subset for fitting if applicable,
# though for BERT sentence embeddings, fitting is usually on a dummy df or very fast.
# Here we fit on the actual qa_text data.
pipeline_model = pipeline.fit(df.select("qa_text"))

# Transform to get embeddings
embeddings_df = pipeline_model.transform(df) \
    .select(F.col("qa_id"),
            F.col("Question"),
            F.col("Answer"),
            F.col("Level"), # Assuming Level column exists and you want to keep it
            F.col("qa_embedding_col.embeddings")[0].alias("qa_embedding_vector"))

embeddings_df.persist()
embeddings_count = embeddings_df.count()
print(f"'qa_text' embeddings generated. Rows: {embeddings_count}. Time: {time.time() - embedding_generation_start_time:.2f} seconds.")

if embeddings_count == 0:
    print("No embeddings were generated. Exiting.")
    spark.stop()
    exit()

# --- 4. Save Embeddings to Parquet ---
print(f"\nSaving embeddings to Parquet file: {embeddings_parquet_path}")
save_parquet_start_time = time.time()

embeddings_df.write.mode("overwrite").parquet(embeddings_parquet_path)

print(f"Embeddings saved to Parquet. Time: {time.time() - save_parquet_start_time:.2f} seconds.")

# Clean up original df from cache
df.unpersist()

# --- 5. Load Embeddings from Parquet (Optional, for FAISS building) ---
# This step ensures we use the persisted data for FAISS
print(f"\nReloading embeddings from Parquet for FAISS index construction: {embeddings_parquet_path}")
reload_parquet_start_time = time.time()

loaded_embeddings_spark_df = spark.read.parquet(embeddings_parquet_path)
loaded_embeddings_spark_df.persist() # Cache for FAISS processing
print(f"Embeddings reloaded. Rows: {loaded_embeddings_spark_df.count()}. Time: {time.time() - reload_parquet_start_time:.2f} seconds.")
loaded_embeddings_spark_df.printSchema()
loaded_embeddings_spark_df.show(5, truncate=50)

# --- 6. Build FAISS Index ---
print("\nBuilding FAISS index...")
faiss_build_start_time = time.time()

# Convert Spark DataFrame to Pandas DataFrame to extract NumPy array for FAISS
# Select only 'qa_embedding_vector' as FAISS only needs the vectors.
# 'qa_id' or other metadata will be used for mapping results later if searching.
pandas_df_for_faiss = loaded_embeddings_spark_df.select("qa_embedding_vector").toPandas()

if pandas_df_for_faiss.empty or pandas_df_for_faiss['qa_embedding_vector'].iloc[0] is None:
    print("Error: Embedding data for FAISS is empty or has incorrect format. Exiting.")
    spark.stop()
    exit()

embedding_dim = len(pandas_df_for_faiss['qa_embedding_vector'].iloc[0])
print(f"Embedding dimension for FAISS index: {embedding_dim}")

embeddings_np_for_faiss = np.array(pandas_df_for_faiss['qa_embedding_vector'].tolist()).astype('float32')

if embeddings_np_for_faiss.ndim != 2 or embeddings_np_for_faiss.shape[1] != embedding_dim:
    print(f"Error: NumPy embedding array has incorrect dimensions: {embeddings_np_for_faiss.shape}. Expected (N, {embedding_dim}). Exiting.")
    spark.stop()
    exit()

print(f"Shape of NumPy array for FAISS: {embeddings_np_for_faiss.shape}")

faiss_index = faiss.IndexFlatL2(embedding_dim)
faiss_index.add(embeddings_np_for_faiss)

print(f"FAISS index built. Contains {faiss_index.ntotal} vectors. Time: {time.time() - faiss_build_start_time:.2f} seconds.")

# --- 7. Save FAISS Index ---
print(f"\nSaving FAISS index to: {faiss_index_path}")
faiss_save_start_time = time.time()

faiss.write_index(faiss_index, faiss_index_path)

print(f"FAISS index saved. Time: {time.time() - faiss_save_start_time:.2f} seconds.")

# Clean up cached DataFrames
embeddings_df.unpersist() # This was the one before saving/reloading
loaded_embeddings_spark_df.unpersist()

print(f"\nTotal processing time: {time.time() - overall_start_time:.2f} seconds.")

# --- Stop Spark Session ---
spark.stop()
print("Spark session stopped. Script finished.")
