## Load & Preprocess Dataset

In [None]:
import json
import random
import matplotlib.pyplot as plt

from typing import Dict, List, Any, Union

from dotenv import load_dotenv
from openai import AsyncOpenAI
from datasets import load_dataset

load_dotenv()

# get dataset
ds = load_dataset("Idavidrein/gpqa", "gpqa_main")
ds = ds["train"]
ds = ds.to_list()

random.seed(42)
random.shuffle(ds)
trainset = ds[:100]
testset = ds[100:201]

from ape.common.types import DatasetItem

trainset = [DatasetItem(inputs={"question": item["Question"]}, outputs={"thought": item["Explanation"], "answer": item["Correct Answer"]}) for item in trainset]
testset = [DatasetItem(inputs={"question": item["Question"]}, outputs={"thought": item["Explanation"], "answer": item["Correct Answer"]}) for item in testset]

testset = [data for data in testset if "ATCG" not in data["outputs"]["answer"]]

## Prepare Prompt to optimize

In [2]:
from ape.common import Prompt

# define prompt
system_prompt = """\
For given science question, solve it step by step.

Question: {question}

You MUST respond in JSON format with the following fields:
thought: the reasoning process of the problem solving.
answer: only return the answer without any explanation.
"""

json_schema = {
    "type": "json_schema", 
    "json_schema": {
        "name": "ScienceProblemSolving",
        "strict": True,
        "schema": {
            "type": "object",
            "properties": {
                "thought": {
                    "type": "string",
                    "description": "The reasoning process of the problem solving"
                },
                "answer": {
                    "type": "string",
                    "description": "The answer to the question"
                }
            },
            "required": ["thought", "answer"],
            "additionalProperties": False
        }
    }
}

student_prompt = Prompt(
    messages=[
        {"role": "system", "content": system_prompt},
    ],
    model="gpt-4o-mini",
    temperature=0.0,
    name="Science Problem Solver",
    response_format=json_schema,
)

## Prepare Generator, Metric, and Global Metric

In [3]:
import asyncio
import time

from ape.common.generator import BaseGenerator
from ape.common.metric import BaseMetric
from ape.common.global_metric import BaseGlobalMetric
from ape.common.types import MetricResult, GlobalMetricResult

# define generator, metric, global metric
openai = AsyncOpenAI()

class ScienceSolver(BaseGenerator):
    async def generate(
        self,
        prompt: Prompt,
        inputs: Dict[str, Any],
    ) -> Union[Dict[str, Any], str]:
        retry_count = 0
        messages = prompt.format(**inputs).messages
        model = prompt.model
        response_format = prompt.response_format

        while retry_count < 3:
            stream_response = None
            try:
                start_time = time.time()
                stream_response = await asyncio.wait_for(
                    openai.chat.completions.create(
                        model=model,
                        messages=messages,
                        response_format=response_format,
                        temperature=0.0,
                        stream=True,
                        frequency_penalty=0.1
                    ),
                    timeout=10.0
                )
                full_response = ""
                async for chunk in stream_response:
                    if time.time() - start_time > 30.0:
                        raise Exception("TimeoutError")
                    
                    if len(chunk.choices) == 0:
                        continue
                    if chunk.choices[0].delta.content is not None:
                        full_response += chunk.choices[0].delta.content

                return json.loads(full_response)

            except asyncio.TimeoutError:
                # print("TimeoutError")
                retry_count += 1
                if retry_count == 3:
                    return {
                        "thought": "error: stream timeout",
                        "answer": "",
                    }
            except Exception as e:
                # print(f"Other Error, {e}")
                retry_count += 1
                if retry_count == 3:
                    return {
                        "thought": f"error: {str(e)}",
                        "answer": "",
                    }

        return {
            "thought": "error: max retries reached",
            "answer": "",
        }



eval_json_schema = {
    "type": "json_schema", 
    "json_schema": {
        "name": "ScienceQuestionSolvingEvaluation",
        "strict": True,
        "schema": {
            "type": "object",
            "properties": {
                "thought": {
                    "type": "string",
                    "description": "The reasoning process of the problem solving evaluation"
                },
                "correctness": {
                    "type": "string",
                    "description": "The correctness of the problem solving"
                }
            },
            "required": ["thought", "correctness"],
            "additionalProperties": False
        }
    }
}


class GPQAMetric(BaseMetric):
    async def compute(
        self,
        dataset_item: DatasetItem,
        pred: Dict[str, Any],
    ) -> MetricResult:
        retry_count = 0
        while retry_count < 3:
            try:
                start_time = time.time()
                stream_response = await asyncio.wait_for(
                    openai.chat.completions.create(
                        model="gpt-4o-mini",
                        messages=[
                            {
                                "role": "system",
                                "content": """\
        YOU ARE one of the GREATEST scientists. You are intelligent and rational. You are prudent and cautious. Your mastery over science is unparalleled. You THINK NATURAL, BROAD AND DEEP. Let's think step by step. 
        Your job is to judge whether the "final_answer" is correct based on "ground_truth_answer", do not be strict on the format, but check the content. Notice that unsolved half results are not Correct. 
        Question: {question_content}
        Is the final_answer correct, given the ground truth answer? Reply with Correct, Wrong or Unknown. 
        "final_answer": "{final_answer}", "ground_truth_answer": "{ground_truth_answer}"

        You MUST respond in JSON format like below:
        {{
            "thought": "...",
            "correctness": "<correctness>", One of "Correct", "Wrong", "Unknown"
        }}
        """.format(question_content=dataset_item["inputs"]["question"], final_answer=pred["answer"], ground_truth_answer=dataset_item["outputs"]["answer"])
                            }
                        ],
                        response_format=eval_json_schema,
                        temperature=0.0,
                        stream=True,
                        frequency_penalty=0.1
                    ),
                    timeout=10.0
                )
                full_response = ""
                async for chunk in stream_response:
                    if time.time() - start_time > 30.0:
                        raise Exception("TimeoutError")
                    if len(chunk.choices) == 0:
                        continue
                    if chunk.choices[0].delta.content is not None:
                        full_response += chunk.choices[0].delta.content

                res_json = json.loads(full_response)
                if res_json["correctness"] == "Correct":
                    return MetricResult(score=1.0)
                else:
                    return MetricResult(score=0.0)

            except asyncio.TimeoutError:
                # print("TimeoutError")
                retry_count += 1
                if retry_count == 3:
                    return MetricResult(score=0.0)
            except Exception as e:
                # print(f"Other Error: {e}")
                retry_count += 1
                if retry_count == 3:
                    return MetricResult(score=0.0)

        return MetricResult(score=0.0)

class GlobalGPQAMetric(BaseGlobalMetric):
    async def compute(
        self,
        results: List[MetricResult],
    ) -> GlobalMetricResult:
        try:
            scores = [result.score for result in results]
            return GlobalMetricResult(
                score=sum(scores) / len(scores) if len(results) > 0 else 0.0,
            )
        except Exception as e:
            # print(e)
            return GlobalMetricResult(
                score=0.0,
            )

## Select Trainer & Run

In [None]:
from ape.core.trainer import (
    TextGradientTrainer,
    ExpelTrainer,
    FewShotTrainer,
    EvoPromptTrainer,
    DspyMiproTrainer,
    OptunaTrainer,
)

# define trainer 
trainer = FewShotTrainer(
    generator=ScienceSolver(),
    metric=GPQAMetric(),
    global_metric=GlobalGPQAMetric(),
    testmode=True # If True, trainer will run prompts for validation set and save results.
)

# run trainer
optimized_prompt, report = await trainer.train(
    prompt=student_prompt,
    trainset=trainset,
    valset=testset,  
)


## Print Optimized Prompt

In [None]:
# print optimized prompt
for message in optimized_prompt.messages:
    print(message)

## Print Benchmark Test Results

In [None]:
# visualize experiment results
def visualize_scores(report):
    scores = report.scores
    trainset_scores = [score["score"] for score in scores]
    valset_scores = [score["val_score"] for score in scores]
    iterations = range(1, len(trainset_scores) + 1)

    plt.figure(figsize=(10, 6))
    plt.plot(iterations, trainset_scores, label='Training Set', marker='o')
    plt.plot(iterations, valset_scores, label='Validation Set', marker='s')
    
    for i, (train_score, val_score) in enumerate(zip(trainset_scores, valset_scores)):
        plt.text(iterations[i], train_score, f'{train_score:.2f}', 
                    ha='center', va='bottom', fontsize=8, color='blue')
        plt.text(iterations[i], val_score, f'{val_score:.2f}', 
                    ha='center', va='bottom', fontsize=8, color='green')

    plt.title('Training and Validation Scores over Iterations')
    plt.xlabel('Iteration')
    plt.ylabel('Score')
    plt.legend()
    plt.show()

visualize_scores(report)