# Using Palm and W&B

You will need to setup Palm access and the corresponding VertexAI service account. 

See the documentation on how to get started using [Generative AI on Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview)


In [25]:
import vertexai
from vertexai.language_models import TextGenerationModel

You will  need to setup the GCP variables related to your project and zone:

In [26]:
project_id = "wandb-growth"
zone = "us-central1"

Let's try the quickstart example from the [PALM](https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/quickstart-text)

In [27]:
def interview(
    temperature: float,
    project_id: str = project_id,
    location: str = zone,
) -> str:
    """Ideation example with a Large Language Model"""

    vertexai.init(project=project_id, location=location)
    # TODO developer - override these parameters as needed:
    parameters = {
        "temperature": temperature,  # Temperature controls the degree of randomness in token selection.
        "max_output_tokens": 256,  # Token limit determines the maximum amount of text output.
        "top_p": 0.8,  # Tokens are selected from most probable to least until the sum of their probabilities equals the top_p value.
        "top_k": 40,  # A top_k of 1 means the selected token is the most probable among all tokens.
    }

    model = TextGenerationModel.from_pretrained("text-bison@001")
    response = model.predict(
        "Give me ten interview questions for the role of program manager.",
        **parameters,
    )
    return response.text

In [28]:
res = interview(0.7)
print(res)

1. What is your definition of a program manager?
2. What are your key responsibilities as a program manager?
3. What are your biggest challenges as a program manager?
4. What are your biggest successes as a program manager?
5. How do you measure the success of your programs?
6. What are your leadership style and management style?
7. How do you communicate with your team and stakeholders?
8. How do you handle conflict and difficult situations?
9. What are your goals for your career as a program manager?
10. Why are you interested in this position?


## Let's use Weights and Biases Tables to store our model predictions

In [29]:
import time, wandb
from tqdm.auto import tqdm

queries = [
    "The planet earth is the ",
    "Implement a Python function to compute the Fibonacci numbers.",
    "Write a Rust function that performs binary exponentiation.",
    "How do I allocate memory in C?",
    "What are the differences between Javascript and Python?",
    "How do I find invalid indices in Postgres?",
    "How can you implement a LRU (Least Recently Used) cache in Python?",
    "What approach would you use to detect and prevent race conditions in a multithreaded application?",
    "Can you explain how a decision tree algorithm works in machine learning?",
    "How would you design a simple key-value store database from scratch?",
    "How do you handle deadlock situations in concurrent programming?",
    "What is the logic behind the A* search algorithm, and where is it used?",
    "How can you design an efficient autocomplete system?",
    "What approach would you take to design a secure session management system in a web application?",
    "How would you handle collision in a hash table?",
    "How can you implement a load balancer for a distributed system?",
    "What is the fable involving a fox and grapes?",
    "Write a story in the style of James Joyce about a trip to the Australian outback in 2083, to see robots in the beautiful desert.",
    "Who does Harry turn into a balloon?",
    "Write a tale about a time-traveling historian who's determined to witness the most significant events in human history.",
    "Describe a day in the life of a secret agent who's also a full-time parent.",
]

In [30]:
def palm_call(
    prompt: str,
    temperature: float = 0.7,
    max_output_tokens: int = 256,
    top_p: float = 0.8,
    top_k: int = 40,
    project_id: str = project_id,
    location: str = zone,
) -> str:
    vertexai.init(project=project_id, location=location)
    parameters = {
        "temperature": temperature,  # Temperature controls the degree of randomness in token selection.
        "max_output_tokens": max_output_tokens,  # Token limit determines the maximum amount of text output.
        "top_p": top_p,  # Tokens are selected from most probable to least until the sum of their probabilities equals the top_p value.
        "top_k": top_k,  # A top_k of 1 means the selected token is the most probable among all tokens.
    }

    model = TextGenerationModel.from_pretrained("text-bison@001")
    response = model.predict(
        prompt,
        **parameters,
    )
    return response.text

In [7]:
table = wandb.Table(columns=["model", "time", "temperature", "max_output_tokens", "top_p", "top_k", "prompt", "response"])

In [8]:
# let's define some configuration parameters
config = dict(
    temperature = 1.0,
    max_output_tokens = 128,
    top_p = 0.8,
    top_k = 40,
)

# we iterate through the queries and call the model
# adding the results to a table
for q in tqdm(queries):
    t0 = time.perf_counter()
    res = palm_call(q, **config)
    table.add_data(
        "text-bison@001", 
        time.perf_counter() - t0, 
        config["temperature"], 
        config["max_output_tokens"], 
        config["top_p"], 
        config["top_k"], 
        q, 
        res)

  0%|          | 0/21 [00:00<?, ?it/s]

100%|██████████| 21/21 [00:52<00:00,  2.51s/it]


We can now save the table to W&B

In [9]:
wandb.init(project="wandb-palm", config=config)

wandb.log({"palm_samples": table})

wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m. Use [1m`wandb login --relogin`[0m to force relogin


## LangChain Integration

Weights and Biases supports Langchain, this means that you can use Palm in your Langchain application and benefit from the full suite of features that Weights and Biases offers.

In [14]:
import os, random, time, wandb

from langchain.llms import VertexAI
from langchain.agents import AgentType, initialize_agent
from langchain import PromptTemplate, LLMChain
from langchain.tools import BaseTool

from typing import Optional

from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)

os.environ["LANGCHAIN_WANDB_TRACING"] = "true"

In [15]:
wandb.init(project="wandb-palm", job_type="generation")

In [16]:
llm = VertexAI(model="text-bison@001", project="wandb-growth", location="us-central1")

In [17]:
class WorldPickerTool(BaseTool):
    name = "pick_world"
    description = "pick a virtual game world for your character or item naming"
    worlds = [
                "a mystic medieval island inhabited by intelligent and funny frogs",
                "a modern anthill featuring a cyber-ant queen and her cyber-ant-workers",
                "a digital world inhabited by friendly machine learning engineers"
            ]

    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Use the tool."""
        time.sleep(1)
        return random.choice(self.worlds)

    async def _arun(
        self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
    ) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError("pick_world does not support async")
        
class NameValidatorTool(BaseTool):
    name = "validate_name"
    description = "validate if the name is properly generated"

    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Use the tool."""
        time.sleep(1)
        if len(query) < 20:
            return f"This is a correct name: {query}"
        else:
            return f"This name is too long. It should be shorter than 20 characters."

    async def _arun(
        self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
    ) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError("validate_name does not support async")

In [18]:
tools = [WorldPickerTool(), NameValidatorTool()]
agent = initialize_agent(
    tools, 
    llm, 
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    handle_parsing_errors=True,
    verbose=True
)

In [19]:
agent.run(
    "Find a virtual game world for me and imagine the name of a hero in that world"
)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m pick a virtual game world
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma modern anthill featuring a cyber-ant queen and her cyber-ant-workers[0m
Thought:[32;1m[1;3m pick a name for the hero
Action: pick_name
Action Input: None[0m
Observation: pick_name is not a valid tool, try one of [pick_world, validate_name].
Thought:[32;1m[1;3m pick a name for the hero
Action: pick_name
Action Input: None[0m
Observation: pick_name is not a valid tool, try one of [pick_world, validate_name].
Thought:[32;1m[1;3m pick a name for the hero
Action: validate_name
Action Input: "Cyber-ant Queen"[0m
Observation: [33;1m[1;3mThis is a correct name: Cyber-ant Queen[0m
Thought:[32;1m[1;3m I now know the final answer
Final Answer: Cyber-ant Queen in a modern anthill featuring a cyber-ant queen and her cyber-ant-workers[0m

[1m> Finished chain.[0m


'Cyber-ant Queen in a modern anthill featuring a cyber-ant queen and her cyber-ant-workers'

In [20]:
agent.run(
    "Find a virtual game world for me and imagine the name of a jewel in that world"
)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m pick_world
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma modern anthill featuring a cyber-ant queen and her cyber-ant-workers[0m
Thought:[32;1m[1;3m pick_world
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma modern anthill featuring a cyber-ant queen and her cyber-ant-workers[0m
Thought:[32;1m[1;3m pick_world
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma modern anthill featuring a cyber-ant queen and her cyber-ant-workers[0m
Thought:[32;1m[1;3m pick_world
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma digital world inhabited by friendly machine learning engineers[0m
Thought:[32;1m[1;3m pick_world
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma mystic medieval island inhabited by intelligent and funny frogs[0m
Thought:[32;1m[1;3m pick_world
Action: pick_world
Action Input: None[0m
Observation: [3

'Agent stopped due to iteration limit or time limit.'

In [21]:
agent.run(
    "Find a virtual game world for me and imagine the name of food in that world."
)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m Let me pick a virtual game world first.
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma modern anthill featuring a cyber-ant queen and her cyber-ant-workers[0m
Thought:[32;1m[1;3m Now I need to imagine the name of food in that world.
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma modern anthill featuring a cyber-ant queen and her cyber-ant-workers[0m
Thought:[32;1m[1;3m Let me validate the name.
Action: validate_name
Action Input: cyber-ant-queen-cake[0m
Observation: [33;1m[1;3mThis name is too long. It should be shorter than 20 characters.[0m
Thought:[32;1m[1;3m Let me try another name.
Action: pick_world
Action Input: None[0m
Observation: [36;1m[1;3ma mystic medieval island inhabited by intelligent and funny frogs[0m
Thought:[32;1m[1;3m Let me validate the name.
Action: validate_name
Action Input: frog-pie[0m
Observation: [33;1m[1;3mThis is a correct name: f

'frog-pie'

In [22]:
wandb.finish()

## Prompting LLMs correctly to get the best performance

Let's run some toy mathematical operations to see how LLMs work.

In [31]:
expressions = [
    "5 + 3 * 2",
    "10 / (4 - 2)",
    "(7 + 3) * 2 - 5",
    "2^3 + 4",
    "8 / (2 + 2) - 1",
    "3 * (5 - 2^2)",
    "2 * 4 + (6 - 3)",
    "10 / (2 + 2) + 1",
    "(4 - 2) * (6 + 1)",
    "8 / (2 + 2) + (3 - 1)",
    "2^3 * 2 - 4",
    "6 / (2 + 1) - 2",
    "(3 + 2) * 4 - 6",
    "9 / (3 + 2) + 1",
    "(2^3 - 1) * (6 - 2)",
    "12 / (4 + 2) - 1",
    "2 * (3 + 4) - 5",
    "15 / (3 + 2) - 1",
    "(2^2 + 3) * (6 - 1)",
    "18 / (6 + 2) - 2",
]

We can try different prompts and ask the LLM to compute the expression

In [32]:
prompt_template1 = """
The following is the mathematical expression provided by the user.
{question}

Find the answer using the BODMAS return the answer as a float:
"""

In [61]:
prompt_template1.format(question=expressions[0])

'\nThe following is the mathematical expression provided by the user.\n5 + 3 * 2\n\nFind the answer using the BODMAS return the answer as a float:\n'

In [43]:
palm_call(prompt_template1.format(question=expressions[0]))

'11.0'

we need to convert some operations to Python like `^` to `**`

In [44]:
def correct_expression(expr: str) -> str:
    expr = expr.replace(" ", "")
    expr = expr.replace("[", "(")
    expr = expr.replace("]", ")")
    expr = expr.replace("{", "(")
    expr = expr.replace("}", ")")
    expr = expr.replace("^", "**")
    
    return expr

In [45]:
import numexpr, math

def evaluate_expr(expr: str) -> str:
    local_dict = {"pi": math.pi, "e": math.e}

    try:
        expr = correct_expression(expr)
        output = str(
            numexpr.evaluate(
                expr.strip(),
                global_dict={},  # restrict access to globals
                local_dict=local_dict,  # add common mathematical functions
            )
        )
        return float(output)
    except:
        return None

In [46]:
evaluate_expr(expressions[0])

11.0

Let's create a Table with the results

In [48]:
config = dict(
    temperature = 1.0,
    max_output_tokens = 128,
    top_p = 0.8,
    top_k = 40,
)

In [101]:
def evaluate_with_palm(expressions, prompt_template, config):
    "Evaluate Palm on solving simple math expressions"
    accuracy = 0.
    table = wandb.Table(columns = ["prompt", "expression", "true_answer", 
                                   "pred_answer", "temperature", "max_output_tokens", 
                                   "top_p", "top_k" ])
    for exp in tqdm(expressions):
        prompt = prompt_template.format(question=exp)
        palm_answer = palm_call(
            prompt,
            temperature=config["temperature"], 
            max_output_tokens=config["max_output_tokens"], 
            top_p=config["top_p"], 
            top_k=config["top_k"])
        try:
            palm_answer = float(palm_answer)
        except:
            pass
        true_answer = eval(exp)
        table.add_data(
            prompt,
            exp, 
            true_answer, 
            palm_answer,
            config["temperature"], 
            config["max_output_tokens"], 
            config["top_p"], 
            config["top_k"]
        )
        accuracy += palm_answer == true_answer
        wandb.log({"accuracy": 100 * accuracy / len(expressions)})
        wandb.log({"calculated_expression": table})

In [102]:
wandb.init(project="wandb-palm", job_type="msth_calculator")

In [103]:
evaluate_with_palm(expressions, prompt_template1, config)

100%|██████████| 20/20 [00:35<00:00,  1.80s/it]


In [104]:
wandb.finish()

0,1
accuracy,▁▃▃▃▅▅▅▅▆▆▆▆▆▆▆█████

0,1
accuracy,25.0


The model has a hard time solving this simple expression, let's improve the prompt and maybe play with the temperature parameter:

In [105]:
prompt_template2 = """
You are an expert mathematician. You can solve a given mathematical expression using the BODMAS rule.
BODMAS stands for Bracket, Orders of Indices, Division, Multiplication, Addition and Subtraction. The computation should happen in that order.
The dorder is as follows:
B: Solve expressions inside brackets in this order -> small bracket followed by curly bracket and finally square bracket.
O: Solve the indices such as roots, powers, etc.
D: Divide the numbers which are given
M: Multiply the numbers next
A: Sum up the next numbers
S: Subtract the numbers left in the end

The following is the mathematical expression provided by the user.
{question}

Think about it step-by-step. Don't skip steps.

When ready with the answer return the answer as a float:
"""

We can use W&B Sweeps to explore the different combination of parameters:

In [106]:
sweep_config = dict(
    method="random",
    name="palm_sweep",
    metric={"name": "accuracy", "goal": "maximize"},
    parameters=dict(
        prompt_template={"values": [1,2]},
        temperature={"min": 0.1, "max": 1.0},)
)

Let's refactor the evaulation on a single function that depends on the configuration.

In [110]:
config = dict(
    temperature = 1.0,
    max_output_tokens = 128,
    top_p = 0.8,
    top_k = 40,
    prompt_template=1,
)

def sweep_func(config=config):
    wandb.init(config=config)
    config = wandb.config

    if config.prompt_template == 1:
        prompt_template = prompt_template1
    else:
        prompt_template = prompt_template2

    evaluate_with_palm(expressions, prompt_template, config)

    wandb.finish()
    

In [111]:
sweep_func()

100%|██████████| 20/20 [00:36<00:00,  1.81s/it]


0,1
accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,5.0


In [112]:
sweep_id = wandb.sweep(sweep=sweep_config, project="wandb-palm")

Create sweep with ID: 9poww5en
Sweep URL: https://wandb.ai/capecape/wandb-palm/sweeps/9poww5en


In [113]:
wandb.agent(sweep_id=sweep_id, function=sweep_func, count=5)

[34m[1mwandb[0m: Agent Starting Run: kvo67oeu with config:
[34m[1mwandb[0m: 	prompt_template: 1
[34m[1mwandb[0m: 	temperature: 0.41114374043542057
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


100%|██████████| 20/20 [00:37<00:00,  1.90s/it]




0,1
accuracy,▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅███

0,1
accuracy,15.0
