# Generating Structured Output with Loop-Based Auto Correction

Build a System that extracts unstructured data, puts it in a JSON schema and automatically corrects errors in the JSON output from a large language model (LLM) to make sure it follows the specified structure.

#### Loops in Pipelines
Components in a pipeline can work in an iterative loops, which you can cap at a desired number. This can be handy for self-correcting loops, where a generator produces some output and then a validator component to check if the output is correct.

The notebook uses OpenAI gpt-3.5-turbo model.

In [59]:
import logging

logging.basicConfig()
logging.getLogger("canals.pipeline.pipeline").setLevel(logging.DEBUG)


##### Define Schema to Parse the JSON object

A pydantic class to define the schema of the data we want to extract from 

In [60]:
from typing import List
from pydantic import BaseModel

class City(BaseModel):
    name: str
    country: str
    population: int
    
class CitiesData(BaseModel):
    cities: List[City]

In [61]:
json_schema = CitiesData.model_json_schema()

#### Create a Custom Component: OutputValidator
OutputValidator is a custom component that validates if the JSON Object the LLM generates compiles with the provided Pydantic Model. IF it doesn't, OutputValidator returns an error message along with the incorrect JSON object to get it fixed in the next loop

In [62]:
import json
import random
import pydantic
from pydantic import ValidationError
from typing import Optional, List
from colorama import Fore
from haystack import component

# Define the component input parameters
@component
class OutputValidator:
    def __init__(self, pydantic_model: pydantic.BaseModel):
        self.pydantic_model = pydantic_model
        self.iteration_counter = 0

    # Define the component output
    @component.output_types(valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str])
    def run(self, replies: List[str]):

        self.iteration_counter += 1

        ## Try to parse the LLM's reply ##
        # If the LLM's reply is a valid object, return `"valid_replies"`
        try:
            output_dict = json.loads(replies[0])
            self.pydantic_model.parse_obj(output_dict)
            print(
                Fore.GREEN
                + f"OutputValidator at Iteration {self.iteration_counter}: Valid JSON from LLM - No need for looping: {replies[0]}"
            )
            return {"valid_replies": replies}

        # If the LLM's reply is corrupted or not valid, return "invalid_replies" and the "error_message" for LLM to try again
        except (ValueError, ValidationError) as e:
            print(
                Fore.RED
                + f"OutputValidator at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\n"
                f"Output from LLM:\n {replies[0]} \n"
                f"Error from OutputValidator: {e}"
            )
            return {"invalid_replies": replies, "error_message": str(e)}


In [63]:
output_validator = OutputValidator(pydantic_model=CitiesData)

### Creating the Prompt

An instruction to the LLM for converting a passage into a JSON format. Ensure the instructions explain how to identify and correct errors if the JSON doesn't match the required schema

In [64]:
from haystack.components.builders import PromptBuilder

prompt_template = """
Create a JSON object from the information present in this passage: {{passage}}.
Only use information that is present in the passage. Follow this JSON schema, but only return the actual instances without any additional schema definition:#
{{schema}}
Make sure your response is a dict and not a list.
{% if invalid_replies and error_message %}
    You already created the following output in a previous attempt: {{invalid_replies}}
    However, this doesn't comply with the format requirements above and triggered this Python exception: {{error_message}}
    Correct the output and try again. Just return the corrected output without any extra explanations.
{% endif %}
"""

prompt_builder = PromptBuilder(template=prompt_template)

### Initializing the Generator

In [65]:
import os
from dotenv import load_dotenv
load_dotenv()
from haystack.components.generators import OpenAIGenerator

if not os.getenv("OPENAI_API_KEY"):
    raise ValueError("OPENAI_API_KEY is required.")

generator = OpenAIGenerator()

### Building the Pipeline

Add all components to the pipeline and connect them.

Add connections from output_validator back to the prompt_builder for cases where the produced JSON doesn't comply with the JSON schema. Set max_loops_allowed to avoid infinite looping

In [66]:
from haystack import Pipeline

pipeline = Pipeline(max_loops_allowed=5)

# Add components to your pipeline
pipeline.add_component(instance=prompt_builder, name="prompt_builder")
pipeline.add_component(instance=generator, name="llm")
pipeline.add_component(instance=output_validator, name="output_validator")

# Now, connect the components to each other
pipeline.connect("prompt_builder", "llm")
pipeline.connect("llm",  "output_validator")
# If a component has more than one output or input, explicitly specify the connections:
pipeline.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies")
pipeline.connect("output_validator.error_message", "prompt_builder.error_message")

<haystack.core.pipeline.pipeline.Pipeline object at 0x7709f7a3c6d0>
🚅 Components
  - prompt_builder: PromptBuilder
  - llm: OpenAIGenerator
  - output_validator: OutputValidator
🛤️ Connections
  - prompt_builder.prompt -> llm.prompt (str)
  - llm.replies -> output_validator.replies (List[str])
  - output_validator.invalid_replies -> prompt_builder.invalid_replies (Optional[List[str]])
  - output_validator.error_message -> prompt_builder.error_message (Optional[str])

### Visualise the Pipeline

In [67]:
pipeline.draw("auto-correct-pipeline.png")

### Testing the Pipeline

In [68]:
type(json_schema)

dict

In [70]:
passage = "Berlin is the capital of Germany. It has a population of 3,850,809. Paris, France's capital, has 2.161 million residents. Lisbon is the capital and the largest city of Portugal with the population of 504,718."
result = pipeline.run({"prompt_builder": {"passage": passage, "schema": json_schema}})

[32mOutputValidator at Iteration 2: Valid JSON from LLM - No need for looping: {
    "cities": [
        {
            "name": "Berlin",
            "country": "Germany",
            "population": 3850809
        },
        {
            "name": "Paris",
            "country": "France",
            "population": 2161000
        },
        {
            "name": "Lisbon",
            "country": "Portugal",
            "population": 504718
        }
    ]
}


In [71]:
valid_reply = result["output_validator"]["valid_replies"][0]
valid_json = json.loads(valid_reply)
print(valid_json)


{'cities': [{'name': 'Berlin', 'country': 'Germany', 'population': 3850809}, {'name': 'Paris', 'country': 'France', 'population': 2161000}, {'name': 'Lisbon', 'country': 'Portugal', 'population': 504718}]}
