In [None]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import nltk
from tqdm.auto import tqdm
tqdm.pandas()

In [None]:
DATASET_PATH = "../../data/data.csv"
MEAN_PROMPT = "Rewrite this text convey manner human evokes text better exude genre plath tone cut include object being about please further wise this individuals could originally convey here."
MEAN_PROMPT_BASE = MEAN_PROMPT[:-1] + " {}."
STOP_WORDS = set(nltk.corpus.stopwords.words('english'))

In [None]:
st_model = SentenceTransformer('sentence-transformers/sentence-t5-base')

In [None]:
MEAN_PROMPT_EMBD = st_model.encode(MEAN_PROMPT)

In [None]:
def calc_score(emb_1, emb_2):
    return cosine_similarity([emb_1], [emb_2])[0][0]**3

def optimise_single(rewrite_prompt):
    base_promp_embedding = st_model.encode(rewrite_prompt)
    initial_score = calc_score(MEAN_PROMPT_EMBD, base_promp_embedding)
    best_score = initial_score
    rewrite_prompt_words = rewrite_prompt.split()
    rewrite_prompt_words = [word.lower() for word in rewrite_prompt_words]
    # remove non alphanumeric characters in each word
    rewrite_prompt_words = ["".join([c for c in word if c.isalnum()]) for word in rewrite_prompt_words]
    available_words = list(set(rewrite_prompt_words) - set(MEAN_PROMPT.split()))
    keep_going = True
    best_words = []
    while keep_going:
        best_word = None
        for word in available_words:
            best_words.append(word.lower())
            new_prompt = MEAN_PROMPT_BASE.format(" ".join(best_words))
            new_prompt_embedding = st_model.encode(new_prompt)
            new_score = calc_score(new_prompt_embedding, base_promp_embedding)
            best_words.pop()
            if new_score > best_score:
                best_score = new_score
                best_word = word
        if best_word:
            best_words.append(best_word.lower())
            available_words.remove(best_word)
        if not best_word or len(available_words) == 0 or best_score > 0.9:
            break
    return " ".join(best_words)

In [None]:
data = pd.read_csv(DATASET_PATH)

In [None]:
data["subject"] = data.rewrite_prompt.progress_apply(optimise_single)