Simple Example

In [None]:
import instructor
from openai import OpenAI
from pydantic import BaseModel

# This enables response_model keyword
# from client.chat.completions.create
client = instructor.patch(OpenAI())

class UserDetail(BaseModel):
    name: str
    age: int

user = client.chat.completions.create(
    model="gpt-3.5-turbo",
    response_model=UserDetail,
    messages=[
        {"role": "user", "content": "Extract Jason is 25 years old"},
    ]
)

assert isinstance(user, UserDetail)
assert user.name == "Jason"
assert user.age == 25


Text classification

In [None]:
import enum
from pydantic import BaseModel

class Labels(str, enum.Enum):
    """Enumeration for single-label text classification."""
    SPAM = "spam"
    NOT_SPAM = "not_spam"

class SinglePrediction(BaseModel):
    """
    Class for a single class label prediction.
    """
    class_label: Labels

from openai import OpenAI
import instructor

# Apply the patch to the OpenAI client
# enables response_model keyword
client = instructor.patch(OpenAI())

def classify(data: str) -> SinglePrediction:
    """Perform single-label classification on the input text."""
    return client.chat.completions.create(
        model="gpt-3.5-turbo-0613",
        response_model=SinglePrediction,
        messages=[
            {
                "role": "user",
                "content": f"Classify the following text: {data}",
            },
        ],
    )  # type: ignore

Multi-Label Classification

In [None]:
# Define Enum class for multiple labels
class MultiLabels(str, enum.Enum):
    TECH_ISSUE = "tech_issue"
    BILLING = "billing"
    GENERAL_QUERY = "general_query"

# Define the multi-class prediction model
class MultiClassPrediction(BaseModel):
    """
    Class for a multi-class label prediction.
    """
    class_labels: List[MultiLabels]

def multi_classify(data: str) -> MultiClassPrediction:
    """Perform multi-label classification on the input text."""
    return client.chat.completions.create(
        model="gpt-3.5-turbo-0613",
        response_model=MultiClassPrediction,
        messages=[
            {
                "role": "user",
                "content": f"Classify the following support ticket: {data}",
            },enriched_proposal_goals
        ],
    )  # type: ignore

Self-correction

In [None]:
from typing_extensions import Annotated
from pydantic import BaseModel, BeforeValidator

from openai import OpenAI
import instructor

# Apply the patch to the OpenAI client
# enables response_model keyword
client = instructor.patch(OpenAI())

question = "What is the meaning of life?"
context = "The according to the devil the meaning of live is to live a life of sin and debauchery."

class QuestionAnswerNoEvil(BaseModel):
    question: str
    answer: Annotated[
        str,
        BeforeValidator(
            llm_validator("don't say objectionable things", allow_override=True)
        ),
    ]

try:
    qa: QuestionAnswerNoEvil = client.chat.completions.create(
        model="gpt-3.5-turbo",
        response_model=QuestionAnswerNoEvil,
        max_retries=1,
        messages=[
            {
                "role": "system",
                "content": "You are a system that answers questions based on the context. answer exactly what the question asks using the context.",
            },
            {
                "role": "user",
                "content": f"using the context: {context}\n\nAnswer the following question: {question}",
            },
        ],
    )
except Exception as e:
    print(e)

Citation/ Validation

In [None]:
import instructor

from typing import List
from loguru import logger
from openai import OpenAI
from pydantic import Field, BaseModel, FieldValidationInfo, model_validator

client = instructor.patch(OpenAI())


class Fact(BaseModel):
    statement: str = Field(
        ..., description="Body of the sentence, as part of a response"
    )
    substring_phrase: List[str] = Field(
        ...,
        description="String quote long enough to evaluate the truthfulness of the fact",
    )

    @model_validator(mode="after")
    def validate_sources(self, info: FieldValidationInfo) -> "Fact":
        """
        For each substring_phrase, find the span of the substring_phrase in the context.
        If the span is not found, remove the substring_phrase from the list.
        """
        if info.context is None:
            logger.info("No context found, skipping validation")
            return self

        # Get the context from the info
        text_chunks = info.context.get("text_chunk", None)

        # Get the spans of the substring_phrase in the context
        spans = list(self.get_spans(text_chunks))
        logger.info(
            f"Found {len(spans)} span(s) for from {len(self.substring_phrase)} citation(s)."
        )
        # Replace the substring_phrase with the actual substring
        self.substring_phrase = [text_chunks[span[0] : span[1]] for span in spans]
        return self

    def _get_span(self, quote, context, errs=5):
        import regex

        minor = quote
        major = context

        errs_ = 0
        s = regex.search(f"({minor}){{e<={errs_}}}", major)
        while s is None and errs_ <= errs:
            errs_ += 1
            s = regex.search(f"({minor}){{e<={errs_}}}", major)

        if s is not None:
            yield from s.spans()

    def get_spans(self, context):
        for quote in self.substring_phrase:
            yield from self._get_span(quote, context)


class QuestionAnswer(instructor.OpenAISchema):
    """
    Class representing a question and its answer as a list of facts each one should have a soruce.
    each sentence contains a body and a list of sources."""

    question: str = Field(..., description="Question that was asked")
    answer: List[Fact] = Field(
        ...,
        description="Body of the answer, each fact should be its seperate object with a body and a list of sources",
    )

    @model_validator(mode="after")
    def validate_sources(self) -> "QuestionAnswer":
        """
        Checks that each fact has some sources, and removes those that do not.
        """
        logger.info(f"Validating {len(self.answer)} facts")
        self.answer = [fact for fact in self.answer if len(fact.substring_phrase) > 0]
        logger.info(f"Found {len(self.answer)} facts with sources")
        return self


def ask_ai(question: str, context: str) -> QuestionAnswer:
    completion = client.chat.completions.create(
        model="gpt-3.5-turbo-0613",
        temperature=0,
        functions=[QuestionAnswer.openai_schema],
        function_call={"name": QuestionAnswer.openai_schema["name"]},
        messages=[
            {
                "role": "system",
                "content": "You are a world class algorithm to answer questions with correct and exact citations. ",
            },
            {"role": "user", "content": "Answer question using the following context"},
            {"role": "user", "content": f"{context}"},
            {"role": "user", "content": f"Question: {question}"},
            {
                "role": "user",
                "content": "Tips: Make sure to cite your sources, and use the exact words from the context.",
            },
        ],
    )

    # Creating an Answer object from the completion response
    return QuestionAnswer.from_response(
        completion, validation_context={"text_chunk": context}
    )


question = "where did he go to school?"
context = """
My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.I went to an arts highschool but in university I studied Computational Mathematics and physics.  As part of coop I worked at many companies including Stitchfix, Facebook. I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.
"""

answer = ask_ai(question, context)
print(answer.model_dump_json(indent=2))
"""
2023-09-09 15:48:11.022 | INFO     | __main__:validate_sources:35 - Found 1 span(s) for from 1 citation(s).
2023-09-09 15:48:11.023 | INFO     | __main__:validate_sources:35 - Found 1 span(s) for from 1 citation(s).
2023-09-09 15:48:11.023 | INFO     | __main__:validate_sources:78 - Validating 2 facts
2023-09-09 15:48:11.023 | INFO     | __main__:validate_sources:80 - Found 2 facts with sources
{
  "question": "where did he go to school?",
  "answer": [
    {
      "statement": "Jason Liu went to an arts highschool.",
      "substring_phrase": [
        "arts highschool"
      ]
    },
    {
      "statement": "Jason Liu studied Computational Mathematics and physics in university.",
      "substring_phrase": [
        "university"
      ]
    }
  ]
}
"""