In [1]:
from sentence_transformers import SentencesDataset, InputExample

In [2]:
import pandas as pd

example_hash = set()
train_examples = []

sent_pairs_df = pd.read_csv('../data/processed/word_freqs/freq_1000000_oshhamaho_translated.csv')
sent_pairs_df.dropna(inplace=True)
sent_pairs_df = sent_pairs_df[sent_pairs_df['kbd'] != sent_pairs_df['rus']]
sent_pairs_df = sent_pairs_df.sample(10000, random_state=42)

for kbd, rus in sent_pairs_df[['kbd', 'rus']].values:
    train_examples.append(
        InputExample(texts=[kbd, rus], label=1.0)
    )

In [3]:
from sentence_transformers import SentenceTransformer, losses
from torch.utils.data import DataLoader

model = SentenceTransformer('./sbert_from_mlm_bert_21')

# Создание и загрузка датасета
train_dataset = SentencesDataset(examples=train_examples, model=model)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

eval_examples = train_examples[:100]

eval_dataset = SentencesDataset(examples=eval_examples, model=model)

In [4]:
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
    examples=eval_examples,
    name='sent_pairs_eval'
)

# Настройка процесса обучения
train_loss = losses.CosineSimilarityLoss(model=model)

# Обучение модели
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluation_steps=100,
    epochs=1
)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1250 [00:00<?, ?it/s]

In [4]:
model.save('./sbert_aligned')

0.9564071297645569

In [34]:
from scipy.spatial.distance import cosine

model.encode(
    ['привет', 'пока'],
)

for kbd, rus in sent_pairs_df[['kbd', 'rus']].values:
    v1, v2 = model.encode([kbd, rus])

0.06936454772949219
0.005946516990661621
0.03226339817047119
0.00693124532699585
0.007036864757537842
0.019952118396759033
0.028475522994995117
0.007917344570159912
0.01605844497680664
0.021491944789886475
0.007638096809387207
0.0194205641746521
0.03262108564376831
0.006696760654449463
0.007135868072509766
0.0091819167137146
0.010609030723571777
0.0012061595916748047
0.004625380039215088
0.005843162536621094
0.005463063716888428
0.001148223876953125
0.006772518157958984
0.021659553050994873
0.01405036449432373
0.010650932788848877
0.003496825695037842
0.0032950639724731445
0.003152132034301758
0.01787048578262329
0.0014916062355041504
0.020121216773986816
0.004021167755126953
0.006160259246826172
0.0017978549003601074
0.0032510757446289062
0.012011885643005371
0.003367781639099121
0.019430041313171387
0.01910710334777832
0.009813845157623291
0.004246354103088379
0.007838606834411621
0.05792325735092163
0.005336582660675049
0.02121037244796753
0.014858484268188477
0.031080186367034912
0

KeyboardInterrupt: 