In [1]:
import os

import gradio as gr
from dotenv import load_dotenv
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import (JsonOutputParser,
                                           PydanticOutputParser)
from langchain_google_genai import ChatGoogleGenerativeAI
from mod.O_prompt import (STORY_GENERATOR_SYSTEM_PROMPT,
                          STORY_GENERATOR_USER_PROMPT, SYSTEM_PROMPT)
from pydantic import BaseModel, Field

load_dotenv()

True

In [2]:
class TurtleSoupStory(BaseModel):
    """
    海龜湯（情境猜謎）的故事模型
    """
    title: str = Field(description="故事標題", examples="海龜湯")
    difficulty: str = Field(description="海龜湯故事的難易度")
    custom: str = Field(description="故事情節的客製化要素")
    question: str = Field(description="故事的謎面，隱藏了故事的關鍵真相，僅看到部分的結果")
    answer: str = Field(description="故事的謎底，包含所有的完整真相和邏輯環節")


In [None]:
class StoryGenerator:
    def __init__(self, llm, system_prompt):
        self.llm = llm
        self.messages = [SystemMessage(system_prompt)]

    def generate_story(self, difficulty, custom, pydantic_schema):
        data = {
            "difficulty": difficulty,
            "custom": custom,
            "pydantic_schema": pydantic_schema
        }


In [3]:
model_name = 'gemini-3-flash-preview'
api_key = os.environ.get("GOOGLE_API")

llm = ChatGoogleGenerativeAI(
        model=model_name,
        google_api_key=api_key
    )

In [4]:
json_parser = JsonOutputParser()
pydantic_parser = PydanticOutputParser(
    pydantic_object=TurtleSoupStory
)
pydantic_schema = pydantic_parser.get_format_instructions()

difficulty = input("請輸入困難度（簡單/普通/困難/隨機）")
custom = input("請輸入客製化故事需求")

data = {
    "difficulty": difficulty,
    "custom": custom,
    "pydantic_schema": pydantic_schema
}

In [5]:
from langchain_core.prompts import ChatPromptTemplate

story_gen_system = STORY_GENERATOR_SYSTEM_PROMPT
story_gen_user = STORY_GENERATOR_USER_PROMPT

prompt_template = ChatPromptTemplate.from_messages([
        ("system", story_gen_system),
        ("human", story_gen_user)
    ])

In [6]:
story_gen_chain = prompt_template | llm | json_parser

story = story_gen_chain.invoke(data)

In [8]:
class Bot:
    def __init__(self, llm, system_prompt):
        self.llm = llm
        self.messages = [SystemMessage(system_prompt)]

    def chat_stream(self, text):
        self.messages.append(HumanMessage(text))

        full_response = []

        for chunk in self.llm.stream(self.messages):
            if chunk.content:
                if isinstance(chunk.content, list):
                    for block in chunk.content:
                        full_response.append(block)

                        if block.get('type') == 'text':
                            yield block.get('text', '')
                else:
                    full_response.append({"type": "text", "text": chunk.content})
                    yield chunk.content

        self.messages.append(AIMessage(full_response))

In [9]:
system_prompt = SYSTEM_PROMPT.format(**story)

bot = Bot(llm=llm, system_prompt=system_prompt)

In [10]:
def chat_function_stream(message, history):
    """
    處理每輪使用者輸入的訊息，並以串流方式回傳回應。

    參數：
        message (str)：使用者輸入的文字
        history (list)：對話歷史紀錄，每輪對話包含問答內容

    回傳：
        生成器 (generator)：逐步產生回覆的文字片段，可即時在 Gradio 顯示
    """
    full_response = ""

    for chunk in bot.chat_stream(message):
        # 將每次生成的片段累加
        full_response += chunk
        # 使用 yield 將目前累積的回覆回傳給 Gradio
        yield full_response

webui = gr.ChatInterface(chat_function_stream)
webui.launch()

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


