In [1]:
class MainAgent:
    def __init__(self, user_idx, prp_idx, model_type="public", params={}):
        self.user_idx = user_idx
        self.prp_idx = prp_idx
        self.model_type = model_type
        self.params = params
        self.chroma_handler_public = ChromaHandler(
            user_idx=user_idx, prp_idx=prp_idx, file_type="public"
        )
        self.chroma_handler_private = (
            ChromaHandler(user_idx=user_idx, prp_idx=prp_idx, file_type="private")
            if model_type == "private"
            else None
        )
        self.query_handler = QueryHandler()
        self.prompt_handler = PromptHandler()
        self.model = self._initialize_model()

    def _initialize_model(self):
        if self.model_type == "public":
            return ChatOpenAI(model="gpt-3.5-turbo", temperature=0, streaming=True)
        elif self.model_type == "private":
            return PrivateLLMModel()  # 추후 개발될 private LLM 모델 클래스
        else:
            raise ValueError(f"Unsupported model type: {self.model_type}")

    def retrieve_documents_simple(self, state: AgentState) -> AgentState:
        query = state["messages"][-1].content
        relevant_documents_db = self.chroma_handler_public.search_db(
            query, {"vectordb": "query"}, k=2
        )
        return {"query": query, "documents": list(relevant_documents_db)}

    def make_context(self, state: AgentState) -> AgentState:
        context_info = self._make_context(
            self.chroma_handler_public,
            self.chroma_handler_private,
            state["query"],
            {"vectordb": "query", "summary": "summary", "search": "search"},
        )
        return {"query": state["query"], "context_info": context_info}

    def create_draft(self, state: AgentState) -> AgentState:
        context_info = state["context_info"]
        toc_list = eval(self.params["toc_list"])
        summarized_history = []
        previous_contents = eval(self.params["previous_content"])
        draft_content = ""

        for idx, toc_idx in enumerate(toc_list):
            category = self._get_category(toc_idx)
            context = context_info["context"]
            source = context_info["source"]
            prompt_type, temperature = self._get_prompt_details(category)

            if context and source:
                content = self._generate_content(
                    toc_idx,
                    context,
                    summarized_history,
                    previous_contents,
                    prompt_type,
                    temperature,
                )
                content = self._evaluate_and_refine_content(category, toc_idx, content)
                draft_content += content

                if idx != len(toc_list) - 1:
                    summarized_history.append(self._summarize_history(content))

        return {"messages": draft_content}

    def synthesize_response_default(
        self, state: AgentState, config: RunnableConfig
    ) -> AgentState:
        prompt_template = """
        Task: Generate a comprehensive answer based on search results.
        Topic: Provided search results
        Language: Korean
        Length: 100 words or less
        """
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", prompt_template),
                ("human", "{query}"),
                ("system", "{context}"),
            ]
        )
        response_synthesizer = prompt | self.model
        synthesized_response = response_synthesizer.stream(
            {
                "query": state["query"],
                "context": self._format_docs(state["context_info"]["documents"]),
            },
            config,
        )
        return {"messages": list(synthesized_response)}

    def save_content(self):
        pass

    def update_proposal(self):
        pass

    def export_docx(self):
        pass

    def export_pdf(self):
        pass

    def track_status(self):
        pass

    def _make_context(
        self,
        chroma_handler_public: ChromaHandler,
        chroma_handler_private: ChromaHandler,
        query: str,
        prompt: dict,
    ) -> dict:
        source = dict(vectorDB=[], search=[])
        context_info = dict(context={}, source=source)

        public_docs = chroma_handler_public.search_db(query, prompt["vectordb"], k=2)
        private_docs = []
        if chroma_handler_private:
            private_docs = chroma_handler_private.search_db(
                query, prompt["vectordb"], k=2
            )

        all_docs = public_docs + private_docs

        if all_docs:
            summarized_chroma_docs = chroma_handler_public.summarize_context(
                all_docs, prompt["summary"]
            )
            context_info["context"]["vectorDB"] = summarized_chroma_docs["context"]
            context_info["source"]["vectorDB"] = summarized_chroma_docs["source"]

        search_docs = chroma_handler_public.search_web(query, prompt["search"], k=3)
        context_info["context"]["search"] = search_docs["context"]
        context_info["source"]["search"] = search_docs["source"]

        return context_info

    def _get_category(self, toc_idx: int) -> str:
        category_query = self.query_handler.make_category_query(toc_idx)
        category_prompt = self.prompt_handler.make_category_prompt(category_query)
        category_chatgpt = ChatGPTHandler(
            prompt=category_prompt, temperature=0, model_kwargs={"top_p": 0.1}
        )
        category = category_chatgpt.generate()
        return category

    def _get_prompt_details(self, category: str) -> tuple:
        if category == "계획":
            return "high", 0.6
        return "low", 0

    def _generate_content(
        self,
        toc_idx: int,
        context: dict,
        summarized_history: list,
        previous_contents: list,
        prompt_type: str,
        temperature: float,
    ) -> str:
        prompt = self.prompt_handler.make_draft_prompt(
            toc_idx,
            self.params["user_msg"],
            context,
            "\n".join(summarized_history),
            previous_contents[toc_idx],
            prompt_type,
        )
        draft_chatgpt = ChatGPTHandler(
            prompt=prompt, model_name="gpt-4o", temperature=temperature
        )
        content = draft_chatgpt.generate()
        return content

    def _evaluate_and_refine_content(
        self, category: str, toc_idx: int, content: str
    ) -> str:
        if category != "계획":
            agent_handler = AgentHandler(
                toc_idx, category, content, self.chroma_handler_public
            )
            evaluation = agent_handler.start()
            final_prompt = self.prompt_handler.make_final_prompt(content, evaluation)
            final_chatgpt = ChatGPTHandler(
                prompt=final_prompt, model_name="gpt-4o", temperature=0
            )
            content = final_chatgpt.generate()
        return content

    def _summarize_history(self, content: str) -> str:
        summarized_history_prompt = self.prompt_handler.make_summarized_history_prompt(
            content
        )
        summarized_chatgpt = ChatGPTHandler(
            prompt=summarized_history_prompt, temperature=0, model_kwargs={"top_p": 0.1}
        )
        summarized_content = summarized_chatgpt.generate()
        return summarized_content

    def _format_docs(self, docs: Sequence[Document]) -> str:
        formatted_docs = []
        for i, doc in enumerate(docs):
            doc_string = f"<doc id='{i+1}'>{doc.page_content}</doc>"
            formatted_docs.append(doc_string)
        return "\n".join(formatted_docs)

NameError: name 'AgentState' is not defined

In [2]:
def create_graph(agent: MainAgent) -> StateGraph:
    workflow = StateGraph(AgentState)
    workflow.add_node("retriever", agent.retrieve_documents_simple)
    workflow.add_node("context_maker", agent.make_context)
    workflow.add_node("draft_creator", agent.create_draft)
    workflow.add_node("response_synthesizer", agent.synthesize_response_default)
    workflow.set_conditional_entry_point(route_to_retriever)
    workflow.add_conditional_edges("retriever", route_to_context_maker)
    workflow.add_conditional_edges("context_maker", route_to_draft_creator)
    workflow.add_conditional_edges("draft_creator", route_to_response_synthesizer)
    workflow.add_edge("response_synthesizer", END)
    return workflow.compile()


def route_to_retriever(state: AgentState) -> Literal["retriever"]:
    return "retriever"


def route_to_context_maker(state: AgentState) -> Literal["context_maker"]:
    return "context_maker"


def route_to_draft_creator(state: AgentState) -> Literal["draft_creator"]:
    return "draft_creator"


def route_to_response_synthesizer(
    state: AgentState, config: RunnableConfig
) -> Literal["response_synthesizer"]:
    return "response_synthesizer"

NameError: name 'MainAgent' is not defined