In [None]:
#データの設定
data = "カブトムシ_raw.txt"

with open(data, 'r', encoding='utf-8') as f:
        text = f.read()

In [None]:
#クエリの設定
question = "洗えますか？"

In [None]:
#チャンク分けの設定
chunk_size = 100

#システムプロンプトの設定
prompt = """
以下の情報を参照して、質問に答えてください。
情報に含まれている内容のみで回答してください。

[情報]
{context}

[質問]
{question}

[回答]
"""

In [None]:
#モジュールのインポート
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
import faiss
import numpy as np
import pykakasi
import os
import socket
import unicodedata
from pathlib import Path

data = "カブトムシ_raw.txt"  # RAGデータファイル
with open(data, "r", encoding="utf-8") as f:
    text = f.read()

chunk_size = 100

question = "カブトムシのオスとメスの違いは？"  # 質問例

prompt = """
以下の情報をもとに質問に簡潔に回答してください。
与えられた情報の内容のみを使用し、他の情報は使用しないでください。

[情報]
{context}

[質問]
{question}

"""

#モデルの設定
BASE_MODEL_NAME = "./sbintuitions/sarashina2.2-3B-instruct-v0.1"
EMBEDDING_MODEL_NAME = "./bge-m3"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

FAISS_INDEX_DIR = "faiss_index"
if not os.path.exists(FAISS_INDEX_DIR):
    os.makedirs(FAISS_INDEX_DIR)

def load_models(llm_name, emb_name):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16
    )
    
    llm_tok = AutoTokenizer.from_pretrained(llm_name)
    llm_model = AutoModelForCausalLM.from_pretrained(
        llm_name,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
        low_cpu_mem_usage=True
    )
    
    llm_model.eval()
    
    emb_tok = AutoTokenizer.from_pretrained(emb_name)
    emb_model = AutoModel.from_pretrained(emb_name)
    
    # エンベディングモデルを適切なデバイスに移動
    emb_model = emb_model.to(DEVICE)
    emb_model.eval()
    
    return llm_tok, llm_model, emb_tok, emb_model

def chunking(text, chunk_size):
    chunks = []
    step = max(1, chunk_size)
    
    for i in range(0, len(text), step):
        piece = text[i:i + chunk_size]
        if piece:
            chunks.append(piece)
    return chunks

def convert_to_romaji(filename):
    try:
        kakasi = pykakasi.kakasi()
        name_without_ext = Path(filename).stem
        
        romaji_result = kakasi.convert(name_without_ext)
        
        # pykakasiの結果がリストの場合は、hepburn形式の文字列を結合
        if isinstance(romaji_result, list):
            romaji_name = ""
            for item in romaji_result:
                if isinstance(item, dict) and 'hepburn' in item:
                    romaji_name += item['hepburn']
                elif isinstance(item, str):
                    romaji_name += item
        else:
            romaji_name = str(romaji_result)
        
        safe_name = re.sub(r'[^a-zA-Z0-9]', '_', romaji_name)
        safe_name = re.sub(r'_+', '_', safe_name)
        safe_name = safe_name.strip('_')
        
        return safe_name if safe_name else "data"
        
    except Exception as e:
        return "data"

def get_faiss_index_path(data_filename):
    japanese_pattern = re.compile(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FAF]')
    
    print(f"DEBUG: ファイル名: {data_filename}")
    print(f"DEBUG: 日本語判定: {japanese_pattern.search(data_filename) is not None}")
    
    if japanese_pattern.search(data_filename):
        romaji_name = convert_to_romaji(data_filename)
        print(f"DEBUG: ローマ字変換結果: {romaji_name}")
        return os.path.join(FAISS_INDEX_DIR, f"{romaji_name}.faiss")
    else:
        safe_name = re.sub(r'[^a-zA-Z0-9._-]', '_', data_filename)
        print(f"DEBUG: 英数字処理結果: {safe_name}")
        return os.path.join(FAISS_INDEX_DIR, f"{safe_name}.faiss")

def update_rag_data(data_path, text, chunk_size, emb_tok, emb_model):
    current_filename = Path(data_path).name
    print(f"DEBUG: データパス: {data_path}")
    print(f"DEBUG: 現在のファイル名: {current_filename}")
    
    faiss_index_path = get_faiss_index_path(current_filename)
    print(f"DEBUG: 最終的なFAISSパス: {faiss_index_path}")
    
    if os.path.exists(faiss_index_path):
        try:
            faiss_index = faiss.read_index(faiss_index_path)
            chunk_data = chunking(text, chunk_size)
            print(f"既存のFAISSインデックスを読み込みました: {faiss_index_path}")
            return faiss_index, chunk_data
        except Exception as e:
            print(f"既存インデックスの読み込みに失敗: {e}")
            pass
    
    try:
        chunk_data = chunking(text, chunk_size)
        chunk_embs = embed_texts(chunk_data, emb_tok, emb_model)
        faiss_index = build_faiss_index(chunk_embs)
        
        faiss.write_index(faiss_index, faiss_index_path)
        print(f"新しいFAISSインデックスを保存しました: {faiss_index_path}")
        
        return faiss_index, chunk_data
    
    except Exception as e:
        print(f"FAISSインデックス構築エラー: {e}")
        return None, None

@torch.no_grad()
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    summed = (token_embeddings * mask).sum(dim=1)
    counts = mask.sum(dim=1).clamp(min=1e-9)
    return summed / counts

@torch.no_grad()
def embed_texts(texts, tok: AutoTokenizer, model: AutoModel, batch_size: int = 32) -> np.ndarray:
    vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        inputs = tok(batch, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
        outputs = model(**inputs)
        emb = mean_pooling(outputs, inputs["attention_mask"])
        emb = torch.nn.functional.normalize(emb, p=2, dim=1)
        vecs.append(emb.cpu().numpy())
    return np.vstack(vecs)

def build_faiss_index(embs: np.ndarray):
    index = faiss.IndexFlatIP(embs.shape[1])
    index.add(embs)
    return index

def vector_search(question, chunks, emb_tok: AutoTokenizer, emb_model, index, k=5):
    with torch.no_grad():
        q_inputs = emb_tok([question], padding=True, truncation=True, return_tensors="pt").to(DEVICE)
        q_out = emb_model(**q_inputs)
        q_vec = mean_pooling(q_out, q_inputs["attention_mask"])
        q_vec = torch.nn.functional.normalize(q_vec, p=2, dim=1).cpu().numpy()
    sims, ids = index.search(q_vec, k)
    return [(int(i), float(s)) for i, s in zip(ids[0], sims[0]) if 0 <= i < len(chunks)]

def _safe_format(template, **kwargs):
    try:
        return template.format(**kwargs)
    except KeyError as e:
        print(f"フォーマットエラー: {e}")
        return template

def build_prompt(prompt_template, question, retrieved_chunks):
    context = "\n---\n".join(retrieved_chunks)
    user_content = _safe_format(prompt_template, question=question, context=context)
    return user_content

def generate_response(llm_tok, llm_model, prompt, max_length=512):
    try:
        inputs = llm_tok(prompt, return_tensors="pt", truncation=True, max_length=max_length).to(DEVICE)
        
        with torch.no_grad():
            outputs = llm_model.generate(
                **inputs,
                max_new_tokens=max_length,
                temperature=0.7,
                do_sample=True,
                pad_token_id=llm_tok.eos_token_id,
                eos_token_id=llm_tok.eos_token_id
            )
        
        response = llm_tok.decode(outputs[0], skip_special_tokens=True)
        if prompt in response:
            response = response.replace(prompt, "").strip()
        
        return response
    except Exception as e:
        print(f"回答生成エラー: {e}")
        return "回答の生成中にエラーが発生しました。"

def run_pipeline(question, chunks, emb_tok, emb_model, index, llm_tok, llm_model, prompt_template, k=5):
    search_results = vector_search(question, chunks, emb_tok, emb_model, index, k)
    
    retrieved_chunks = []
    for chunk_id, similarity in search_results:
        if 0 <= chunk_id < len(chunks):
            retrieved_chunks.append(chunks[chunk_id])
    
    if not retrieved_chunks:
        return "関連する情報が見つかりませんでした。"
    
    prompt = build_prompt(prompt_template, question, retrieved_chunks)
    response = generate_response(llm_tok, llm_model, prompt)
    
    return response

def main_rag_pipeline(data_path, text, chunk_size, question, prompt_template, k=5):
    print("モデルを読み込み中...")
    llm_tok, llm_model, emb_tok, emb_model = load_models(BASE_MODEL_NAME, EMBEDDING_MODEL_NAME)
    
    print("RAGデータを更新中...")
    faiss_index, chunk_data = update_rag_data(data_path, text, chunk_size, emb_tok, emb_model)
    
    if faiss_index is None or chunk_data is None:
        return "RAGデータの準備に失敗しました。"
    
    print("質問に回答中...")
    response = run_pipeline(
        question, 
        chunk_data, 
        emb_tok, 
        emb_model, 
        faiss_index, 
        llm_tok, 
        llm_model, 
        prompt, 
        k
    )
    
    return response

if __name__ == "__main__":
    response = main_rag_pipeline(data, text, chunk_size, question, prompt)
    print(response)