# Callbacks and Streaming

What are “Callbacks & Streaming”?

- Streaming lets you receive model output token-by-token (or chunk-by-chunk) for real-time UIs and faster feedback.
- Callbacks are hooks that fire on events (chain start/end, LLM start/new token/end, tool start/end, retriever, etc.). They enable logging, metrics, tracing, and custom side-effects without tangling your core logic.

We'll cover:

- Basic token streaming with `streaming=True` and `.stream()`
- Built-in `StdOutCallbackHandler`
- Writing a custom `BaseCallbackHandler` (latency + token count)
- Attaching callbacks globally and per-run with `.with_config(callbacks=[...])`
- Callback events across chains, retrievers, and tools (simulated)

## Bootstrap

⚓--- Before proceeding futher it is very important you do the following: --- 👾

Select the 🗝 (key) icon in the left pane and include your OpenAI Api key with Name as "OPENAPI_KEY" and value as the key, and grant it notebook access in order to be able to run this notebook.

Run the below two cells in the order they are in, before running further cells. Wait till a number appears in place of '*' or '[ ]'. Below the cell you should see "✅ Ready: Chat model initialized"

In [None]:
!pip install -q langchain langchain-openai langchain-community

In [None]:
# Environment & imports
import time
from google.colab import userdata

key = userdata.get('OPENAI_API_KEY')  # returns None if not granted
if not key:
    raise RuntimeError("Set OPENAI_API_KEY in a .env file next to this notebook.")

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.callbacks import BaseCallbackHandler, CallbackManager, StdOutCallbackHandler
from langchain_core.runnables import RunnableLambda, RunnableParallel
from langchain.tools import tool

print("✅ Ready: Chat model initialized")

## Streaming

There are two ways to stream tokens as they're generated

- `.stream`: Synchronous -- Works well in plain Python scripts where you don’t want to deal with async/await.
- `.astream`: Asyncrhonous -- Useful when you’re inside an async application (FastAPI, async web server, etc.)

## Asynchronous

The `.astream` should be enclosed in an aysynchronous function.

In [None]:
# Create a small, fast chat model with streaming enabled
llm_stream = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.2,
    streaming=True,  # critical for real-time token events
    callback_manager=CallbackManager([StdOutCallbackHandler()]),  # prints tokens as they arrive
    api_key=key
)

prompt = ChatPromptTemplate.from_messages([
    ("system", "Answer concisely."),
    ("user", "List 5 short ideas to test prompt quality.")
])

chain = prompt | llm_stream | StrOutputParser()

async def streamchain():
    async for chunk in chain.astream({}):
        print(chunk, end="", flush=True)


# Tokens will "type out" thanks to StdOutCallbackHandler
print("\n\n--- Streaming output (async) ---")
await streamchain()

## Synchronous streaming

Here's the same code but in synchronous streaming

In [None]:
llm_stream = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.2,
    streaming=True,   # enables token-by-token streaming
    callbacks=[StdOutCallbackHandler()],  # optional: prints tokens directly
    api_key=key
)

prompt = ChatPromptTemplate.from_messages([
    ("system", "Answer concisely."),
    ("user", "List 5 short ideas to test prompt quality.")
])

# Chain: prompt → llm → output parser
chain = prompt | llm_stream | StrOutputParser()

print("\n\n--- Streaming output (sync) ---")
final = []
for chunk in chain.stream({}):
    # Each chunk is parsed text from the LLM as it arrives
    print(chunk, end="", flush=True)
    final.append(chunk)

# Combine chunks into the final result
result = "".join(final)

## Callbacks

Let’s write a custom handler that records:
- wall-clock latency per generation
- count of streamed tokens
- the final text length

The below setup is appended to the model. Meaning that this logging will happen whereever the llm instance is used and not per call (invoke, stream, etc).

In [None]:
class MetricsHandler(BaseCallbackHandler):
    # Constructor that inits variables -- Called when object is initialized
    def __init__(self, label="run"):
        self.label = label # label just for our reference
        self.t0 = None # variable to record time (latency)
        self.token_count = 0

    # Called when an LLM call starts
    def on_llm_start(self, serialized, prompts, **kwargs):
        self.t0 = time.time()
        self.token_count = 0
        print(f"\n[{self.label}] LLM start. Prompts={len(prompts)}")

    # Called on each new token (for streaming models)
    def on_llm_new_token(self, token, **kwargs):
        self.token_count += 1

    # Called when LLM completes
    def on_llm_end(self, response, **kwargs):
        dt = (time.time() - self.t0) * 1000 if self.t0 else None
        # response.generations is provider-specific; with ChatOpenAI you'll get AIMessage(s)
        try:
            text = response.generations[0][0].text
        except Exception:
            text = getattr(response, "content", "")
        print(f"[{self.label}] LLM end. ~{self.token_count} tokens streamed. {dt:.0f} ms. Final length={len(text)}")

metrics = MetricsHandler(label="demo")
cm = CallbackManager([metrics])

llm_obs = ChatOpenAI(model="gpt-4o-mini", temperature=0.2, streaming=True, callback_manager=cm, api_key=key)

chain_obs = (ChatPromptTemplate.from_messages([
    ("system", "Be concise."),
    ("user", "Explain {topic} in <= 60 words.")
]) | llm_obs | StrOutputParser())

print(chain_obs.invoke({"topic": "why callbacks are useful"}))

### How it works

- Subclass `BaseCallbackHandler` and implement the event methods you care about.
- Attach via a `CallbackManager([...])` to a model.
- The handler records runtime metrics without changing your chain logic.

To know more about the callback events like on_llm_start checkout [Langchain's doc on Callback events](https://python.langchain.com/docs/concepts/callbacks/).

## Per-runnable or run-scoped callbacks

If you want to do the above but per call. This is how you'd do it. Run the previous cell before runnnig this.

In [None]:
temp_handler = MetricsHandler(label="one-off")
llm_one_off = ChatOpenAI(model="gpt-4o-mini", temperature=0.2, streaming=True, api_key=key)

one_off = (ChatPromptTemplate.from_messages([
    ("user", "Give 4 crisp reasons to add callbacks only for this invocation.")
]) | llm_one_off | StrOutputParser()
).with_config(callbacks=[temp_handler])

print(one_off.invoke({}))

## Run-scoped callbacks in RunnableParallel

When you do the same run-scoped callbacks in RunnableParallel, insteading of scoping to the entire chain the callbacks are scoped to the each branch.

In [None]:
def style_prompt(style):
    return ChatPromptTemplate.from_messages([
        ("system", "Rewrite in {style} style."),
        ("user", "{text}")
    ]).partial(style=style)

def styled_chain(style, label):
    return (style_prompt(style) | ChatOpenAI(model="gpt-4o-mini", temperature=0.2, streaming=True, api_key=key)
            | StrOutputParser()).with_config(callbacks=[MetricsHandler(label=label)])

parallel = RunnableParallel(
    formal = styled_chain("formal", "formal"),
    casual = styled_chain("casual", "casual"),
)

out = parallel.invoke({"text": "callbacks help you observe and debug chains"})
print("\n--- FORMAL ---\n", out["formal"])
print("\n--- CASUAL ---\n", out["casual"])


Here, each branch (formal/casual) becomes its own run context, and the callback handler runs independently there.

### How it works?

- Each branch (formal/casual) has its own handler instance/label.
- You’ll see separate start/end logs and token counts.

## Callback Handler for tools

Callback handler and callback events can be used for tools.

In [None]:
# --- Custom handler for tool events ---
class ToolishHandler(BaseCallbackHandler):
    def on_tool_start(self, serialized, input_str, **kwargs):
        print(f"[tool] start: {serialized.get('name', 'unknown')} input={input_str!r}")

    def on_tool_end(self, output, **kwargs):
        print(f"[tool] end: output={str(output)[:40]!r}")


# --- Define a fake tool using @tool ---
@tool("knowledge_lookup", return_direct=True)
def fake_tool(q: str) -> str:
    """Lookup knowledge about a query."""
    time.sleep(0.3)  # simulate latency
    return f"FAKE_LOOKUP({q}) -> MMR balances relevance with diversity."


# --- Prompt template ---
qa_prompt = ChatPromptTemplate.from_messages([
    ("system", "Answer using ONLY this context:\n{context}"),
    ("user", "{question}")
])


# --- Chain definition ---
qa_chain = (
    {"context": fake_tool, "question": lambda x: x["q"]}
    | qa_prompt
    | ChatOpenAI(model="gpt-4o-mini", temperature=0.2, streaming=True, api_key=key)
    | StrOutputParser()
).with_config(callbacks=[ToolishHandler()])


# --- Run ---
print("\n--- QA ---\n")
print(qa_chain.invoke({"q": "What does MMR do?"}))

## Log Errors

In [None]:
class ErrorReporter(BaseCallbackHandler):
    def on_chain_error(self, error, **kwargs):
        print(f"[ERROR] Chain failed: {error}")

def guard(inputs):
    q = inputs.get("question", "")
    if len(q) > 300:
        raise ValueError("Question too long for this demo (max 300 chars).")
    return inputs

guarded_prompt = ChatPromptTemplate.from_messages([
    ("user", "Answer briefly: {question}")
])

guarded = (
    RunnableLambda(guard)
    | guarded_prompt
    | ChatOpenAI(model="gpt-4o-mini", temperature=0.2, streaming=True, api_key=key)
    | StrOutputParser()
)

print(guarded.with_config(callbacks=[ErrorReporter()]).invoke({"question": "What is a callback in LangChain?"}))

# Trigger an error
try:
    long_q = "x" * 400
    guarded.with_config(callbacks=[ErrorReporter()]).invoke({"question": long_q})
except Exception as e:
    print("Raised as expected:", e)