## Summarize & Extract text (PDF files + Images) using Spark and Gemini

### Overview

This notebook shows how to perform summarization using Gemini for a set of PDF files and Images

#### **Steps**
Using Spark, 
1) It reads the table of the [Contract Understanding Atticus Dataset (CUAD)](https://www.atticusprojectai.org/cuad) dataset located in the [gs://dataproc-metastore-public-binaries/cuad_v1/full_contract_pdf/](https://console.cloud.google.com/storage/browser/dataproc-metastore-public-binaries/cuad_v1)  
   We will create a metadata table poiting to the paths of the image files in the bucket.  
3) It calls [Vertex AI Gemini API](https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart#try_text_prompts) to summarize the text.
4) It saves the output to BigQuery

#### Related content

- [Design summarization prompts](https://cloud.google.com/vertex-ai/docs/generative-ai/text/summarization-prompts)

## ENV Setup

#### Identity and Access Management (IAM)

Make sure the service account running this notebook has the required permissions:

- **Run the notebook**
  - AI Platform Notebooks Service Agent
  - Notebooks Admin
  - Vertex AI Administrator
- **Read files from bucket**
  - Storage Object Viewer
- **Run Dataproc jobs**
  - Dataproc Service Agent
  - Dataproc Worker
- **Call Google APIs (Gemini)**
  - Service Usage Consumer
  - VisionAI Admin
- **BigQuery**
  - BigQuery Data Editor

### Imports

In [None]:
from pyspark.sql.functions import udf

import google.auth
import google.auth.transport.requests
import requests

In [None]:
# When using Dataproc Serverless, installed packages are automatically available on all nodes
!pip3 install --upgrade -q google-cloud-aiplatform google-genai "protobuf~=4.25.3" "numpy~=1.26.4" 
# When using a Dataproc cluster, you will need to install these packages during cluster creation: https://cloud.google.com/dataproc/docs/tutorials/python-configuration

### Authentication

In [None]:
# Get credentials to authenticate with Google APIs
credentials, project_id = google.auth.default()
auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)

## Setup Spark Session

In [None]:
from pyspark.sql import SparkSession

In [None]:
spark = SparkSession.builder \
    .appName("PDF/Image files summarization using Gemini") \
    .enableHiveSupport() \
    .getOrCreate()

#### Parameters

In [None]:
# Change the maximum number of files you want to consider
limit_files = 5

# BigQuery
output_dataset_bq = "output_dataset" # create the BigQuery dataset beforehand
output_table_bq = "summaries"

## Read Online CUAD dataset

#### Read CUAD V1 dataset from metastore

In [None]:
BINARIES_BUCKET_PATH = "gs://dataproc-metastore-public-binaries/cuad_v1/full_contract_pdf/"
cuad_v1_df = spark.read.format("binaryFile").option("recursiveFileLookup", "true").load(BINARIES_BUCKET_PATH).limit(limit_files)

|                path|    modificationTime| length|             content|
|--------------------|--------------------|-------|--------------------|
|gs://dataproc-met...|2023-05-15 20:53:...|3683550|[25 50 44 46 2D 3...|
|gs://dataproc-met...|2023-05-15 20:53:...|2881262|[25 50 44 46 2D 3...|
|gs://dataproc-met...|2023-05-15 20:54:...|1778356|[25 50 44 46 2D 3...|
|gs://dataproc-met...|2023-05-15 20:53:...|1557129|[25 50 44 46 2D 3...|
|gs://dataproc-met...|2023-05-15 20:53:...|1452180|[25 50 44 46 2D 3...|

### Summarize pages using Gemini API

In [None]:
def gemini_predict(gcs_pdf_uri, model_name="gemini-2.0-flash", max_retries=3, initial_delay=1):
    
    import time
    from google import genai
    from google.genai import types
    
    client = genai.Client(
        vertexai=True,
        project=project_id,
        location="us-central1"
    )
    
    generate_content_config = types.GenerateContentConfig(
        response_mime_type = "text/plain"
    )

    contents = [
        types.Part.from_uri(
            file_uri=gcs_pdf_uri,
            mime_type='application/pdf',
        ),
        """ You an expert in reading contracts, articles, agreements, or text in general.
            You are able to create concise summaries of the text provided to you.
            Provide a summary about the attached pdf with about 3 sentences with the most important information from the text.
            Summary:
        """
  ]
    
    retries, delay = 0, initial_delay
    while retries <= max_retries:
        try:
            response = client.models.generate_content(model=model_name,
                                                      contents=contents,
                                                      config=generate_content_config)
            
            return response.text
        except Exception:
            if retries == max_retries:
                return
            time.sleep(delay)
            delay *= 2
            retries += 1
    return ""

In [None]:
summarize_text = udf(gemini_predict)

In [None]:
summaries_df = cuad_v1_df.withColumn("summary", summarize_text(cuad_v1_df["path"]))

In [None]:
summaries_df.show(5,50)

|                                              path|       modificationTime| length|                                           content|                                           summary|
|--------------------------------------------------|-----------------------|-------|--------------------------------------------------|--------------------------------------------------|
|gs://dataproc-metastore-public-binaries/cuad_v1...|2023-05-15 20:53:55.891|3683550|[25 50 44 46 2D 31 2E 34 0A 25 E2 E3 CF D3 0A 3...|Here is a summary of the provided document:\n\n...|
|gs://dataproc-metastore-public-binaries/cuad_v1...|2023-05-15 20:53:57.195|2881262|[25 50 44 46 2D 31 2E 35 0A 25 E2 E3 CF D3 0A 0...|This document is a promotion and distribution a...|
|gs://dataproc-metastore-public-binaries/cuad_v1...|2023-05-15 20:54:00.609|1778356|[25 50 44 46 2D 31 2E 35 0A 25 E2 E3 CF D3 0A 0...|This document is a strategic alliance agreement...|
|gs://dataproc-metastore-public-binaries/cuad_v1...|2023-05-15 20:53:57.902|1557129|[25 50 44 46 2D 31 2E 35 0A 25 E2 E3 CF D3 0A 0...|This PDF is a collaboration agreement between t...|
|gs://dataproc-metastore-public-binaries/cuad_v1...|2023-05-15 20:53:57.659|1452180|[25 50 44 46 2D 31 2E 34 0D 25 C8 C8 C8 C8 C8 C...|This is a Transportation Services Agreement bet...|

## Retrieve from Local Vector DB

### Read from Local Vector Database

You can use `spark.read.format("jdbc")` to read chunks from your local PostgreSQL vector database.

**Note:** The `vector(768)` type from pgvector and `JSONB` types may need special handling. Consider casting them to text/arrays in your SQL query.


In [None]:
# Read chunks from local PostgreSQL vector database
# Note: You'll need the PostgreSQL JDBC driver in your Spark environment
# For Dataproc: Add --jars=gs://spark-lib/postgresql-42.7.1.jar to your Spark session
# For local: Download postgresql-42.7.1.jar and add to Spark jars directory

import os
from pyspark.sql import SparkSession

# Database connection parameters (adjust based on your .env file)
db_host = os.getenv("DB_HOST", "localhost")
db_port = os.getenv("DB_PORT", "5432")
db_name = os.getenv("DB_NAME", "deep_rag")
db_user = os.getenv("DB_USER", "postgres")
db_pass = os.getenv("DB_PASS", "postgres")

# JDBC URL
jdbc_url = f"jdbc:postgresql://{db_host}:{db_port}/{db_name}"

# Read chunks table
# Note: Cast vector(768) to text/array and JSONB to text for Spark compatibility
chunks_df = spark.read.format("jdbc") \
    .option("url", jdbc_url) \
    .option("dbtable", """
        (SELECT 
            chunk_id,
            doc_id,
            page_start,
            page_end,
            section,
            text,
            is_ocr,
            is_figure,
            content_type,
            image_path,
            emb::text as emb_text,  -- Cast vector to text
            meta::text as meta_text  -- Cast JSONB to text
        FROM chunks
        LIMIT 1000) AS chunks_subquery
    """) \
    .option("user", db_user) \
    .option("password", db_pass) \
    .option("driver", "org.postgresql.Driver") \
    .load()

chunks_df.show(5, truncate=False)


### Alternative: Read with Custom Query (More Flexible)

You can also use a custom SQL query to read specific chunks or join with documents table:


In [None]:
# Alternative: Read chunks with documents metadata using custom query
chunks_with_docs_df = spark.read.format("jdbc") \
    .option("url", jdbc_url) \
    .option("query", """
        SELECT 
            c.chunk_id,
            c.doc_id,
            d.title as doc_title,
            d.source_path,
            c.page_start,
            c.page_end,
            c.text,
            c.content_type,
            c.is_ocr,
            c.is_figure,
            c.emb::text as embedding_text,  -- Vector as text
            c.meta::text as chunk_meta       -- JSONB as text
        FROM chunks c
        LEFT JOIN documents d ON c.doc_id = d.doc_id
        WHERE c.content_type = 'text'  -- Filter by content type
        LIMIT 1000
    """) \
    .option("user", db_user) \
    .option("password", db_pass) \
    .option("driver", "org.postgresql.Driver") \
    .load()

chunks_with_docs_df.show(5, truncate=50)


### Important Considerations:

1. **PostgreSQL JDBC Driver**: You need the PostgreSQL JDBC driver (postgresql-42.7.1.jar) in your Spark classpath
   - For Dataproc: Use `--jars=gs://spark-lib/postgresql-42.7.1.jar` when creating Spark session
   - For local: Download and add to Spark jars directory

2. **Vector Type Handling**: The `vector(768)` type from pgvector needs to be cast to text or array
   - Use `emb::text` to get vector as text representation
   - Or use `emb::float8[]` to get as array (if supported)

3. **JSONB Handling**: Cast JSONB to text: `meta::text`

4. **Performance**: For large datasets, consider:
   - Using `partitionColumn`, `lowerBound`, `upperBound` for parallel reads
   - Adding `numPartitions` option for better parallelism
   - Filtering in SQL query rather than after loading

5. **Connection Pooling**: For high-concurrency, consider using connection pooling options


In [None]:
# Example: Parallel read with partitioning (for large datasets)
# This splits the read across multiple Spark partitions for better performance
# Note: Partition by chunk_id (UUID) for parallel reads

chunks_parallel_df = spark.read.format("jdbc") \
    .option("url", jdbc_url) \
    .option("dbtable", """
        (SELECT 
            chunk_id,
            doc_id,
            text,
            emb::text as emb_text,
            meta::text as meta_text
        FROM chunks) AS chunks_subquery
    """) \
    .option("user", db_user) \
    .option("password", db_pass) \
    .option("driver", "org.postgresql.Driver") \
    .option("partitionColumn", "chunk_id") \
    .option("lowerBound", "0") \
    .option("upperBound", "1000000") \
    .option("numPartitions", "10") \
    .load()

print(f"Number of partitions: {chunks_parallel_df.rdd.getNumPartitions()}")
print(f"Total chunks: {chunks_parallel_df.count()}")


## Save to BigQuery

In [None]:
# summaries_df.write \
#             .format("bigquery") \
#             .option("table", f"{project_id}.{output_dataset_bq}.{output_table_bq}") \
#             .option("writeMethod", "direct") \
#             .mode("overwrite") \
#             .save()