In [None]:
import os
from dotenv import load_dotenv
load_dotenv()

if os.getenv('XPU_API_KEY'):
    print ("✅ Found XPU_API_KEY in environment, using it")
else:
    raise ValueError("❌ XPU_API_KEY not found in environment. Please set it in .env file before running this script.")


from llama_index.core import SimpleDirectoryReader
import pprint

# load documents
documents = SimpleDirectoryReader(
    input_dir = '../../data/',
).load_data()

print (f"Loaded {len(documents)} chunks")

# print("Document [0].doc_id:", documents[0].doc_id)
# pprint.pprint (documents[0], indent=4)

from llama_index.core.embeddings import BaseEmbedding
from typing import List, Optional
import numpy as np
import requests
import json
import os

class OpenAICompatibleEmbedding(BaseEmbedding):
    """OpenAI风格API的embedding模型实现"""
    def __init__(
        self,
        api_base: str,
        api_key: str,
        model: str = "text-embedding-ada-002",  # 默认使用OpenAI的嵌入模型名称
        embed_batch_size: int = 50,
        **kwargs
    ) -> None:
        """
        初始化OpenAI风格的embedding模型
        
        Args:
            api_base: API基础地址
            api_key: API密钥
            model: 嵌入模型名称
            embed_batch_size: 批处理大小
        """
        self.api_base = api_base.rstrip('/')
        self.api_key = api_key
        self.model = model
        self.embed_batch_size = embed_batch_size
        super().__init__(**kwargs)
    
    def _get_query_embedding(self, query: str) -> List[float]:
        """获取单个查询的embedding（OpenAI风格）"""
        # OpenAI风格API请求
        headers = {
            'Authorization': f'Bearer {self.api_key}',
            'Content-Type': 'application/json'
        }
        data = {
            'model': self.model,
            'input': query
        }
        
        response = requests.post(
            f"{self.api_base}/embeddings",  # OpenAI风格的端点
            headers=headers,
            data=json.dumps(data)
        )
        
        if response.status_code == 200:
            # OpenAI风格响应格式: {"data": [{"embedding": [...]}], ...}
            result = response.json()
            if result.get('data') and len(result['data']) > 0:
                return result['data'][0].get('embedding', [])
            return []
        else:
            raise Exception(f"API请求失败: {response.status_code}, {response.text}")
    
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """批量获取文本的embeddings（OpenAI风格）"""
        embeddings = []
        # 分批处理文本
        for i in range(0, len(texts), self.embed_batch_size):
            batch = texts[i:i + self.embed_batch_size]
            
            # OpenAI风格API批量请求
            headers = {
                'Authorization': f'Bearer {self.api_key}',
                'Content-Type': 'application/json'
            }
            data = {
                'model': self.model,
                'input': batch  # OpenAI API支持直接传入字符串列表
            }
            
            response = requests.post(
                f"{self.api_base}/embeddings",
                headers=headers,
                data=json.dumps(data)
            )
            
            if response.status_code == 200:
                # OpenAI风格响应格式处理
                result = response.json()
                if result.get('data'):
                    # 提取每个条目的embedding
                    batch_embeddings = [item.get('embedding', []) for item in result['data']]
                    embeddings.extend(batch_embeddings)
            else:
                raise Exception(f"API批量请求失败: {response.status_code}, {response.text}")
        
        return embeddings
    
    @property
    def model_name(self) -> str:
        return self.model

# 使用示例：
from llama_index.core import Settings

# 配置OpenAI风格的embedding模型
Settings.embed_model = OpenAICompatibleEmbedding(
    api_base="https://xpulink.ai/v1",  # OpenAI风格的API基础地址
    api_key=os.getenv("CLOUD_API_KEY") or "sk-umVACm8cv8lbuWT0MegMTxNe4QXlY3CFZVKoKJXkiEL2RqRm",  # 使用环境变量或直接设置的密钥
    model="text-embedding-ada-002",  # 替换为您的模型名称
    embed_batch_size=50
)

CLOUD_API_KEY = sk-umVACm8cv8lbuWT0MegMTxNe4QXlY3CFZVKoKJXkiEL2RqRm
