#### `All inputs to lambda functions need to be a SINGLE argument. If a function accepts multiple arguments, it should be wrapped into a single input wrapper and then unpacked.`

In [3]:
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnableParallel
from langchain_community.chat_models import ChatOllama


def length_function(text):
    return len(text)


def _multiple_length_function(text1, text2):
    return len(text1) * len(text2)


def multiple_length_function(_dict):
    return _multiple_length_function(_dict["text1"], _dict["text2"])


prompt = ChatPromptTemplate.from_template("what is {a} + {b}")
model = ChatOllama(model="mistral")

retrieval = RunnableParallel(
    a=itemgetter("foo") | RunnableLambda(length_function),
    b={"text1": itemgetter("foo"), "text2": itemgetter("bar")} | RunnableLambda(multiple_length_function),
)

chain = retrieval | prompt | model

In [4]:
chain.invoke({"foo": "bar", "bar": "gah"})

AIMessage(content=' The sum of the numbers 3 and 9 is 12. In mathematical terms, you can represent this as:\n\n3 (the first number) + 9 (the second number) = 12 (the answer)')

# Accepting a `Runnable Config`

In [10]:
import json
from langchain_community.chat_models import ChatOllama
from langchain_core.runnables import RunnableConfig
from langchain_core.output_parsers import StrOutputParser


def parse_or_fix(text: str, config: RunnableConfig):
    prompt = ChatPromptTemplate.from_messages([
        ("system", "Fix the following text: {input}"), 
        ("system", "Error: {error}"),
        ("system", "Don't narrate, just respond with the fixed data.")
    ])

    model = ChatOllama(model="mistral")

    chain = prompt | model | StrOutputParser()
    
    for _ in range(3):
        try:
            return json.loads(text)
        except Exception as e:
            text = chain.invoke({"input": text, "error": e}, config)
    return "Failed to parse"

In [27]:
from langchain_core.callbacks.base import BaseCallbackHandler

class CustomCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        super().__init__()

    def on_llm_start(self, *args, **kwargs):
        print("llm start")

    def on_llm_new_token(self, *args, **kwargs):
        print("llm new token")
        print(kwargs)

    def on_llm_end(self, *args, **kwargs):
        print("llm end")


runnable_config = RunnableConfig(
    callbacks=[CustomCallbackHandler()],
    tags=["my-tag"],
)

output = RunnableLambda(parse_or_fix).invoke("{foo: bar}", runnable_config)
print(output)

llm start
llm new token
{'token': ' {', 'run_id': UUID('6333041b-e82d-4764-9084-c2d004814998'), 'parent_run_id': UUID('81eba2c1-5951-46d5-88ed-9de7b0f627f1'), 'tags': ['seq:step:2', 'my-tag'], 'chunk': None, 'verbose': False}
llm new token
{'token': ' "', 'run_id': UUID('6333041b-e82d-4764-9084-c2d004814998'), 'parent_run_id': UUID('81eba2c1-5951-46d5-88ed-9de7b0f627f1'), 'tags': ['seq:step:2', 'my-tag'], 'chunk': None, 'verbose': False}
llm new token
{'token': 'foo', 'run_id': UUID('6333041b-e82d-4764-9084-c2d004814998'), 'parent_run_id': UUID('81eba2c1-5951-46d5-88ed-9de7b0f627f1'), 'tags': ['seq:step:2', 'my-tag'], 'chunk': None, 'verbose': False}
llm new token
{'token': '":', 'run_id': UUID('6333041b-e82d-4764-9084-c2d004814998'), 'parent_run_id': UUID('81eba2c1-5951-46d5-88ed-9de7b0f627f1'), 'tags': ['seq:step:2', 'my-tag'], 'chunk': None, 'verbose': False}
llm new token
{'token': ' "', 'run_id': UUID('6333041b-e82d-4764-9084-c2d004814998'), 'parent_run_id': UUID('81eba2c1-5951-46