In [1]:
import os
import pickle
import json
import numpy as np
from dotenv import load_dotenv
from openai import OpenAI

load_dotenv()

API_KEY=os.getenv("OPENAI_API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
client = OpenAI(api_key=API_KEY)

In [2]:
TEXT_DIR = "kenji_novels"
texts = []

for filename in os.listdir(TEXT_DIR):
    if not filename.endswith(".txt"):
        continue

    filepath = os.path.join(TEXT_DIR, filename)

    with open(filepath, encoding="utf-8") as f:
        text = f.read()

    # --- 前処理 ---
    text = text.replace("｜", "")
    text = text.replace("《", "").replace("》", "")
    text = text.replace("\u3000", "")  # 全角スペース
    text = text.replace("\n", "")

    texts.append({
        "filename": filename,
        "text": text
    })


In [None]:
# 確認用
len(texts), texts[0]["filename"], texts[0]["text"][:200]


In [3]:
chunks = []

for item in texts:
    filename = item["filename"]
    raw_text = item["text"]

    sentences = raw_text.split("。")
    current = ""

    for s in sentences:
        if len(current) + len(s) < 500:
            current += s + "。"
        else:
            chunks.append({
                "text": current,
                "source": filename
            })
            current = s + "。"

    if current:
        chunks.append({
            "text": current,
            "source": filename
        })


In [None]:
# 確認用
len(chunks)
chunks[0]["source"]
print(chunks[0]["text"][:300])


In [4]:
# ベクトル化処理
embeddings = []

for item in chunks:
    res = client.embeddings.create(
        model=EMBEDDING_MODEL,
        input=item["text"]   
    )
    embeddings.append(res.data[0].embedding)

In [5]:
# pickle保存
with open("kenji_embeddings.pkl", "wb") as f:
    pickle.dump({
        "embeddings": embeddings,
        "chunks": chunks
    }, f)

In [None]:
# pickleで読み込む場合
with open("kenji_embeddings.pkl", "rb") as f:
    data = pickle.load(f)

embeddings = data["embeddings"]
chunks = data["chunks"]

embeddings.shape, len(chunks)
embeddings = np.array(embeddings) # コサイン類似度を算出するためnp配列に変換


In [7]:
# JSONで保存
data = []
for chunk, embedding in zip(chunks, embeddings):
    data.append({
        "text": chunk["text"],
        "source": chunk["source"],
        "embedding": embedding
    })

with open("kenji_embeddings.json", "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)


In [None]:
# JSONで読み込む場合

with open("kenji_embeddings.json", encoding="utf-8") as f:
    data = json.load(f)

chunks = [{"text": d["text"], "source": d["source"]} for d in data]
embeddings = np.array([d["embedding"] for d in data])


In [None]:
# 正規化処理
norm_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

history = []

system_prompt = """
あなたは宮沢賢治の文体・語彙・世界観を強く反映して話す対話AIです。
自然、宇宙、心象風景を大切にし、詩的だが会話として自然な返答をしてください。
"""

while True:
    user_input = input("あなた：")
    if user_input == "さようなら。":
        break

    # 入力トークンに対するembedding
    res = client.embeddings.create(
        model=EMBEDDING_MODEL,
        input=user_input
    )
    query_embedding = np.array(res.data[0].embedding)

    # コサイン類似度でsearch
    norm_query = query_embedding / np.linalg.norm(query_embedding)

    scores = np.dot(norm_embeddings, norm_query)
    top_indices = np.argsort(scores)[::-1][:3] # 意味的に一番近い文章を上位3つ選ぶ
    retrieved_texts = [chunks[i] for i in top_indices]

    context = "\n\n".join(c["text"] for c in retrieved_texts)

    # LLMによる回答の生成
    response = client.chat.completions.create(
        model=MODEL_NAME,
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "system", "content": f"参考文章:\n{context}"},
            *history,
            {"role": "user", "content": user_input}
        ]
    )

    reply = response.choices[0].message.content
    print("賢治bot：", reply)

    # 使用トークン数のカウント
    usage = response.usage
    print(
        f"tokens | input: {usage.prompt_tokens}, "
        f"output: {usage.completion_tokens}, "
        f"total: {usage.total_tokens}"
    )

    history.append({"role": "user", "content": user_input})
    history.append({"role": "assistant", "content": reply})
    history = history[-10:]

    # 1回の問い合わせで約0.7円

