In [1]:
from collections import defaultdict

from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from tqdm import tqdm

model_index = '7_3'
model = SentenceTransformer(f'./sbert_from_mlm_bert_{model_index}')


def split_sent_by_cluster(sentences, model):
    if len(sentences) < 8:
        return {0: sentences}
    
    embeddings = model.encode(sentences)
    
    clustering_model = KMeans(n_clusters=4, n_init=20, max_iter=1000)
    clustering_model.fit(embeddings)
    cluster_assignment = clustering_model.labels_

    clusters = defaultdict(list)
    for sentence, cluster_id in zip(sentences, cluster_assignment):
        clusters[cluster_id].append(sentence)

    return clusters

In [2]:
import os
import random

clusters_index = 9
clusters_path = f'../data/processed/sent_clusters_v{clusters_index}/20_30'

sent_pairs = []
for seed_dir in os.listdir(clusters_path):
    seed_path = f'{clusters_path}/{seed_dir}'
    if not os.path.isdir(seed_path):
        continue

    for cluster_dir in os.listdir(seed_path):
        clusters_path = f'{clusters_path}/{seed_dir}/{cluster_dir}'
        if not os.path.isdir(clusters_path):
            continue

        anti_pairs = []
        for cluster_file in tqdm(os.listdir(clusters_path)):
            if not cluster_file.endswith('.txt'):
                continue

            cluster_file_path = f'{clusters_path}/{cluster_file}'
            with open(cluster_file_path, 'r') as f:
                sents = f.read().split('\n')

            if len(sents) < 2:
                continue

            for group, sentences in split_sent_by_cluster(
                    sents, model
            ).items():
                for s1 in sentences:
                    for s2 in sentences:
                        if s1 == s2:
                            continue
                        sent_pairs.append(('pos', s1, s2)
                                          )

            if len(sent_pairs) < 1000:
                continue

            for sent in sents:
                for _ in range(len(sents)):
                    anti_pairs.append(('neg', sent, random.choice(sent_pairs)[2]))

        sent_pairs.extend(anti_pairs)

  0%|          | 0/3000 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 3000/3000 [01:08<00:00, 43.76it/s]


In [3]:
import pandas as pd

sent_pairs_df = pd.DataFrame(sent_pairs, columns=['type', 'sent1', 'sent2'])
sent_pairs_df.drop_duplicates(subset=['sent1', 'sent2'], inplace=True)

In [4]:
sent_pairs_df['final_score'] = sent_pairs_df['type'].apply(lambda x: 1 if x == 'pos' else 0)

In [5]:
sent_pairs_df

Unnamed: 0,type,sent1,sent2,final_score
0,pos,СлъэкIыу сыхъуа мыгъуэ?,Псапэ сщIэ хъунукъэ сэ?,1
1,pos,СлъэкIыу сыхъуа мыгъуэ?,"Арами, дэуэгъу сыхъуркъым.",1
2,pos,СлъэкIыу сыхъуа мыгъуэ?,куэдрэ сызэбгъэжьа!..,1
3,pos,Псапэ сщIэ хъунукъэ сэ?,СлъэкIыу сыхъуа мыгъуэ?,1
4,pos,Псапэ сщIэ хъунукъэ сэ?,"Арами, дэуэгъу сыхъуркъым.",1
...,...,...,...,...
473361,neg,Си нэгум къыщIэвмыгъахуэт!,Тхьэм гущIэгъу къыпхуищI!,0
473362,neg,Си нэгум къыщIэвмыгъахуэт!,Апхуэдэщ Вындыжь Марие.,0
473363,neg,Си нэгум къыщIэвмыгъахуэт!,Сыкъэзылъхуахэр адыгэщ.,0
473364,neg,Си нэгум къыщIэвмыгъахуэт!,Гуэдзыр дэнэ къыщагъуэтат?,0


In [6]:
slice_df = pd.DataFrame(
    sent_pairs_df.groupby('sent1').apply(lambda x: x.sample(n=min(len(x), 10), random_state=1)).reset_index(drop=True)
)
slice_df

Unnamed: 0,type,sent1,sent2,final_score
0,pos,!..» Зыри къэхъуакъым.,Сэ згъэпщкIуркъым зыри.,1
1,pos,!..» Зыри къэхъуакъым.,"– Ей, сэ зыри сыхуейкъым.",1
2,pos,!..» Зыри къэхъуакъым.,Мыдрейр хуейтэкъым зыри.,1
3,neg,!..» Зыри къэхъуакъым.,Ауэ зыми зыкъригъэщIакъым.,0
4,pos,!..» Зыри къэхъуакъым.,"– СлIо, зыри жыфIэркъыми?",1
...,...,...,...,...
256728,neg,…Щимыгъэтауэ уэсыр къос.,Къыстосэ ар си закъуэу.,0
256729,neg,…Щимыгъэтауэ уэсыр къос.,Апхуэдэуи узэрагъэкIуэнт!,0
256730,neg,…Щимыгъэтауэ уэсыр къос.,Си щIалэ цIыкIури щыIи!,0
256731,pos,…Щимыгъэтауэ уэсыр къос.,"Хыр мэпапщэ, зэхэпх къудейуэ.",1


In [7]:
slice_df.to_csv(f'../data/processed/{model_index}_{clusters_index}_sent_pairs_20_30.csv', index=False)