In [47]:
from sentence_transformers import SentencesDataset, InputExample, losses
from torch.utils.data import DataLoader

In [48]:
import pandas as pd


def prepare_data_for_training(sent_pairs_df: pd.DataFrame = None):
    train_examples = []

    for sent_1, sent_2, final_score in sent_pairs_df[['sent1', 'sent2', 'final_score']].values:
        train_examples.append(
            InputExample(texts=[sent_1, sent_2], label=round(float(final_score), 2))
        )
    return train_examples

In [49]:
def run_training(from_model_path: str, to_model_path: str, train_examples):
    model = SentenceTransformer(from_model_path)

    # Создание и загрузка датасета
    train_dataset = SentencesDataset(examples=train_examples, model=model)
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

    # Настройка процесса обучения
    train_loss = losses.CosineSimilarityLoss(model=model)

    # Обучение модели
    model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1)
    model.save(to_model_path)

    return to_model_path

In [50]:
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

In [51]:
import fasttext

kbd_model = fasttext.load_model('../data/processed/embeddings/fasttext_skipgram_kbd_100.bin')


words_df = pd.read_csv('../data/processed/word_freqs/freq_1000000_oshhamaho.csv')
words_df = words_df[words_df['freq'] > 1]
all_words = set(words_df['word'].tolist())

def select_word_pairs(iteration, samples=1000):
    words = words_df.sample(samples)['word'].tolist()
    words_pair = [
        (word, word_p, round(score, 3))
        for word in tqdm(words)
        for score, word_p, in kbd_model.get_nearest_neighbors(word, 20)
        if word_p in all_words
    ]
    words_pair_df = pd.DataFrame(words_pair, columns=['sent1', 'sent2', 'final_score'])
    words_pair_df.to_csv(f'../data/processed/word_pairs_{iteration}.csv', index=False)



In [52]:
for iteration in range(75, 80):
    from_model_path = f'./sbert_from_mlm_bert_{iteration}'
    to_model_path = f'./sbert_from_mlm_bert_{iteration + 1}'

    model_from = SentenceTransformer(from_model_path)
    select_word_pairs(iteration, samples=10000)
    
    sent_pairs_df = pd.read_csv(f'../data/processed/word_pairs_{iteration}.csv')
    train_examples = prepare_data_for_training(sent_pairs_df)
    run_training(from_model_path, to_model_path, train_examples)

100%|██████████| 10000/10000 [05:41<00:00, 29.32it/s]


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10314 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [05:56<00:00, 28.06it/s]


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10421 [00:00<?, ?it/s]

  3%|▎         | 260/10000 [00:08<05:30, 29.50it/s]

KeyboardInterrupt: 