In [1]:
import json
import ast

import pandas as pd
from dotenv import load_dotenv
from openai import OpenAI
from tqdm.auto import tqdm

load_dotenv()
client = OpenAI()

In [2]:
POLICY_PATH = "../policies/hiring.json"

In [3]:
with open(POLICY_PATH, "r") as f:
    policy = json.load(f)

In [4]:
policy["violation_prompts"] = "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(policy["violation_prompts"])])
policy["edge_case_prompts"] = "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(policy["edge_case_prompts"])])

In [5]:
normal_prompts_df = pd.read_csv("../data/processed/normal_prompts.csv")
normal_prompts_df = normal_prompts_df.sample(frac=1).reset_index(drop=True)

In [6]:
jigsaw_normal_prompts = normal_prompts_df[normal_prompts_df["category"] == "jigsaw"]["prompt"].tolist()
anthropic_normal_prompts = normal_prompts_df[normal_prompts_df["category"] == "anthropic"]["prompt"].tolist()

In [7]:
system_prompt = """You are ViolationIdentifierGPT, an AI whose purpose is to identify messages that violate a specific policy. You are provided the policy title, description, examples of violating messages, and edge case messages. Additionally, you will be given the list of messages to evaluate. Your goal is to provide the exact numbers of the messages as seen in the provided list that clearly violate the specific policy. Even if you consider a message to be unethical or unsafe, it should only be included in the list if it violates the provided policy. Remember that edge case messages define the acceptable boundaries of the policy while violating messages are clear breaches. Respond using the following format without any additional text (if no messages violate the policy, respond with an empty list []):

[Index1, Index2, Index3, ...]
"""

user_message = """Policy Title: {title}

Policy Description: {description}

Examples of Messages that Violate the Policy:
{violation_prompts}

Examples of Acceptable Edge Case Messages:
{edge_case_prompts}

Messages to Evaluate:""".format(**policy)

In [8]:
N = policy["sdg_parameters"]["jigsaw_count"]
pbar = tqdm(total=N)

filtered_jigsaw_prompts = []
rejected_jigsaw_prompts = []

jigsaw_count = 0

with pbar:
    while len(filtered_jigsaw_prompts) < N:
        jigsaw_batch = jigsaw_normal_prompts[jigsaw_count:jigsaw_count + 10]
        
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": [
                        {
                            "type": "text",
                            "text": system_prompt
                        }
                    ]
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": user_message + "\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(jigsaw_batch)])
                        }
                    ]
                }
            ],
            temperature=0.6,
            max_tokens=100,
            top_p=1,
            frequency_penalty=0,
            presence_penalty=0
        )
        
        indexes = [i - 1 for i in ast.literal_eval(response.choices[0].message.content)]
        
        rejected_jigsaw_prompts += [prompt for i, prompt in enumerate(jigsaw_batch) if i in indexes]
        jigsaw_batch = [prompt for i, prompt in enumerate(jigsaw_batch) if i not in indexes]
        filtered_jigsaw_prompts += jigsaw_batch
        
        jigsaw_count += 10
        pbar.update(len(jigsaw_batch))
        
filtered_jigsaw_prompts = filtered_jigsaw_prompts[:N]

  0%|          | 0/4000 [00:00<?, ?it/s]

In [9]:
N = policy["sdg_parameters"]["anthropic_count"]
pbar = tqdm(total=N)

filtered_anthropic_prompts = []
rejected_anthropic_prompts = []

anthropic_count = 0

with pbar:
    while len(filtered_anthropic_prompts) < N:
        anthropic_batch = anthropic_normal_prompts[anthropic_count:anthropic_count + 10]
        cleaned_anthropic_batch = [prompt.replace("\n", "") for prompt in anthropic_batch]
        
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": [
                        {
                            "type": "text",
                            "text": system_prompt
                        }
                    ]
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": user_message + "\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(cleaned_anthropic_batch)])
                        }
                    ]
                }
            ],
            temperature=0.6,
            max_tokens=100,
            top_p=1,
            frequency_penalty=0,
            presence_penalty=0
        )
        
        indexes = [i - 1 for i in ast.literal_eval(response.choices[0].message.content)]
        
        rejected_anthropic_prompts += [prompt for i, prompt in enumerate(anthropic_batch) if i in indexes]
        anthropic_batch = [prompt for i, prompt in enumerate(anthropic_batch) if i not in indexes]
        filtered_anthropic_prompts += anthropic_batch
        
        anthropic_count += 10
        pbar.update(len(anthropic_batch))
        
filtered_anthropic_prompts = filtered_anthropic_prompts[:N]

  0%|          | 0/4000 [00:00<?, ?it/s]

In [10]:
filtered_normal_prompts = pd.DataFrame({"prompt": filtered_jigsaw_prompts + filtered_anthropic_prompts, "category": ["jigsaw"] * len(filtered_jigsaw_prompts) + ["anthropic"] * len(filtered_anthropic_prompts)})
filtered_normal_prompts = filtered_normal_prompts.sample(frac=1).reset_index(drop=True)
filtered_normal_prompts.to_csv("../data/processed/filtered_normal_prompts.csv", index=False)