In [None]:
import pandas as pd
import pickle as pkl
import json
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.auto import tqdm
tqdm.pandas()

In [None]:
zero_shot_out = pkl.load(open("/home/pbernardo/github/llm-prompt-recovery/data/mistral_zero_shot_output.pkl", "rb"))
data = pd.read_csv("/home/pbernardo/github/llm-prompt-recovery/data/data.csv")

In [None]:
cluster_to_split = json.load(open("/home/pbernardo/github/llm-prompt-recovery/data/cluster_to_split.json"))

In [None]:
embedding_model = SentenceTransformer('sentence-transformers/sentence-t5-base')

def calc_score(rewrite_prompt, rewrite_prompt_pred):
    emb = embedding_model.encode([rewrite_prompt, rewrite_prompt_pred], normalize_embeddings=True)
    cos_sim = cosine_similarity(emb)[0]
    return cos_sim[1]**3

In [None]:
def get_rewrite_prompt(original_text: str, transformed_text: str):
    json_input = json.dumps({
        "original_text": original_text, 
        "rewritten_text": transformed_text
    }, indent=4)
    return f"""<s> [INST]
        I will give you a JSON with following structure:
        {{
            'original_text': 'An original piece of text.'
            'rewritten_text': 'A version of original_text that was rewritten by an LLM according to a specific prompt.'
        }}

        Given the task of understanding how text is rewritten by analyzing the original_text and rewritten_text, your goal is to deduce the specific instructions or prompt that was most likely used to generate the rewritten text from the original text. Consider the changes made in terms of style, tone, structure, and content. Assess whether the rewrite focuses on summarization, paraphrasing, stylistic alteration (e.g., formal to informal), or any specific content changes (e.g., making the text more concise, expanding on ideas, or altering the perspective). Follow this steps:

        1. Read the original_text: Start by thoroughly understanding the content, style, tone, and purpose of the original text. Note any key themes, technical terms, and the overall message.
        2. Analyze the rewritten_text: Examine how the rewritten text compares to the original. Identify what has been changed, added, or omitted. Pay close attention to changes in style (formal, informal), tone (serious, humorous), structure (paragraph order, sentence structure), and any shifts in perspective or emphasis.
        3. Infer the Prompt: Based on your analysis, infer the most likely prompt that guided the rewriting process. Your inference should account for the observed changes in style, tone, structure, and content. Specify the type of task (e.g., summarize, paraphrase, make more accessible to a general audience), any specific directions evident from the changes, and any specific stylistic choice (e.g., 'as a poem', 'as a song', 'in the style of Shakespeare', etc...)

        Based on your analysis return the prompt as if you were given the instruction your self like:
        "Rewrite this text..."
        "Transform this ... into ... based on the style of ..."
        
        Make the prompt short and direct using a maximum of 20 words.


        Return your answer using the following JSON structure:
        {{"prompt": "Your best guess for the prompt used"}}
        

            
        Return a valid JSON as output and nothing more.
        
        -----------------------
        Input: 
        
        {json_input} [/INST]
    """

def format_response(response):
    return f'{{"prompt": {response}}} </s>'

In [None]:
data = data.loc[data.id.isin(zero_shot_out)]

In [None]:
data["rewrite_prompt_pred"] = data.id.progress_apply(lambda x: zero_shot_out[x]["rewrite_prompt"])
data["score"] = data.progress_apply(lambda x: calc_score(x.rewrite_prompt, x.rewrite_prompt_pred), axis=1)

In [None]:
# Take only the ones where the score < 0.8
dpo_data = data.loc[data.score < 0.72]

In [None]:
dpo_data["split"] = dpo_data.cluster.apply(lambda x: cluster_to_split.get(str(x), "train"))

In [None]:
dpo_train = []
dpo_eval = []

In [None]:
for _, row in dpo_data.loc[dpo_data.split == "train"].iterrows():
    chosen_response = format_response(row.rewrite_prompt)
    rejected_response = format_response(row.rewrite_prompt_pred)
    prompt = get_rewrite_prompt(row.original_text, row.rewritten_text)
    dpo_train.append({
        "prompt": prompt,
        "chosen": chosen_response,
        "rejected": rejected_response,
    })

for _, row in dpo_data.loc[dpo_data.split == "val"].iterrows():
    chosen_response = format_response(row.rewrite_prompt)
    rejected_response = format_response(row.rewrite_prompt_pred)
    prompt = get_rewrite_prompt(row.original_text, row.rewritten_text)
    dpo_eval.append({
        "prompt": prompt,
        "chosen": chosen_response,
        "rejected": rejected_response,
    })

In [None]:
json.dump(dpo_train, open("/home/pbernardo/github/llm-prompt-recovery/data/dpo_train.json", "w"), indent=4)
json.dump(dpo_eval, open("/home/pbernardo/github/llm-prompt-recovery/data/dpo_eval.json", "w"), indent=4)