## Using Pydantic Model as Graphs state_schema to enforce input type validation at runtime



A StateGraph accepts a state_schema argument on initialization that specifies the "shape" of the state that the nodes in the graph can access and update. Pydantic's BaseModel can be used for state_schema to add run time validation on inputs *(Note:traditionally a not intelligent TypedDict or list(for MessageGraph) is used in langchains examples)*

Reference: https://langchain-ai.github.io/langgraph/how-tos/state-model/

### Limitations / Warnings

- The `output` of the graph will **NOT** be an instance of a pydantic model. 
- Run-time validation only occurs on **inputs** into nodes, not on the outputs. 
- The validation error trace from pydantic does not show which node the error arises in. 

### Input Validation 

In [13]:
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from pydantic import BaseModel, ValidationError


# The overall state of the graph (this is the public state shared across nodes)
class OverallState(BaseModel):
    a: str


# The node state (this is the private state of the node)
def node(state: OverallState) -> OverallState:
    return {"a": "goodbye"}


# Build the state graph
builder = StateGraph(OverallState)
builder.add_node(node)  # node_1 is the first node
builder.add_edge(START, "node")  # Start the graph with node_1
builder.add_edge("node", END)  # End the graph after node_1
graph = builder.compile()

# Test the graph with a valid input
graph.invoke({"a": "hello"})

# Test the graph with an invalid input
try:
    print(graph.invoke({"a": 1.3910}))  # This should raise a validation error
except ValidationError as e:
    print(f"Validation error: {e}")

{'a': 'goodbye'}


### Multi-node Graphs with pydantic runtime validation

In [None]:
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict

from pydantic import BaseModel


# The overall state of the graph (this is the public state shared across nodes)
class OverallState(BaseModel):
    a: str


def bad_node(state: OverallState):
    return {
        "a": 123  # Invalid
    }


def ok_node(state: OverallState):
    return {"a": "goodbye"}


# Build the state graph
builder = StateGraph(OverallState)

# define the nodes/functions to use in the graph
builder.add_node(bad_node)
builder.add_node(ok_node)

# define the graph structure
# define the edges between the nodes
builder.add_edge(START, "bad_node")
builder.add_edge("bad_node", "ok_node")
builder.add_edge("ok_node", END)
graph = builder.compile()

# Test the graph with a valid input
try:
    graph.invoke({"a": "hello"})
except Exception as e:
    print("An exception was raised because bad_node sets `a` to an integer.")
    print(e)

### Serialization Behavior 
When using Pydantic models as state schemas, it's important to understand how serialization works

**especially when: - Passing Pydantic objects as inputs - Receiving non pydantic outputs from the graph that need to be convereted back into pydantic model**

In [None]:
from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel

Input object type: <class '__main__.ComplexState'>
Input state type: <class '__main__.ComplexState'>
Nested type: <class '__main__.NestedModel'>
Output type: <class 'langgraph.pregel.io.AddableValuesDict'>
Output content: {'text': 'hello processed', 'count': 1, 'nested': {'value': 'test processed'}}
Converted back to Pydantic: <class '__main__.ComplexState'>


In [None]:
#define the pydantic states
class NestedModel(BaseModel):
    value: str


class ComplexState(BaseModel):
    text: str
    count: int
    nested: NestedModel

# Define the node function
# This function will be called when the node is invoked
def process_node(state: ComplexState):
    # Node receives a validated Pydantic object
    print(f"Input state type: {type(state)}")
    print(f"Nested type: {type(state.nested)}")

    # Return a dictionary update
    return {"text": state.text + " processed"
            , "count": state.count + 1
            , "nested": {"value": state.nested.value + " processed"}
            }

In [None]:
# Build the graph

# Create a state graph with the ComplexState model
builder = StateGraph(ComplexState)
builder.add_node("process", process_node)

#start the graph with the 'process' node
builder.add_edge(START, "process")

#end the graph after the 'process' node
# Note: The END node is not strictly necessary, but it's a good practice to have it
builder.add_edge("process", END)
graph = builder.compile()

# Create a Pydantic instance for input
input_state = ComplexState(text="hello"
                           , count=0
                           , nested=NestedModel(value="test")
                           )

# Invoke the graph with the Pydantic instance
print(f"Input object type: {type(input_state)}")

# Invoke graph with a Pydantic instance
result = graph.invoke(input_state)
print(f"Output type: {type(result)}")
print(f"Output content: {result}")


# Convert back to Pydantic model if needed
# VERY USEFUL: Convert the output back to a Pydantic model
# This is useful if you want to ensure the output is a valid Pydantic model
output_model = ComplexState(**result)
print(f"Converted back to Pydantic: {type(output_model)}")

### Runtime Coercion 
*Warning: Pydantic performs runtime type coercion for certain data types*

In [None]:
class CoercionExample(BaseModel):
    # Pydantic will coerce string numbers to integers
    number: int
    # Pydantic will parse string booleans to bool
    flag: bool


def inspect_node(state: CoercionExample):
    print(f"number: {state.number} (type: {type(state.number)})")
    print(f"flag: {state.flag} (type: {type(state.flag)})")
    return {}


#build the graph
builder = StateGraph(CoercionExample)
builder.add_node("inspect", inspect_node)
builder.add_edge(START, "inspect")
builder.add_edge("inspect", END)
graph = builder.compile()

# Demonstrate coercion with string inputs that will be converted automatically
# This will coerce the string "42" to an integer
result = graph.invoke({"number": "42", "flag": "true"})

# This would fail with a validation error
# because "not-a-number" cannot be coerced to an integer
try:
    graph.invoke({"number": "not-a-number", "flag": "true"})
except Exception as e:
    print(f"\nExpected validation error: {e}")

number: 42 (type: <class 'str'>)
flag: true (type: <class 'str'>)

Expected validation error: 1 validation error for CoercionExample
number
  Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not-a-number', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/int_parsing


### Using Pydantic objects with Message Models

when working with LangChain message types in your state schema, yhou should use ***AnyMessage*** (rather than BaseMessage) for proper serialization/deserialization when using message objects over the wire

In [20]:
from langchain_core.messages import HumanMessage, AIMessage, AnyMessage
from typing import List


class ChatState(BaseModel):
    messages: List[AnyMessage] # List of messages. This can be HumanMessage or AIMessage
    context: str # Context for the chat, can be used to store additional information

# dummy function to simulate a chat
def add_message(state: ChatState):
    return {"messages": state.messages + [AIMessage(content="Hello there!")]}


# Build the state graph with the ChatState model
builder = StateGraph(ChatState)

# Define nodes
builder.add_node("add_message", add_message)

# Define edges
builder.add_edge(START, "add_message")
builder.add_edge("add_message", END)
graph = builder.compile()

# Create input with a message
initial_state = ChatState(
    messages=[HumanMessage(content="Hi")], context="Customer support chat"
)

result = graph.invoke(initial_state)
print(f"Output: {result}")

# Convert back to Pydantic model to see message types
output_model = ChatState(**result)
for i, msg in enumerate(output_model.messages):
    print(f"Message {i}: {type(msg).__name__} - {msg.content}")


Output: {'messages': [HumanMessage(content='Hi', additional_kwargs={}, response_metadata={}), AIMessage(content='Hello there!', additional_kwargs={}, response_metadata={})], 'context': 'Customer support chat'}
Message 0: HumanMessage - Hi
Message 1: AIMessage - Hello there!
