In [None]:
import json
import pandas as pd
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import random

In [None]:
TH = 0.005
DESIRED_DATA_LEN = 1500 # size of the final cluster
PROMPTS_SELECTED_JSON = 'prompts_selected.json' # file to save the selected prompts

In [None]:
DATA = "../../data/data.csv"
PROMPT_VARIATONS = "../../data/prompt_variations.json"
data = pd.read_csv(DATA)
prompt_to_cluster = data[["rewrite_prompt", "cluster"]].drop_duplicates().set_index("rewrite_prompt")["cluster"].to_dict()
prompts_by_cluster = {k: v.tolist() for k, v in data.groupby("cluster")["rewrite_prompt"].unique().to_dict().items()}
prompt_variations = json.load(open(PROMPT_VARIATONS))
for p, v in prompt_variations.items():
    cluster = prompt_to_cluster.get(p, None)
    if cluster:
        prompts_by_cluster[cluster].extend(v)

# Selects a maximum of 20 prompts from each cluster
prompts = []
for cluster_prompts in prompts_by_cluster.values():
    sample_size = min(20, len(cluster_prompts))
    prompts.extend(random.sample(cluster_prompts, sample_size))
prompts = list(set(prompts))
print(len(prompts))

In [None]:
st_model = SentenceTransformer('sentence-transformers/sentence-t5-base')

In [None]:
def process_prompt(prompt):
    # Makes lower case and removes punctuation
    prompt = prompt.lower()
    prompt = ''.join(e for e in prompt if e.isalnum() or e.isspace())
    return prompt

In [None]:
reference_distribution = [
    {'id': 1, 'rewrite_prompt': 'Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style.', 'lb_score': 0.61},
    {'id': 2, 'rewrite_prompt': 'Improve the text to this.', 'lb_score': 0.6},    
    {'id': 3, 'rewrite_prompt': 'Improve the text.', 'lb_score': 0.59},    
    {'id': 4, 'rewrite_prompt': 'Improve that text.', 'lb_score': 0.58},    
    {'id': 5, 'rewrite_prompt': 'Rewrite the text into a rhyming, sea-shanty style with a playful tone while maintaining the original information.', 'lb_score': 0.57},    
    {'id': 6, 'rewrite_prompt': 'original text elegance clarity improve style word tone meaning', 'lb_score': 0.55},    
    {'id': 7, 'rewrite_prompt': 'rephrase new version convey tone', 'lb_score': 0.54},    
    {'id': 8, 'rewrite_prompt': 'rephrase increase writing', 'lb_score': 0.53},    
    {'id': 9, 'rewrite_prompt': 'Transform the text into a humorous shanty and include a catchy chorus.', 'lb_score': 0.5},
    {'id': 10, 'rewrite_prompt': "Kindly refine the text below to mirror the writing style of , preserving its original intent yet modifying its tone, vocabulary, and stylistic details to resemble the new style. Boost the text's clarity, sophistication, and effectiveness by mimicking the writing style of , keeping the essential meaning unchanged but altering the tone, terminology, and style in accordance with the desired style.", 'lb_score': 0.58},
    {'id': 11, 'rewrite_prompt': "Please enhance the text provided, emulating the writing style of , while keeping the original intent but changing the tone, vocabulary, and style elements to align with the desired style.", 'lb_score': 0.59},
    {'id': 12, 'rewrite_prompt': "Kindly refine the subsequent text by adopting the writing style of , preserving its inherent meaning while modifying the tone, language, and stylistic nuances to reflect the new style.", 'lb_score': 0.56},
    {"id": 13, 'rewrite_prompt': "Make this text more negative.", 'lb_score': 0.45},
    {"id": 14, 'rewrite_prompt': 'z', "lb_score": 0.40},
    {"id": 15, 'rewrite_prompt': "The composition stands as a sequence of words arranged for potential contemplation, devoid of explicit intent or discernible purpose. It exists within a framework of neutrality, offering neither direction nor conclusion, inviting observation without expectation. The arrangement facilitates a space for presence, unattached to specific outcomes or interpretations.", "lb_score": 0.44},
    {"id": 16, 'rewrite_prompt': "Improve rephrase text manner this written to has character in style.", "lb_score": 0.64},
]

for entry in tqdm(reference_distribution):
    entry['embeddings'] = st_model.encode(entry['rewrite_prompt'])

In [None]:
def calc_energy(score_on_ref):
    # Energy is difference between lb_score and score_on_ref for each entry
    return {
        entry['id']: abs(entry['lb_score'] - score_on_ref[entry['id']]) for entry in reference_distribution
    }

def calc_score_on_ref(embed):
    return {
        entry['id']: cosine_similarity(embed.reshape(1, -1), entry['embeddings'].reshape(1, -1))[0][0]**3 for entry in reference_distribution
    }

def calc_avg_score_on_ref_from_individual_scores(scores: list[dict]):
    res = {}
    for ref in reference_distribution:
        res[ref['id']] = np.mean([x[ref['id']] for x in scores])
    return res

def calc_avg_score_on_ref_from_new_score(n_items: int, current_avg_scores: dict, new_score: dict):
    res = {}
    for ref in reference_distribution:
        res[ref['id']] = (current_avg_scores[ref['id']] * n_items + new_score[ref['id']]) / (n_items + 1)
    return res

def print_report(scores: list[dict]):
    avg_scores = calc_avg_score_on_ref_from_individual_scores(scores)
    for ref in reference_distribution:
        print(f"Reference {ref['id']}: {ref['lb_score']:.2f} -> {avg_scores[ref['id']]:.2f}")

def calculate_single_energy(n_items, prompt_score, current_avg_scores):
    new_energy = calc_energy(calc_avg_score_on_ref_from_new_score(n_items, current_avg_scores, prompt_score))
    new_max_energy = np.max(list(new_energy.values()))
    return new_max_energy

def add_prompt(available_prompts_scores, current_cluster_prompt_scores):
    """
    Greedy selection of the next prompt to add to the cluster
    If a prompt is found that makes the cluster energy down return it. Otherwise return None
    Conditions:
    - If the overall energy is below TH, returns a random prompt that keeps the energy below TH. This is useful to avoid local minima
    - If the overall energy is above TH, returns the prompt that makes the energy go down the most
    """
    current_score = calc_avg_score_on_ref_from_individual_scores(current_cluster_prompt_scores)
    current_energy = calc_energy(current_score)
    best_energy_max = np.max(list(current_energy.values()))
    idxs_to_energy_max = []
    n_items = len(current_cluster_prompt_scores)
    for idx, prompt_score in enumerate(available_prompts_scores):
        new_energy = calculate_single_energy(n_items, prompt_score, current_score)
        idxs_to_energy_max.append((idx, new_energy))
    idxs_to_energy_max.sort(key=lambda x: x[1])
    idxs_to_energy_max = [i for i in idxs_to_energy_max if i[1] < best_energy_max or i[1] < TH]
    if len(idxs_to_energy_max) == 0:
        return None
    if best_energy_max <= TH:
        # If the best energy is already below TH, randomly select one that is below TH
        random.choice(idxs_to_energy_max)[0]
    return idxs_to_energy_max[0][0]


def make_small_cluster(prompts_scores, pbar, desired_len):
    """
    Iterate over the prompts and creates a cluster of prompts that keep the energey below the given
    threshold. It will keep adding prompts until the energy is below the threshold, once its below
    the threshold it will keep adding prompts that keep the energy below the threshold. If no prompt
    is found that makes the energy go down, it will return the current cluster.

    This starts from a random prompt, if it fails to find a cluster that keeps the energy below the
    threshold, it will restart from a another random prompt and keep trying until it finds an initial
    prompt that allows to create a cluster that keeps the energy below the threshold.

    """
    cluster = random.sample(range(len(prompts_scores)), 1)
    initial_seeds = [cluster[0]]
    available_seeds = [i for i in range(len(prompts_scores)) if i not in initial_seeds]
    while True:
        cluster_scores = [prompts_scores[i] for i in cluster]
        other_prompt_idxs = [i for i in range(len(prompts_scores)) if i not in cluster]
        if len(other_prompt_idxs) == 0:
            return None
        other_prompt_scores = [prompts_scores[i] for i in other_prompt_idxs]
        other_prompt_idx = add_prompt(other_prompt_scores, cluster_scores)
        if other_prompt_idx is not None:
            cluster.append(other_prompt_idxs[other_prompt_idx])
            cluster_energy = calc_energy(calc_avg_score_on_ref_from_individual_scores([prompts_scores[i] for i in cluster]))
            pbar.set_description(f"Current cluster size: {str(len(cluster)).zfill(3)} | Energy: {np.max(list(cluster_energy.values())):.4f} | Retries: {len(available_seeds)-1}")
        else:
            if np.max(list(cluster_energy.values())) <= TH or len(cluster) >= desired_len:
                return cluster
            available_seeds = [i for i in range(len(prompts_scores)) if i not in initial_seeds]
            if len(available_seeds) == 0:
                return None
            cluster = random.sample(available_seeds, 1)
            initial_seeds.append(cluster[0])
            cluster_energy = calc_energy(calc_avg_score_on_ref_from_individual_scores([prompts_scores[i] for i in cluster]))
            pbar.set_description(f"Current cluster size: {str(len(cluster)).zfill(3)} | Energy: {np.max(list(cluster_energy.values())):.4f} | Retries: {len(available_seeds)-1}")


In [None]:
embeddings = st_model.encode(prompts, show_progress_bar=True)

In [None]:
prompts_scores = [calc_score_on_ref(embed) for embed in embeddings]

In [None]:
data = []
available_prompts = list(range(len(prompts)))
pbar = tqdm(total=DESIRED_DATA_LEN)
while len(data) < DESIRED_DATA_LEN:
    available_prompts = [i for i in available_prompts if i not in data]
    if len(available_prompts) == 0:
        break
    _prompt_scores = [prompts_scores[i] for i in available_prompts]
    missing_len = DESIRED_DATA_LEN - len(data)
    _new_cluster_idxs = make_small_cluster(_prompt_scores, pbar, missing_len)
    if _new_cluster_idxs is None:
        break
    _selected_idxs = [available_prompts[i] for i in _new_cluster_idxs]
    data.extend(_selected_idxs)
    current_selection = [prompts[i] for i in data]
    pbar.update(len(_selected_idxs))
    json.dump(current_selection, open(PROMPTS_SELECTED_JSON, 'w'))