<a href="https://colab.research.google.com/github/erika-cardenas/recipes/blob/main/generative_search_palm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies

In [None]:
!pip install weaviate-client

## Configuration

In [None]:
import weaviate
import json

client = weaviate.Client(
  url="WEAVIATE-INSTANCE-URL",  # URL of your Weaviate instance
  auth_client_secret=weaviate.AuthApiKey(api_key="AUTH-KEY"), # (Optional) If the Weaviate instance requires authentication
  additional_headers={
    "X-PALM-Api-Key": "PALM-API-KEY", # Replace with your PALM key
  }
)

client.schema.get()  # Get the schema to test connection

### Expired Google Cloud Token

The Google Cloud's OAuth 2.0 access tokens only have a **one** hour lifetime. This means you have to replace the expired token with a valid one and it to Weaviate by re-instantiating the client. 

#### Option 1: With Google Cloud CLI

In [None]:
import subprocess
import weaviate

def refresh_token() -> str:
    result = subprocess.run(["gcloud", "auth", "print-access-token"], capture_output=True, text=True)
    if result.returncode != 0:
        print(f"Error refreshing token: {result.stderr}")
        return None
    return result.stdout.strip()

def re_instantiate_weaviate() -> weaviate.Client:
    token = refresh_token()

    client = weaviate.Client(
      url = "https://some-endpoint.weaviate.network",  # Replace with your Weaviate URL
      additional_headers = {
        "X-Palm-Api-Key": token,
      }
    )
    return client

# Run this every ~60 minutes
client = re_instantiate_weaviate()

Then you could run the below cell periodically.

In [None]:
client = re_instantiate_weaviate()

#### Option 2: With `google-auth`

See the links to google-auth in [Python](https://google-auth.readthedocs.io/en/master/index.html) and [Node.js](https://cloud.google.com/nodejs/docs/reference/google-auth-library/latest) libraries.

In [None]:
from google.auth.transport.requests import Request
from google.oauth2.service_account import Credentials
import weaviate

def get_credentials() -> Credentials:
    credentials = Credentials.from_service_account_file('path/to/your/service-account.json', scopes=['openid'])
    request = Request()
    credentials.refresh(request)
    return credentials

def re_instantiate_weaviate() -> weaviate.Client:
    credentials = get_credentials()
    token = credentials.token

    client = weaviate.Client(
      url = "https://some-endpoint.weaviate.network",  # Replace with your Weaviate URL
      additional_headers = {
        "X-Palm-Api-Key": token,
      }
    )
    return client

# Run this every ~60 minutes
client = re_instantiate_weaviate()

Then run the below periodically:

In [None]:
client = re_instantiate_weaviate()

## Schema

In [None]:
# resetting the schema. CAUTION: THIS WILL DELETE YOUR DATA 
client.schema.delete_all()

schema = {
   "classes": [
       {
           "class": "JeopardyQuestion",
           "description": "List of jeopardy questions",
           "vectorizer": "text2vec-palm",
           "moduleConfig": { # specify the model you want to use
               "generative-palm": {
                  "projectId": "YOUR-GOOGLE-CLOUD-PROJECT-ID",    # required. Replace with your value: (e.g. "cloud-large-language-models")
                  "apiEndpoint": "YOUR-API-ENDPOINT",             # optional. Defaults to "us-central1-aiplatform.googleapis.
                  "modelId": "YOUR-GOOGLE-CLOUD-ENDPOINT-ID",     # optional. Defaults to "chat-bison"
                }
           },
           "properties": [
               {
                  "name": "Category",
                  "dataType": ["text"],
                  "description": "Category of the question",
               },
               {
                  "name": "Question",
                  "dataType": ["text"],
                  "description": "The question",
               },
               {
                  "name": "Answer",
                  "dataType": ["text"],
                  "description": "The answer",
                }
            ]
        }
    ]
}

client.schema.create(schema)

print("Successfully created the schema.")

## Import the Data

In [None]:
import requests
url = 'https://raw.githubusercontent.com/weaviate/weaviate-examples/main/jeopardy_small_dataset/jeopardy_tiny.json'
resp = requests.get(url)
data = json.loads(resp.text)

if client.is_ready():

# Configure a batch process
  with client.batch as batch:
      batch.batch_size=100
      # Batch import all Questions
      for i, d in enumerate(data):
          print(f"importing question: {i+1}")

          properties = {
              "answer": d["Answer"],
              "question": d["Question"],
              "category": d["Category"],
          }

          client.batch.add_data_object(properties, "JeopardyQuestion")
else:
  print("The Weaviate cluster is not connected.")

## Generative Search Queries

### Single Result

Single Result makes a generation for each individual search result. 

In the below example, I want to create a Facebook ad from the Jeopardy question about Elephants. 

In [None]:
generatePrompt = "Turn the following Jeogrady question into a Facebook Ad: {question}"

result = (
  client.query
  .get("JeopardyQuestion", ["question"])
  .with_generate(single_prompt = generatePrompt)
  .with_near_text({
    "concepts": ["Elephants"]
  })
  .with_limit(1)
).do()

print(json.dumps(result, indent=1))

### Grouped Result

Grouped Result generates a single response from all the search results. 

The below example is creating a Facebook ad from the 3 retrieved Jeoprady questions about animals. 

In [None]:
generateTask = "Explain why these Jeopardy questions are under the Animals category."

result = (
  client.query
  .get("JeopardyQuestion", ["question"])
  .with_generate(grouped_task = generateTask)
  .with_near_text({
    "concepts": ["Animals"]
  })
  .with_limit(3)
).do()

print(json.dumps(result, indent=1))