In [2]:
from typing import List, Any
from sentence_transformers import SentenceTransformer


class Embedding(object):

    def __init__(
            self,
            model_name_or_path: str,
            device: str = "cuda",
            normalized: bool = False,
    ):
        super(Embedding, self).__init__()
        # self.m = text2vec.SentenceModel(
        #     model_name_or_path,
        #     max_seq_length=max_seq_length,
        # )
        self.normalized = normalized
        self.m = SentenceTransformer(
            model_name_or_path=model_name_or_path,
            device=device,
        )

    def encode(
            self,
            text: str | List[str],
            normalized: bool = False,
    ):
        text = (
            text if isinstance(text, list) else [text]
        )
        normalized = normalized or self.normalized
        embeddings = self.m.encode(
            sentences=text,
            normalize_embeddings=normalized
        )
        ret = [
            {"embedding": embed.tolist()} for embed in embeddings
        ]
        return ret

    def __call__(
            self,
            text: str | List[str],
    ):
        return self.encode(
            text=text,
        )


model = Embedding(
    model_name_or_path="./models/text2vec-base-chinese",
    device="cpu",
    normalized=True,
)

In [3]:
ret = model.encode(
    text="你好",
    normalized=True,
)
print(ret)

[{'embedding': [-0.011796600185334682, 0.020661162212491035, 0.03626355528831482, 0.003833576338365674, 0.01988031528890133, -0.028167380020022392, 0.02026396431028843, -0.01783526875078678, -0.03674688935279846, 0.04176722466945648, 0.027834801003336906, 0.026619063690304756, 0.0023146173916757107, -0.021494880318641663, -0.00721411406993866, 0.00788100715726614, -0.007364068645983934, -0.019263694062829018, -0.01964900828897953, -0.05119998753070831, -0.03278687223792076, 0.042289793491363525, 0.002700143028050661, 0.012581350281834602, -0.016569823026657104, 0.05623943731188774, -0.004821450915187597, 0.024360671639442444, 0.07893326133489609, 0.03437742590904236, -0.009767141193151474, 0.043593768030405045, -0.014650587923824787, 0.00045120486174710095, -0.05232773721218109, -0.017314201220870018, -0.006627094466239214, 0.05541662871837616, -0.037477217614650726, -0.015246674418449402, 0.006722163874655962, 0.030116006731987, -0.03538242354989052, 0.0339144803583622, 0.037600360810