In [20]:
from sentence_transformers import SentencesDataset, InputExample, losses
from torch.utils.data import DataLoader

In [21]:
import pandas as pd


def prepare_data_for_training(sent_pairs_df: pd.DataFrame = None):
    train_examples = []

    for sent_1, sent_2 in sent_pairs_df[['rus', 'kbd']].values:
        train_examples.append(
            InputExample(texts=[sent_1, sent_2], label=round(float(1.0), 2))
        )
    return train_examples

In [25]:
def run_training(from_model_path: str, to_model_path: str, train_examples):
    model = SentenceTransformer(from_model_path)

    # Создание и загрузка датасета
    train_dataset = SentencesDataset(examples=train_examples, model=model)
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

    # Настройка процесса обучения
    train_loss = losses.CosineSimilarityLoss(model=model)

    # Обучение модели
    model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1)
    model.save(to_model_path)

    return to_model_path

In [26]:
from sentence_transformers import SentenceTransformer
import nltk
from sklearn.cluster import AgglomerativeClustering
from collections import defaultdict
import random
from tqdm import tqdm
import os


def get_sents(text_path: str, len_min: int, len_max: int):
    with open(text_path, 'r') as f:
        text = f.read()

    sents = sorted(set([
        sent.replace('\n', ' ')
        for sent in nltk.sent_tokenize(text)
        if len_min < len(sent) < len_max
    ]))
    return sents


def get_clusters(vectors, n_clusters=20):
    agg_clustering = AgglomerativeClustering(n_clusters=n_clusters)
    labels = agg_clustering.fit_predict(vectors)
    return labels


def get_sents_by_clusters(words, labels):
    sents_by_clusters = defaultdict(list)
    for i, label in enumerate(labels):
        sents_by_clusters[label].append(words[i])

    return sents_by_clusters


def clusterize_sents(sents, model, version, butch_size=10000, cluster_num=1000):
    cluster_factor = butch_size / cluster_num

    for seed in range(111, 115):
        export_path = f'../data/processed/sent_clusters_{version}/seed_{seed}/{cluster_factor}_{butch_size}_{cluster_num}'
        os.makedirs(export_path, exist_ok=True)

        random.shuffle(sents)

        for offset in tqdm(range(0, len(sents), butch_size)):
            butch_sents = sents[offset:offset + butch_size]
            if len(butch_sents) < cluster_num:
                break

            word_vectors = [
                model.encode(sent)
                for sent in butch_sents
            ]
            labels = get_clusters(word_vectors, n_clusters=cluster_num)
            sents_by_clusters = get_sents_by_clusters(butch_sents, labels)

            for cluster_label, cluster_sents in sents_by_clusters.items():
                with open(f'{export_path}/cluster_{offset}_{offset + butch_size}_{cluster_label}.txt', 'w') as f:
                    f.write('\n'.join(cluster_sents))


In [27]:
from_model_path = f'./sbert_from_mlm_bert_45_aligned'
to_model_path = f'./sbert_from_mlm_bert_45_aligned'

model_from = SentenceTransformer(from_model_path)

sent_pairs_df = pd.read_csv(f'../data/processed/word_freqs/freq_1000000_oshhamaho_translated.csv')
sent_pairs_df.dropna(inplace=True)

train_examples = prepare_data_for_training(sent_pairs_df)
run_training(from_model_path, to_model_path, train_examples)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21567 [00:00<?, ?it/s]

'./sbert_from_mlm_bert_45_aligned'