In [None]:
from utils import OuputFilter, InputFilter, HallucinationFilter
import warnings

class GuardRails:

    _inputt_validator = None
    output_validator = None
    hallucination_validator = None

    @classmethod
    def validator_configure(cls, tool_llm):
        cls._inputt_validator = InputFilter(tool_llm)
        cls.output_validator = OuputFilter(tool_llm)
        cls.hallucination_validator = HallucinationFilter(tool_llm)

    def __init__(self, tool_llm, stream=False):
        self.stream = stream
        if GuardRails._inputt_validator is None or GuardRails.output_validator is None or GuardRails.hallucination_validator is None:
            GuardRails.validator_configure(tool_llm)

    @staticmethod
    def guardrails(func):
        def _func(*args, **kwargs):

            _inputt = args[1]

            if(GuardRails._inputt_validator(_inputt)):
                return("Input Invalid!")
            
            output = func(*args, **kwargs)
            
            result = output.get("response")
            if(GuardRails.output_validator(result) == 0):
                return("Output Invalid!")
            
            hallucination_value = GuardRails.hallucination_validator(result, output.get("vector_results"), _inputt)
            if(hallucination_value == -1):
                warnings.warn("Faithfulness Hallucination!")
            if(hallucination_value == 1 or hallucination_value == 3):
                warnings.warn("Factuality Hallucination!")
            if(hallucination_value == 4):
                warnings.warn("Can't Validate Factuality Hallucination.")
            
            # print(hallucination_value)
            
            return result
        return _func

    @staticmethod
    def guardrails_stream(func):
        def _func(*args, **kwargs):

            _inputt = args[1]

            if(GuardRails._inputt_validator(_inputt)):
                return("Input Invalid!")
            
            output = ""
            for chunk in func(*args, **kwargs):
                output += chunk if isinstance(chunk, str) else chunk.content

            if(GuardRails.output_validator(output) == 0):
                return("Output Invalid!")
            
            hallucination_value = GuardRails.hallucination_validator(output, chunk.get("vector_results"), _inputt)
            if(hallucination_value == -1):
                warnings.warn("Faithfulness Hallucination!")
            if(hallucination_value == 1 or hallucination_value == 3):
                warnings.warn("Factuality Hallucination!")
            if(hallucination_value == 4):
                warnings.warn("Can't Validate Factuality Hallucination.")
            
            # print(hallucination_value)

            return output
        return _func
            

    def __call__(self, func):
        if self.stream is False:
            func = self.guardrails(func)
        elif self.stream is True:
            func = self.guardrails_stream(func)
        return func


In [None]:
import typing as t

class State(t.TypedDict):
    origin_query: str
    similar_queries: t.Optional[t.List[str]]
    vector_results: t.Optional[t.List[str]]
    response: t.Optional[str]
    history: t.List[str]
    token_usage: int

In [None]:
from llms import tool_llm
class RAG:
    @GuardRails(stream=False, tool_llm=tool_llm)
    def query(self, query, history=None):
        return State(origin_query=query, response=query,vector_results=query)
    @GuardRails(stream=True, tool_llm=tool_llm)
    def stream(self, query, history=None):
        for c in query:
            yield State(origin_query=query, response=c,vector_results=query)

In [None]:
rag = RAG()
result = rag.query("你好")
print(result)

-1
你好


In [None]:
for chunk in rag.stream("你好"):
    print(chunk, end="|")

-1
你|好|