<a href="https://colab.research.google.com/github/mshumer/gpt-oss-pro-mode/blob/main/OpenAI_Open_Source_Pro_Mode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Made by Matt Shumer ([@mattshumer_](https://x.com/mattshumer_) on X).

In [1]:
# @title Run this cell to set up Pro Mode
!pip -q install groq

from typing import List, Dict, Any
import time, os
import concurrent.futures as cf
from groq import Groq, GroqError

MODEL = "openai/gpt-oss-120b"
MAX_COMPLETION_TOKENS = 30000

def _one_completion(client: Groq, prompt: str, temperature: float) -> str:
    """Single non-streaming completion with simple retry/backoff."""
    delay = 0.5
    for attempt in range(3):
        try:
            resp = client.chat.completions.create(
                model=MODEL,
                messages=[{"role": "user", "content": prompt}],
                temperature=temperature,          # 0.9 for candidates; 0.2 for synthesis
                max_completion_tokens=MAX_COMPLETION_TOKENS,
                top_p=1,
                stream=False,
            )
            return resp.choices[0].message.content
        except GroqError as e:
            if attempt == 2:
                raise
            time.sleep(delay)
            delay *= 2

def _build_synthesis_messages(candidates: List[str]) -> List[Dict[str, str]]:
    numbered = "\n\n".join(
        f"<cand {i+1}>\n{txt}\n</cand {i+1}>" for i, txt in enumerate(candidates)
    )
    system = (
        "You are an expert editor. Synthesize ONE best answer from the candidate "
        "answers provided, merging strengths, correcting errors, and removing repetition. "
        "Do not mention the candidates or the synthesis process. Be decisive and clear."
    )
    user = (
        f"You are given {len(candidates)} candidate answers delimited by <cand i> tags.\n\n"
        f"{numbered}\n\nReturn the single best final answer."
    )
    return [{"role": "system", "content": system},
            {"role": "user", "content": user}]

def pro_mode(prompt: str, n_runs: int, groq_api_key: str | None = None) -> Dict[str, Any]:
    """
    Fan out n_runs parallel generations at T=0.9 and synthesize a final answer at T=0.2.
    If groq_api_key is provided, it will be used; otherwise GROQ_API_KEY env var is used.
    Returns: {"final": str, "candidates": List[str]}
    """
    assert n_runs >= 1, "n_runs must be >= 1"
    client = Groq(api_key=groq_api_key) if groq_api_key else Groq()

    # Parallel candidate generations (threaded; Colab-friendly)
    max_workers = min(n_runs, 16)
    candidates: List[str] = [None] * n_runs  # preserve order
    with cf.ThreadPoolExecutor(max_workers=max_workers) as ex:
        fut_to_idx = {
            ex.submit(_one_completion, client, prompt, 0.9): i
            for i in range(n_runs)
        }
        for fut in cf.as_completed(fut_to_idx):
            i = fut_to_idx[fut]
            candidates[i] = fut.result()

    # Synthesis pass
    messages = _build_synthesis_messages(candidates)
    final_resp = client.chat.completions.create(
        model=MODEL,
        messages=messages,
        temperature=0.2,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
        top_p=1,
        stream=False,
    )
    final = final_resp.choices[0].message.content

    return {"final": final, "candidates": candidates}



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
PROMPT = "Explain self-play in reinforcement learning with a concrete example."
NUMBER_OF_CANDIDATES = 5 # start with five, go up if you need more intelligence!
from dotenv import load_dotenv
import os

load_dotenv()  # Load environment variables from a .env file
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

result = pro_mode(PROMPT, NUMBER_OF_CANDIDATES, groq_api_key=GROQ_API_KEY)

print("\n=== FINAL ===\n", result["final"])
# To inspect candidates:
# for i, c in enumerate(result["candidates"], 1): print(f"\n--- Candidate {i} ---\n{c}")