In [1]:
from sqlalchemy import create_engine

# 创建 SQL Server 连接
connection_string = (
    "mssql+pyodbc://sa:1@127.0.0.1/database_name?driver=ODBC+Driver+17+for+SQL+Server"
)
engine = create_engine(connection_string)

In [10]:
from openai import OpenAI
from typing import Any, List
from llama_index.core.embeddings import BaseEmbedding
from pydantic import Field
# 配置Embedding模型
api_key ="222"
emb_model="bge-m3"
base_url="http://192.168.12.10:9997/v1"
# emb_model ="qwen-embedding-v1"
# base_url = "http://192.168.12.10:8000/v1"


class OurEmbeddings(BaseEmbedding):
    api_key: str = Field(default=api_key)
    base_url: str = Field(default=base_url)
    model_name: str = Field(default=emb_model)
    client: OpenAI = Field(default=None, exclude=True)  # 显式声明 client 字段

    def __init__(
        self,
        api_key: str = api_key, 
        base_url: str = base_url,
        model_name: str = emb_model,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) 

    def invoke_embedding(self, query: str) -> List[float]:
        response = self.client.embeddings.create(model=self.model_name, input=[query])

        # 检查响应是否成功
        if response.data and len(response.data) > 0:
            return response.data[0].embedding
        else:
            raise ValueError("Failed to get embedding from ZhipuAI API")

    def _get_query_embedding(self, query: str) -> List[float]:
        return self.invoke_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        return self.invoke_embedding(text)

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        return [self._get_text_embedding(text) for text in texts]

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        return self._get_text_embeddings(texts)

In [11]:
embedding = OurEmbeddings(api_key=api_key, base_url=base_url, model_name=emb_model)
emb = embedding.get_text_embedding("你好呀呀")
len(emb), type(emb)

(1024, list)

In [14]:
from openai import OpenAI
from typing import Any, List
from llama_index.core.embeddings import BaseEmbedding
from pydantic import Field

# 配置 Xinference 的 API 信息
api_key = "your_api_key"  # 如果没有 API 密钥，可以留空
base_url = "http://192.168.12.10:9997/v1"  # Xinference 服务的地址
model_uid = "bge-m3"  # 替换为实际的 model_uid

class XinferenceEmbeddings(BaseEmbedding):
    api_key: str = Field(default=api_key)
    base_url: str = Field(default=base_url)
    model_uid: str = Field(default=model_uid)
    client: OpenAI = Field(default=None, exclude=True)  # 显式声明 client 字段

    def __init__(
        self,
        api_key: str = api_key, 
        base_url: str = base_url,
        model_uid: str = model_uid,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.api_key = api_key
        self.base_url = base_url
        self.model_uid = model_uid
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)

    def invoke_embedding(self, query: str) -> List[float]:
        try:
            response = self.client.embeddings.create(model=self.model_uid, input=[query])
            if response.data and len(response.data) > 0:
                return response.data[0].embedding
            else:
                raise ValueError("Failed to get embedding from Xinference API")
        except Exception as e:
            print(f"Error: {e}")
            return []

    def _get_query_embedding(self, query: str) -> List[float]:
        return self.invoke_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        return self.invoke_embedding(text)

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        return [self._get_text_embedding(text) for text in texts]

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        return self._get_text_embeddings(texts)

# 测试代码
embedding = XinferenceEmbeddings(api_key=api_key, base_url=base_url, model_uid=model_uid)
emb = embedding.get_text_embedding("你好呀呀")
print(len(emb), type(emb))

1024 <class 'list'>


In [12]:

from llama_index.embeddings.openai import OpenAIEmbedding

emb_model ="bge-m3"
api_key ="23231"
base_url="http://192.168.12.10:9997/v1"


embedding = OpenAIEmbedding(
    api_key = api_key,
    model = emb_model,
    api_base = base_url # 注意这里单词不一样
)

emb = embedding.get_text_embedding("你好呀呀")
len(emb), type(emb)

(1024, list)

In [16]:
from tenacity import retry, stop_after_attempt, wait_exponential
from llama_index.embeddings.openai import OpenAIEmbedding

emb_model = "bge-m3"
api_key = "23231"
base_url = "http://192.168.12.10:9997/v1"

@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10))
def get_embedding_with_retry(embedding, text):
    return embedding.get_text_embedding(text)

embedding = OpenAIEmbedding(
    api_key=api_key,
    model=emb_model,
    api_base=base_url
)

try:
    emb = get_embedding_with_retry(embedding, "你好呀呀")
    print(len(emb), type(emb))
except Exception as e:
    print(f"Failed to get embedding: {e}")

1024 <class 'list'>
