In [121]:
from pymilvus import connections

connections.connect(host='localhost', port='19530')

In [122]:
from pymilvus import FieldSchema, CollectionSchema, Collection, DataType

dim = 384

schema = CollectionSchema(fields=[
    FieldSchema(name="id", dtype=DataType.INT64, description="id", is_primary=True),
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim, description="embedding")
])

collection_name = "sbert_from_mlm_bert_80"
collection = Collection(name=collection_name, schema=schema, using='default')

In [123]:
# создание индекса
collection.create_index(
    field_name="embedding",
    index_params={"metric_type": "COSINE", "index_type": "IVF_FLAT", "nlist": 16384}
)

Status(code=0, message=)

In [125]:
import nltk
import pandas as pd
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('./sbert_from_mlm_bert_19')

model.encode(['hello world'])

array([[ 8.59407604e-01,  8.49813759e-01,  1.61552429e-01,
        -9.29081202e-01,  3.98176312e-02,  3.95493656e-02,
        -2.18836164e+00, -9.78214800e-01, -1.61439944e-02,
         3.50391448e-01,  4.51187551e-01,  2.47718528e-01,
         2.41391331e-01,  1.16843653e+00,  4.15819943e-01,
        -2.85165161e-02,  1.24654539e-01, -6.36771321e-01,
        -1.51432037e+00, -8.54755118e-02, -2.59021968e-02,
         8.74955893e-01, -9.65558052e-01,  9.18002844e-01,
         3.48503441e-01,  1.36290276e+00,  8.67882371e-01,
        -3.47709477e-01,  1.83938360e+00, -5.18177867e-01,
         4.73967269e-02,  5.41395903e-01, -1.81146550e+00,
        -3.00502390e-01, -1.21891630e+00,  3.42858374e-01,
         5.16730212e-02,  1.97092414e+00,  1.85212731e+00,
        -7.76450634e-02,  6.80921435e-01, -8.98889542e-01,
         1.36631894e+00,  7.88382053e-01,  9.49835956e-01,
        -9.72384572e-01, -1.27880585e+00,  8.10197443e-02,
         1.07867837e+00, -3.25738490e-01, -8.11832190e-0

In [126]:
def get_sents(text_path: str, len_min: int, len_max: int):
    with open(text_path, 'r') as f:
        text = f.read()

    sents = sorted(set([
        sent.replace('\n', ' ')
        for sent in nltk.sent_tokenize(text)
        if len_min < len(sent) < len_max
    ]))
    return sents


sentences = get_sents('../data/processed/oshhamaho.txt', 20, 40)
print(len(sentences))

df = pd.DataFrame(sentences, columns=['text'])

58477


In [127]:
from tqdm import tqdm

tqdm.pandas()

df['embedding'] = df['text'].progress_apply(lambda x: model.encode([x])[0])

data = [
    {"id": i, "embedding": emb}
    for i, emb in zip(df.index, df['embedding'])
]

100%|██████████| 58477/58477 [05:43<00:00, 170.05it/s]


In [128]:
# split data into chunks
chunk_size = 10000

for i in range(0, len(data), chunk_size):
    collection.insert(data[i:i + chunk_size])

collection.flush()

In [129]:
# Загрузка коллекции в память
collection.load()

In [136]:
# Поиск векторов в коллекции
q = 'Сыжеящ сэ, зыри слъэгъуакъым!' 

query_embedding = model.encode([q])[0]
search_params = {"metric_type": "COSINE", "params": {"nprobe": 16}}
results = collection.search([query_embedding], "embedding", search_params, limit=100)

ids = results[0].ids

In [137]:
scores = results[0].distances

df.iloc[ids].assign(score=scores)

Unnamed: 0,text,embedding,score
32846,"Сыжеящ сэ, зыри слъэгъуакъым!","[0.9200104, -0.39659, 0.026366716, 0.8062754, ...",1.000000
19107,Зыри схуещIэнукъым сэ а гурыщIэм.,"[0.8800354, -0.36040992, -0.061464228, 0.82607...",0.990564
34992,Сэ зыми сыщIэупщIэркъым.,"[0.8470928, -0.48494545, 0.020302482, 0.908391...",0.990510
21195,Иджы сэ зыри сыхуеижкъым.,"[1.1014924, -0.49872944, -0.093424186, 0.92347...",0.990269
18358,Зулий Тэмазэ Ар сэ зыми къызжиIакъым.,"[0.8809128, -0.5940377, -0.074895926, 0.977269...",0.989486
...,...,...,...
49101,– Зыри къызгурыIуэркъым.,"[0.78755003, -0.3477541, -0.55958515, 0.779013...",0.898831
18947,"Зыми ущымышынэ, дахэ.","[0.81828624, -0.6319867, -0.49326023, 0.819835...",0.898524
39132,Уэ пхуэдэу зыми хуэщIыркъым ар.,"[0.92883253, -0.65393406, -0.51447546, 0.80085...",0.898388
7783,Ар зыми хуэздэнукъым.,"[0.89433694, -0.64419645, -0.53091425, 0.85696...",0.898111


In [120]:
collection.drop()