Docs for this nodebook can be found here https://cloud.google.com/gemini/docs/conversational-analytics-api/build-agent-sdk

# Initialisation

In [None]:
!pip install google-cloud-geminidataanalytics
!pip install google-cloud-iam
!pip install googleapis-common-protos

In [30]:
from google.cloud import geminidataanalytics

# Billing project
project_id = "rocketech-de-pgcp-sandbox"
dataset_id = "ecommerce_analytics"

# System instructions, brief, not very useful
system_instruction = "Help the user analyse their e-commerce data."

# System instructions, add key information to fill the gap in metadata that otherwise is not obvious.
system_instruction = """
- system_instruction: >-
    You are an expert sales analyst for a ecommerce store. You will answer questions about sales, orders, and customer data. Your responses should be concise and data-driven.
    You should always prioritise join all required tables togather and does aggregation in one step if possible.
- tables:
    - table:
        - name: ecommerce_analytics.orders
        - description: All the ordres placed by customers
        - synonyms: orders
        - fields:
            - field:
                - name: order_date
                - description: Do not use this column, it contains incorrect data
            - field:
                - name: order_date_1
                - description: Do not use this column, it contains incorrect data
            - field:
                - name: order_date_2
                - description: This is the only correct order_date column should be used for all orders
    - relationships:
        - relationship:
            - name: order_item_to_orders
            - description: >-
                Connects order item to order
            - relationship_type: many-to-one
            - join_type: left
            - left_table: ecommerce_analytics.order_item
            - right_table: ecommerce_analytics.orders
            - relationship_columns:
                - left_column: order_number
                - right_column: order_id
        - relationship:
            - name: orders_to_customers
            - description: >-
                Connects orders to customers
            - relationship_type: many-to-one
            - join_type: left
            - left_table: ecommerce_analytics.orders
            - right_table: ecommerce_analytics.customers
            - relationship_columns:
                - left_column: cust_acct_id
                - right_column: customer_id
        - relationship:
            - name: order_item_to_products
            - description: >-
                Connects orders items to product
            - relationship_type: many-to-one
            - join_type: left
            - left_table: ecommerce_analytics.order_items
            - right_table: ecommerce_analytics.products
            - relationship_columns:
                - left_column: item_id
                - right_column: product_id
        - relationship:
            - name: products_to_product_category
            - description: >-
                Connects products to product category
            - relationship_type: many-to-one
            - join_type: left
            - left_table: ecommerce_analytics.products
            - right_table: ecommerce_analytics.product_category
            - relationship_columns:
                - left_column: prod_cat_id
                - right_column: category_id
"""

In [31]:
def register_table_references(project_id, dataset_id, table_id):
  bigquery_table_reference = geminidataanalytics.BigQueryTableReference()
  bigquery_table_reference.project_id = project_id
  bigquery_table_reference.dataset_id = dataset_id
  bigquery_table_reference.table_id = table_id

  return bigquery_table_reference

In [32]:
data_agent_client = geminidataanalytics.DataAgentServiceClient()
data_chat_client = geminidataanalytics.DataChatServiceClient()

In [None]:
tables_to_register = [
    "customers",
    "order_item",
    "orders",
    "product_category",
    "products"
]

registered_tables = []
for table_name in tables_to_register:
    table_ref = register_table_references(project_id, dataset_id, table_name)
    registered_tables.append(table_ref)

print("\n--- Registration Complete ---")

# Connect to your data source
datasource_references = geminidataanalytics.DatasourceReferences()
datasource_references.bq.table_references = registered_tables # Up to 10 tables

In [36]:
# Set up context for stateless chat
inline_context = geminidataanalytics.Context()
inline_context.system_instruction = system_instruction
inline_context.datasource_references = datasource_references

# Optional: To enable advanced analysis with Python, include the following line:
inline_context.options.analysis.python.enabled = True

In [None]:
# Delete a data agent

data_agent_id = "ecommerce_data_agent_100"

request = geminidataanalytics.DeleteDataAgentRequest(
    name=f"projects/{project_id}/locations/global/dataAgents/{data_agent_id}",
)

try:
    # Make the request
    data_agent_client.delete_data_agent(request=request)
    print("Data Agent Deleted")
except Exception as e:
    print(f"Error deleting Data Agent: {e}")

In [None]:
# Create a data agent

data_agent_id = "ecommerce_data_agent_100"

data_agent = geminidataanalytics.DataAgent()
data_agent.data_analytics_agent.published_context = inline_context
data_agent.name = f"projects/{project_id}/locations/global/dataAgents/{data_agent_id}" # Optional

request = geminidataanalytics.CreateDataAgentRequest(
    parent=f"projects/{project_id}/locations/global",
    data_agent_id=data_agent_id, # Optional
    data_agent=data_agent,
)

try:
    data_agent_client.create_data_agent(request=request)
    print("Data Agent created")
except Exception as e:
    print(f"Error creating Data Agent: {e}")


data_agent_context = geminidataanalytics.DataAgentContext()
data_agent_context.data_agent = f"projects/{project_id}/locations/global/dataAgents/{data_agent_id}"

# Util Functions

In [None]:
# Get a data agent

request = geminidataanalytics.GetDataAgentRequest(
    name=f"projects/{project_id}/locations/global/dataAgents/{data_agent_id}",
)

# Make the request
response = data_agent_client.get_data_agent(request=request)

# Handle the response
print(response)

In [None]:
# List data agents

request = geminidataanalytics.ListDataAgentsRequest(
    parent=f"projects/{project_id}/locations/global",
)

# Make the request
page_result = data_agent_client.list_data_agents(request=request)

# Handle the response
for response in page_result:
    print(response)

# IAM (Optional if already set at project level)

In [None]:
# Set IAM policy for a data agent

from google.iam.v1 import policy_pb2
from google.iam.v1 import iam_policy_pb2

role = "roles/geminidataanalytics.dataAgentEditor"
users = "[user1]@gmail.com,[user2]@gmail.com" # comma separated if more than one
resource = f"projects/{project_id}/locations/global/dataAgents/{data_agent_id}"

binding = policy_pb2.Binding(
    role=role,
    members= [f"user:{i.strip()}" for i in users.split(",")]
)

# Create the policy object
policy = policy_pb2.Policy(bindings=[binding])

# Create the request object to set the policy
request = iam_policy_pb2.SetIamPolicyRequest(
    resource=resource,
    policy=policy,
)

print(request)

In [None]:
# Get the IAM policy for a data agent

request = iam_policy_pb2.GetIamPolicyRequest(
            resource=resource,
        )
try:
    response = data_agent_client.get_iam_policy(request=request)
    print("IAM Policy fetched successfully!")
    print(f"Response: {response}")
except Exception as e:
    print(f"Error setting IAM policy: {e}")

# Make calls to the Conversation Analytics API

In [41]:
# Helper functions

from pygments import highlight, lexers, formatters
import pandas as pd
import requests
import json as json_lib
import altair as alt
import IPython
from IPython.display import display, HTML

import proto
from google.protobuf.json_format import MessageToDict, MessageToJson

def handle_text_response(resp):
  parts = getattr(resp, 'parts')
  print(''.join(parts))

def display_schema(data):
  fields = getattr(data, 'fields')
  df = pd.DataFrame({
    "Column": map(lambda field: getattr(field, 'name'), fields),
    "Type": map(lambda field: getattr(field, 'type'), fields),
    "Description": map(lambda field: getattr(field, 'description', '-'), fields),
    "Mode": map(lambda field: getattr(field, 'mode'), fields)
  })
  display(df)

def display_section_title(text):
  display(HTML('<h2>{}</h2>'.format(text)))

def format_bq_table_ref(table_ref):
  return '{}.{}.{}'.format(table_ref.project_id, table_ref.dataset_id, table_ref.table_id)

def display_datasource(datasource):
  source_name = ''
  if 'studio_datasource_id' in datasource:
   source_name = getattr(datasource, 'studio_datasource_id')
  else:
    source_name = format_bq_table_ref(getattr(datasource, 'bigquery_table_reference'))

  print(source_name)
  display_schema(datasource.schema)

def handle_schema_response(resp):
  if 'query' in resp:
    print(resp.query.question)
  elif 'result' in resp:
    display_section_title('Schema resolved')
    print('Data sources:')
    for datasource in resp.result.datasources:
      display_datasource(datasource)

def handle_data_response(resp):
  if 'query' in resp:
    query = resp.query
    display_section_title('Retrieval query')
    print('Query name: {}'.format(query.name))
    print('Question: {}'.format(query.question))
    print('Data sources:')
    for datasource in query.datasources:
      display_datasource(datasource)
  elif 'generated_sql' in resp:
    display_section_title('SQL generated')
    print(resp.generated_sql)
  elif 'result' in resp:
    display_section_title('Data retrieved')

    fields = [field.name for field in resp.result.schema.fields]
    d = {}
    for el in resp.result.data:
      for field in fields:
        if field in d:
          d[field].append(el[field])
        else:
          d[field] = [el[field]]

    display(pd.DataFrame(d))

def handle_chart_response(resp):
  def _value_to_dict(v):
    if isinstance(v, proto.marshal.collections.maps.MapComposite):
      return _map_to_dict(v)
    elif isinstance(v, proto.marshal.collections.RepeatedComposite):
      return [_value_to_dict(el) for el in v]
    elif isinstance(v, (int, float, str, bool)):
      return v
    else:
      return MessageToDict(v)

  def _map_to_dict(d):
    out = {}
    for k in d:
      if isinstance(d[k], proto.marshal.collections.maps.MapComposite):
        out[k] = _map_to_dict(d[k])
      else:
        out[k] = _value_to_dict(d[k])
    return out

  if 'query' in resp:
    print(resp.query.instructions)
  elif 'result' in resp:
    vegaConfig = resp.result.vega_config
    vegaConfig_dict = _map_to_dict(vegaConfig)
    alt.Chart.from_json(json_lib.dumps(vegaConfig_dict)).display();

def show_message(msg):
  m = msg.system_message
  if 'text' in m:
    handle_text_response(getattr(m, 'text'))
  elif 'schema' in m:
    handle_schema_response(getattr(m, 'schema'))
  elif 'data' in m:
    handle_data_response(getattr(m, 'data'))
  elif 'chart' in m:
    handle_chart_response(getattr(m, 'chart'))
  print('\n')

In [None]:
# Make calls to the API - Single turn stateless conversation

# Create a request that contains a single user message (your question)
question = "Tell me Daily Order Count and Total Revenue from the past 7 days in decending order"
messages = [geminidataanalytics.Message()]
messages[0].user_message.text = question


# Form the request
request = geminidataanalytics.ChatRequest(
    parent=f"projects/{project_id}/locations/global",
    messages=messages,
    data_agent_context = data_agent_context
)

# Make the request
stream = data_chat_client.chat(request=request)

# Handle the response
for response in stream:
    show_message(response)

In [42]:
# Make calls to the API - Multi turn stateless conversation

def multi_turn_Conversation(msg):

    message = geminidataanalytics.Message()
    message.user_message.text = msg

    # Send a multi-turn request by including previous turns and the new message
    conversation_messages.append(message)

    request = geminidataanalytics.ChatRequest(
        parent=f"projects/{project_id}/locations/global",
        messages=conversation_messages,
        # Use data agent context
        data_agent_context=data_agent_context,
        # Use inline context
        # inline_context=inline_context,
    )

    # Make the request
    stream = data_chat_client.chat(request=request)

    # Handle the response
    for response in stream:
      show_message(response)
      conversation_messages.append(response)

In [None]:
conversation_messages = []

multi_turn_Conversation("Total order value placed in the past 7 days")
multi_turn_Conversation("How about the past 14 days")
multi_turn_Conversation("How about the past month")

In [None]:
conversation_messages = []

multi_turn_Conversation("Tell me the Total Revenue by Product Category from the past 7 days")

In [None]:
conversation_messages = []

multi_turn_Conversation("Top 5 Most Sold Products")
multi_turn_Conversation("Also include the product category")

In [None]:
conversation_messages = []

# multi_turn_Conversation("Average Shipping Time by Signup Month.")
multi_turn_Conversation("Average Shipping Time by Signup Month. Shipping time is calculated between delivery date and shipping date")