In [None]:
#Imports
from scripts import dataset_scripts as ds_s, mauve_quantization as mq, subset_selection as ss
from transformers import AutoTokenizer, AutoModel
import torch

In [7]:
# Dicts are in format: {ds_path, ds_name, under_ds_name}
human_ds_dict = {"ds_path":"data/human/news-fi-2019.jsonl", "ds_name":"news-fi-2019.jsonl", "under_ds_name":None}
clums_ds_dict = {"ds_path":"data/clumsified/news-fi-2019.jsonl_regeneration_5_mini_regen_round_1.jsonl", "ds_name":"news-fi-2019.jsonl_regeneration_5_mini_regen_round_1.jsonl", "under_ds_name":"news-fi-2019.jsonl"}

In [8]:
ds = ds_s.format_datasets([human_ds_dict, clums_ds_dict])

In [9]:
print(len(ds))

48030


In [10]:
ref, remaining = ds_s.sample_reference_corpus(ds, "news-fi-2019.jsonl", 5000)

In [11]:
print(len(ref))
print(len(remaining))

5000
38030


In [12]:
model_name = "intfloat/multilingual-e5-large-instruct"

In [None]:
#Model name that can be loaded from HF goes here
model_id = model_name
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id, device_map="auto")


# ## Getting the embeddings with the wanted (L)LM

inputs_q = [tokenizer.encode(x['text'], return_tensors="pt", truncation=True, max_length=512) for x in remaining]
embeddings_q = mq.featurize_tokens_from_model(model, inputs_q, 1, name="", verbose=False)
inputs_q = []
del inputs_q
inputs_p = [tokenizer.encode(x['text'], return_tensors="pt", truncation=True, max_length=512) for x in ref]
embeddings_p = mq.featurize_tokens_from_model(model, inputs_p, 1, name="", verbose=False)
inputs_p = []
del inputs_p

In [None]:
#Estimate the optimal number of clusters as done in MAUVE
num_of_clusters = max(2, int(round(min(len(embeddings_p)/10, len(embeddings_q)/10))))

print(f'The number of clusters is {num_of_clusters}')
results = mq.CDOE(torch.cat(embeddings_p), torch.cat(embeddings_q), num_of_clusters)
embeddings_q = []
embeddings_p = []
del embeddings_q
del embeddings_p

In [None]:
# Combining information to make future work easier

p2cluster = results['p2cluster']
for i in p2cluster:
    ref[i]['cluster_id'] = p2cluster[i]
q2cluster = results['q2cluster']
for i in q2cluster:
    remaining[i]['cluster_id'] = q2cluster[i]


In [None]:
p_distr = results['p_bin_counts']
q_distr = results['q_bin_counts']

In [None]:
test_simple = ss.get_target_n_per_cluster(p_distr, q_distr, 5000, True)
test_complex = ss.get_target_n_per_cluster(p_distr, q_distr, 5000)