In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Gemini Batch Prediction

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/tvaroska/gemini-batch/blob/main/gemini_batch_prediction.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2Ftvaroska%2Fgemini-batch%2Fmain%2Fgemini_batch_prediction.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Run in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://github.com/tvaroska/gemini-batch/blob/main/gemini_batch_prediction.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/tvaroska/gemini-batch/blob/main/gemini_batch_prediction.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
</table>


| | |
|-|-|
|Author | [Boris Tvaroska](https://github.com/tvaroska)

## Overview

### Vertex AI Gemini Batch Inference API

For more information, see the [Generative AI on Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview) documentation.


### Objectives

In this tutorial, you will learn how to use the Vertex AI Gemini Batch inference API with the Bigquery.

### Costs

This tutorial uses billable components of Google Cloud:

- Vertex AI
- Bigquery

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.


## Getting Started


### Install Vertex AI SDK for Python


In [None]:
! pip3 install --upgrade --user google-cloud-aiplatform google-cloud-bigquery pydantic tenacity

### Restart current runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel.

In [None]:
# Restart kernel after installs so that your environment can access the new packages
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>
</div>



### Authenticate your notebook environment (Colab only)

If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench).


In [None]:
import sys

# Additional authentication is required for Google Colab
if "google.colab" in sys.modules:
    # Authenticate user to Google Cloud
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information and initialize Vertex AI SDK

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
# Define project information
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}
DATASET = 'batch' # @param {type:"string"}
PRO_MODEL = 'gemini-1.5-pro-001' # @param {type:"string"}

In [None]:
RAW_TABLE = f'{PROJECT_ID}.{DATASET}.raw'
SOURCE_TABLE = f'{PROJECT_ID}.{DATASET}.batch'
TARGET_TABLE = f'{PROJECT_ID}.{DATASET}.result'

In [None]:
# Initialize Vertex AI
import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION)

### Import libraries


In [None]:
from typing import Literal, List

from tqdm.notebook import tqdm
from tenacity import retry, wait_exponential, retry_if_exception_type
from pydantic import BaseModel, Field

from google.api_core.exceptions import ResourceExhausted
from google.cloud.exceptions import Conflict, NotFound
from google.cloud import bigquery

from vertexai.generative_models import GenerationConfig, GenerativeModel
from vertexai.batch_prediction._batch_prediction import BatchPredictionJob

In [None]:
bq = bigquery.Client()
pro = GenerativeModel(model_name=PRO_MODEL)

### Helper functions

- generate_content = retry for online API for quota limitation
- flatten_openapi = modify pydantic openAPI schema for Gemini

In [None]:
@retry(retry = retry_if_exception_type(ResourceExhausted), wait=wait_exponential(multiplier=1, min=4, max=10))
def generate_content(model, prompt, generation_config=None):
    return model.generate_content(prompt, generation_config = generation_config)

In [None]:
# https://gist.github.com/tvaroska/20362bd56a5060f1bc21933bd4fee657
def flatten_openapi(schema):
    try:
        defs = {}

        # Cover recursive submodels
        for key, value in schema['$defs'].items():
            replacement = value

            for pkey in value['properties']:
                if '$ref' in value['properties'][pkey]:
                    replacement['properties'][pkey] = defs[value['properties'][pkey]['$ref']]
                elif 'items' in value['properties'][pkey] and '$ref' in value['properties'][pkey]['items']:
                    replacement['properties'][pkey]['items'] = defs[value['properties'][pkey]['items']['$ref']]
            defs[f'#/$defs/{key}'] = replacement
    except KeyError:
        return schema

    for key in schema['properties']:
        # Replace direct ussage of submodel
        if '$ref' in schema['properties'][key]:
            ref = schema['properties'][key]['$ref']
            schema['properties'][key] = defs[ref]
        # Replace list of submodels
        elif 'items' in schema['properties'][key]:
            if '$ref' in schema['properties'][key]['items']:
                ref = schema['properties'][key]['items']['$ref']
                schema['properties'][key]['items'] = defs[ref]

    del schema['$defs']
    return schema

### Create BQ dataset and tables

In [None]:
# Create dataset

dataset = bigquery.Dataset(f'{PROJECT_ID}.{DATASET}')
dataset.location = LOCATION

bq.create_dataset(dataset, exists_ok=True)

In [None]:
# Table to store transcripts

raw_table = bigquery.Table(
    table_ref=RAW_TABLE,
    schema=[
        bigquery.SchemaField('id', 'INTEGER', mode='REQUIRED'),
        bigquery.SchemaField('agent', 'STRING', mode='REQUIRED'),
        bigquery.SchemaField('customer', 'STRING', mode='REQUIRED'),
        bigquery.SchemaField('transcript', 'RECORD', mode='REPEATED', fields = [
            bigquery.SchemaField('role', 'STRING', mode='REQUIRED'),
            bigquery.SchemaField('content', 'STRING', mode='REQUIRED')
        ])
    ]
)

try:
    bq.create_table(raw_table)            
except Conflict:
    # Table exists, delete everything
    bq.query_and_wait(f'DELETE FROM {RAW_TABLE} WHERE True')

In [None]:
# Table to prepare prompt and parameters for Gemini

source_table = bigquery.Table(
    table_ref=SOURCE_TABLE,
    schema=[
        bigquery.SchemaField('id', 'INTEGER', mode='REQUIRED'),
        bigquery.SchemaField('request', 'JSON', mode='REQUIRED')
    ]
)
try:
    bq.create_table(source_table)
except Conflict:
    # Table exists, delete everything
    bq.query_and_wait(f'DELETE FROM {SOURCE_TABLE} WHERE True')

In [None]:
# Delete table with Gemini outputs

try:
    bq.delete_table(TARGET_TABLE)
except NotFound:
    pass # Table does not have to exists

### Create synthetic call transcripts

Use Gemini 1.5 Pro online API calls with controled generated output to create hypothetical call transcripts

In [None]:
products = ['credit card', 'car insurance']
issues = ['unknown fee', 'lost password']
customers = ['angry and hard to understand', 'calm']
agents = ['helpfull', 'profesional, but distant']
outputs = ['solved', 'escalated to supervisor']

count = 2

In [None]:
class TranscriptLine(BaseModel):
    role: Literal['agent', 'customer']
    content: str

class Transcript(BaseModel):
    id: int
    agent: str = Field(description = 'Name of the agent, first and last name')
    customer: str = Field(description = 'Name of the customer, first and last name')
    transcript: List[TranscriptLine]

raw_schema = flatten_openapi(Transcript.schema())

synthetic_config = GenerationConfig(temperature=2, response_mime_type='application/json', response_schema=raw_schema)

In [None]:
id = 1
transcripts = []

total = len(products) * len(issues) * len(customers) * len(agents) * len(outputs) * count

with tqdm(total = total) as pb:
    for product in products:
        for issue in issues:
            for customer in customers:
                for agent in agents:
                    for output in outputs:
                        for _ in range(count):
                            prompt = f'Generate an example for call center transcripts. Call is about {issue} with {product} product. Customer is {customer} and agent is {agent}. Output of the call is {output}. Call id = {id}'
                            response = generate_content(pro, prompt, generation_config=synthetic_config)
                            id += 1
                            try:
                                transcripts.append(Transcript.parse_raw(response.candidates[0].content.parts[0].text).dict())
                            except IndexError:
                                pass
                            pb.update()

In [None]:
job = bq.load_table_from_json(transcripts, RAW_TABLE)
job.result()

### Transform

Use BQ SQL to transform transcripts in RAW_TABLE into request format for Gemini API

In [None]:
class Response(BaseModel):
    category: str = Field(description='Create apropriate category')
    product: Literal['product', 'oder', 'receipt', 'other']
    sentiment: Literal['positive', 'neutral', 'negative']

In [None]:
transform_query = f"""INSERT INTO {SOURCE_TABLE}(id, request)
  (
SELECT id, JSON_OBJECT(ARRAY_AGG(k), ARRAY_AGG(v)) as request FROM
(
SELECT id, k, v FROM 
(
  (
    SELECT 'generationConfig' as k, JSON_OBJECT('temperature', '0.5', 'responseMimeType', 'application/json', 'responseSchema', PARSE_JSON('{Response.schema_json()}')) as v UNION ALL
    SELECT 'system_instruction',  JSON_OBJECT('parts', JSON_ARRAY(JSON_OBJECT('text', 'You are analyzing customer support calls from the call center. Clasify transcript.')))
  ) CROSS JOIN (SELECT id FROM `boris001.batch.raw`)
) UNION ALL
      SELECT id, 'contents', JSON_ARRAY(JSON_OBJECT('role', 'user', 'parts', JSON_ARRAY(JSON_OBJECT('text', conversation)))) FROM
      (
        SELECT id, STRING_AGG(line, '\\n') as conversation FROM (SELECT id, CONCAT(role, ':', content) as line FROM {RAW_TABLE}, UNNEST(transcript)) GROUP BY id)
      )
    GROUP BY id
  )"""

In [None]:
bq.query_and_wait(transform_query)

### Run batch prediction

and wait for operation to finish

In [None]:
batch_job = BatchPredictionJob.submit(
    source_model='gemini-1.5-pro-001',
    input_dataset=f'bq://{SOURCE_TABLE}',
    output_uri_prefix=f'bq://{TARGET_TABLE}',
)

In [None]:
import time

while not batch_job.has_ended:
    batch_job.refresh()
    print(f"Batch job state: {batch_job.state}")
    time.sleep(30)
print(f"Batch job state: {batch_job.state}")

### Get the output in structured format

In [None]:
output = bq.query_and_wait(f"SELECT id, JSON_VALUE(text.category) as category, JSON_VALUE(text.product) as product, JSON_VALUE(text.sentiment) as sentiment FROM (SELECT id, PARSE_JSON(JSON_VALUE(response.candidates[0].content.parts[0].text)) as text FROM `{TARGET_TABLE}`)").to_dataframe()

In [None]:
output