In [1]:
from langchain.schema.runnable import RunnableBranch

In [3]:
rb = RunnableBranch(
    (lambda age: age < 6, lambda age: f'幼儿, {age}'),
    (lambda age: age < 18, lambda age: f'儿童, {age}'),
    (lambda age: age < 40, lambda age: f'青年, {age}'),
    (lambda age: age < 80, lambda age: f'中年, {age}'),
    lambda age: f'老年, {age}'
)

In [7]:
rb.batch([41, 37, 11])

['中年, 41', '青年, 37', '儿童, 11']

In [10]:
from langchain.schema.runnable import RunnableWithFallbacks, RunnableLambda

In [22]:
rfb = RunnableWithFallbacks(
    runnable=RunnableLambda(lambda x: 100/x),
    fallbacks=[
        RunnableLambda(lambda x: 50/x),
        RunnableLambda(lambda x: x)
    ]
)

In [25]:
rfb.invoke(2)

50.0

In [29]:
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableBranch
from typing import Literal

from langchain.pydantic_v1 import BaseModel
from langchain.output_parsers.openai_functions import PydanticAttrOutputFunctionsParser
from langchain.utils.openai_functions import convert_pydantic_to_openai_function

from langchain.prompts import PromptTemplate


physics_template = """You are a very smart physics professor. \
You are great at answering questions about physics in a concise and easy to understand manner. \
When you don't know the answer to a question you admit that you don't know.

Here is a question:
{input}"""
physics_prompt = PromptTemplate.from_template(physics_template)

math_template = """You are a very good mathematician. You are great at answering math questions. \
You are so good because you are able to break down hard problems into their component parts, \
answer the component parts, and then put them together to answer the broader question.

Here is a question:
{input}"""
math_prompt = PromptTemplate.from_template(math_template)

general_prompt = PromptTemplate.from_template(
    "You are a helpful assistant. Answer the question as accurately as you can.\n\n{input}"
)
prompt_branch = RunnableBranch(
  (lambda x: x["topic"] == "math", math_prompt),
  (lambda x: x["topic"] == "physics", physics_prompt),
  general_prompt
)

class TopicClassifier(BaseModel):
    "Classify the topic of the user question"
    
    topic: Literal["math", "physics", "general"]
    "The topic of the user question. One of 'math', 'phsyics' or 'general'."


classifier_function = convert_pydantic_to_openai_function(TopicClassifier)
llm = ChatOpenAI().bind(functions=[classifier_function], function_call={"name": "TopicClassifier"}) 
parser = PydanticAttrOutputFunctionsParser(pydantic_schema=TopicClassifier, attr_name="topic")
classifier_chain = llm | parser

In [30]:
classifier_chain

RunnableBinding(bound=ChatOpenAI(client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, openai_api_key='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', openai_api_base='', openai_organization='', openai_proxy=''), kwargs={'functions': [{'name': 'TopicClassifier', 'description': 'Classify the topic of the user question', 'parameters': {'title': 'TopicClassifier', 'description': 'Classify the topic of the user question', 'type': 'object', 'properties': {'topic': {'title': 'Topic', 'enum': ['math', 'physics', 'general'], 'type': 'string'}}, 'required': ['topic']}}], 'function_call': {'name': 'TopicClassifier'}})
| PydanticAttrOutputFunctionsParser(pydantic_schema=<class '__main__.TopicClassifier'>, attr_name='topic')

In [33]:
from prettyprinter import cpprint
cpprint(llm)

RunnableBinding(bound=ChatOpenAI(client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, openai_api_key='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', openai_api_base='', openai_organization='', openai_proxy=''), kwargs={'functions': [{'name': 'TopicClassifier', 'description': 'Classify the topic of the user question', 'parameters': {'title': 'TopicClassifier', 'description': 'Classify the topic of the user question', 'type': 'object', 'properties': {'topic': {'title': 'Topic', 'enum': ['math', 'physics', 'general'], 'type': 'string'}}, 'required': ['topic']}}], 'function_call': {'name': 'TopicClassifier'}})


In [41]:
from operator import itemgetter

from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

# r = RunnablePassthrough.assign(topic=itemgetter("input") | classifier_chain)
# final_chain = (
#     RunnablePassthrough.assign(topic=itemgetter("input") | classifier_chain) 
#     | prompt_branch 
#     | ChatOpenAI()
#     | StrOutputParser()
# )

# r = itemgetter("input") | classifier_chain
# resp = r.invoke(
#     {"input": "What is the first prime number greater than 40 such that one plus the prime number is divisible by 3?"}
# )
# resp
classifier_chain.invoke(itemgetter("input"))

ValueError: Invalid input type <class 'operator.itemgetter'>. Must be a PromptValue, str, or list of BaseMessages.

In [43]:
llm({"input": "What is the first prime number greater than 40 such that one plus the prime number is divisible by 3?"}, verbose=True)

TypeError: 'RunnableBinding' object is not callable

In [45]:
ChatOpenAI()

ChatOpenAI(client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, openai_api_key='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', openai_api_base='', openai_organization='', openai_proxy='')

In [49]:
ChatOpenAI().bind()

RunnableBinding(bound=ChatOpenAI(client=<class 'openai.api_resources.chat_completion.ChatCompletion'>, openai_api_key='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', openai_api_base='', openai_organization='', openai_proxy=''), kwargs={})

In [53]:
llm = ChatOpenAI().bind(functions=[classifier_function], function_call={"name": "TopicClassifier"})
r = llm.invoke("What is the first prime number greater than 40 such that one plus the prime number is divisible by 3?")
r

AIMessage(content='', additional_kwargs={'function_call': {'name': 'TopicClassifier', 'arguments': '{\n  "topic": "math"\n}'}})

In [55]:
parser.parse_obj(r.schema())

KeyError: 'pydantic_schema'