<a href="https://colab.research.google.com/github/wandb/examples/blob/master/colabs/gemini/How_to_use_Gemini_Pro_API_with_WB_Weave.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<!--- @wandbcode{gemini-weave-intro} -->

# How to use Gemini Pro API with W&B Weave

Read [our article](https://wandb.ai/prompt-eng/gemini-weave/reports/How-to-use-Gemini-Pro-API-with-W-B-Weave--Vmlldzo3NzEwNTA1) and follow along in this colab.

## Installation

In [None]:
%%capture
!pip install google-generativeai weave -qqU

In [None]:
import google.generativeai as genai
import weave

## Set up your Google API key and log into W&B Weave

To run the following cell, your API key must be stored it in a Colab Secret named `GOOGLE_API_KEY`. If you don't already have an API key, or you're not sure how to create a Colab Secret, see the [Authentication](https://github.com/google-gemini/cookbook/blob/main/quickstarts/Authentication.ipynb) quickstart for an example.

In [None]:
from google.colab import userdata
GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)

In [None]:
weave.init('prompt-eng/gemini-weave')

## Generate a summary and track it in Weave

In [None]:
%%capture
!wget https://raw.githubusercontent.com/wandb/llm-workshop-fc2024/main/part_2_structured_outputs/longpaper.txt
with open('longpaper.txt', 'r') as file:
    long_paper_text = file.read()

In [None]:
model_info = genai.get_model('models/gemini-1.5-pro-latest')
print(model_info.input_token_limit)

In [None]:
model = genai.GenerativeModel('models/gemini-1.5-pro-latest')
model.count_tokens(long_paper_text)

In [None]:
@weave.op()
def generate_summary(text):
    prompt = "Generate a concise summary of below text:\n"
    response = model.generate_content(prompt + long_paper_text)
    return {
        'summary': response.text
    }

In [None]:
summary = generate_summary(long_paper_text)

## Gemini API JSON Mode

In [None]:
model = genai.GenerativeModel("gemini-1.5-pro-latest",
                              generation_config={"response_mime_type": "application/json"})

In [None]:
from pydantic import BaseModel, Field

class Summary(BaseModel):
    title: str
    summary: str = Field(description="plain short text summary without markdown")

schema = Summary.model_json_schema()
schema

In [None]:
import json

In [None]:
@weave.op()
def create_prompt(text, schema):
    prompt = f"""Generate a concise summary of below text using below JSON schema.
Please output plain text without markdown and limit it to 200 words.
Text:
{text}
JSON schema:
{schema}
"""
    return prompt


@weave.op()
def generate_summary(text, schema):
    prompt = create_prompt(text, schema)
    response = model.generate_content(prompt)
    try:
        output = json.loads(response.text)
    except:
        output = response.text
    return {
        'summary': output
    }

In [None]:
new_summary = generate_summary(long_paper_text, schema)

## Evaluation with Weave

In [None]:
from pydantic import model_validator
import os
import time

os.environ['WEAVE_PARALLELISM'] = '1' # remove parallelism due to our Gemini quota, remove it if not needed


class SummaryModel(weave.Model):
    model_name: str
    prompt_template: str
    json_schema: dict
    model: genai.GenerativeModel

    @model_validator(mode="before")
    def create_model(cls, v):
        model_name = v["model_name"]
        model = genai.GenerativeModel(model_name,
                generation_config={"response_mime_type": "application/json"})
        v["model"] = model
        return v

    @weave.op()
    async def predict(self, text: str) -> dict:
        time.sleep(15) # remove if your Gemini quota allows for it :)
        prompt = self.prompt_template.format(text=text, schema=self.schema)
        response = self.model.generate_content(prompt)
        try:
            output = json.loads(response.text)
            return output[0]
        except:
            return {'summary': response.text}



In [None]:
prompt_template = """Generate a concise summary of below text using below JSON schema.
Please output plain text without markdown and limit it to 200 words.
Text:
{text}
JSON schema:
{schema}
"""

In [None]:
model = SummaryModel(model_name='gemini-1.5-pro-latest', prompt_template=prompt_template, \
                        json_schema=schema)

In [None]:
await model.predict(long_paper_text)

In [None]:
dataset_uri = "weave:///prompt-eng/gemini-weave/object/long_papers:9N9vkE4XY1SYoXLbvbCtP0YKqyqXErilG4XW8jYmQgE"
dataset = weave.ref(dataset_uri).get()

In [None]:
# Scoring function checking format adherence
@weave.op()
def check_formatting(model_output: dict) -> dict:
    # Check if length is smaller than threshold
    result = False
    if type(model_output) == list:
        model_output = model_output[0]
    if type(model_output) == dict:
        if 'summary' in model_output.keys():
            if type(model_output['summary']) == str:
                result = True
    return {'formatting': result}

In [None]:
# Scoring function checking length of summary
@weave.op()
def check_conciseness(model_output: dict) -> dict:
    # Check if length is smaller than threshold
    result = False
    if type(model_output) == list:
        model_output = model_output[0]
    if type(model_output) == dict:
        if 'summary' in model_output.keys():
            summary = model_output['summary']
            if type(summary) == str:
                result = len(summary.split()) < 300
    return {'conciseness': result}

In [None]:
evaluation = weave.Evaluation(
    dataset=dataset, scorers=[check_formatting, check_conciseness],
)

In [None]:
await evaluation.evaluate(model)