In [1]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('./sbert_from_mlm_bert_45_aligned')

In [2]:
import nltk

with open('../data/processed/all_book.txt', 'r') as f:
    text = f.read()

len_min = 40
len_max = 50

short_sents = sorted(set([
    sent.replace('\n', ' ')
    for sent in nltk.sent_tokenize(text)
    if len_min < len(sent) < len_max
]))

In [3]:
from sklearn.cluster import AgglomerativeClustering
from collections import defaultdict


def get_clusters(vectors, n_clusters=20):
    agg_clustering = AgglomerativeClustering(n_clusters=n_clusters)
    labels = agg_clustering.fit_predict(vectors)
    return labels


def get_sents_by_clusters(words, labels):
    sents_by_clusters = defaultdict(list)
    for i, label in enumerate(labels):
        sents_by_clusters[label].append(words[i])

    return sents_by_clusters

In [None]:
import random
from tqdm import tqdm
import os

butch_size = 10000
cluster_num = 1000
cluster_factor = butch_size / cluster_num

for seed in range(111, 120):
    export_path = f'../data/processed/sent_clusters_book/{len_min}_{len_max}/seed_{seed}/{cluster_factor}_{butch_size}_{cluster_num}'
    os.makedirs(export_path, exist_ok=True)

    random.shuffle(short_sents)

    for offset in tqdm(range(0, len(short_sents), butch_size)):
        butch_sents = short_sents[offset:offset + butch_size]
        if len(butch_sents) < cluster_num:
            break

        word_vectors = [
            model.encode(sent)
            for sent in tqdm(butch_sents)
        ]
        labels = get_clusters(word_vectors, n_clusters=cluster_num)
        sents_by_clusters = get_sents_by_clusters(butch_sents, labels)

        for cluster_label, cluster_sents in sents_by_clusters.items():
            with open(f'{export_path}/cluster_{offset}_{offset + butch_size}_{cluster_label}.txt', 'w') as f:
                f.write('\n'.join(cluster_sents))

  0%|          | 0/8 [00:00<?, ?it/s]
  0%|          | 0/10000 [00:00<?, ?it/s][A
  0%|          | 7/10000 [00:00<02:30, 66.40it/s][A
  0%|          | 18/10000 [00:00<01:55, 86.18it/s][A
  0%|          | 30/10000 [00:00<01:40, 99.10it/s][A
  0%|          | 42/10000 [00:00<01:33, 106.35it/s][A
  1%|          | 55/10000 [00:00<01:28, 112.09it/s][A
  1%|          | 67/10000 [00:00<01:34, 105.21it/s][A
  1%|          | 80/10000 [00:00<01:28, 111.59it/s][A
  1%|          | 94/10000 [00:00<01:24, 117.68it/s][A
  1%|          | 108/10000 [00:00<01:20, 122.52it/s][A
  1%|          | 121/10000 [00:01<01:28, 111.64it/s][A
  1%|▏         | 133/10000 [00:01<01:32, 106.69it/s][A
  1%|▏         | 144/10000 [00:01<01:40, 98.02it/s] [A
  2%|▏         | 155/10000 [00:01<01:38, 100.15it/s][A
  2%|▏         | 166/10000 [00:01<01:52, 87.32it/s] [A
  2%|▏         | 178/10000 [00:01<01:48, 90.35it/s][A
  2%|▏         | 192/10000 [00:01<01:36, 101.34it/s][A
  2%|▏         | 206/10000 [00:01<