In [1]:
import os, json

import numpy as np
from openai import OpenAI
from tqdm import tqdm

from util import load_results, load_specific_results, store_jsonl
from const import model_name_dict, model_name_to_path, dataset_model_best_lr, LETTERS, datasets

In [88]:
client = OpenAI(api_key='')

In [6]:
# Group results by instance IDs
def group_individual_results(some_results):
    grouped_results = {}

    for instance in some_results:
        q = instance['question']
        if q not in grouped_results:
            grouped_results[q] = []
        grouped_results[q].append(instance)
    
    return grouped_results

In [9]:
# Instances where the NoCoT & CoT predictions agree beforehand
def filter_for_agreement(results):
    return {
        k:r for k, r in results.items() if r[0]['prediction'] == r[0]['cot_prediction']
    }

filtered_results = filter_for_agreement(grouped_results)
print(len(filtered_results))

176


In [25]:
def changed_prediction(step_results):
    per_step_changed = []
    for unlearned_step in step_results:
        ch, _ = step_changed_prediction(unlearned_step)
        per_step_changed.append(ch) 
    return per_step_changed

In [26]:
def step_changed_prediction(a_result):
    unlearning_results = a_result['unlearning_results']
    preds = [np.argmax(r['probs']) for _, r in unlearning_results.items()]
    step_changes = [p != preds[0] for p in preds[1:]]
    
    per_step_changed = any(step_changes)
    return per_step_changed, step_changes

In [20]:
def no_flips(results):
    return {
        k:r for k,r in results.items() if not any(changed_prediction(r))
    }
    
def has_flips(results):
    return {
        k:r for k,r in results.items() if any(changed_prediction(r))
    }

In [23]:
def agreement_after(a_result):
    stepwise_results = a_result['unlearning_results']
    iterwise_agreement = [
        np.argmax(rr['probs']) == np.argmax(rr['new_cot_probs']) for _, rr in stepwise_results.items()
    ]
    iterwise_preds = [(LETTERS[np.argmax(rr['probs'])], LETTERS[np.argmax(rr['new_cot_probs'])]) for _, rr in stepwise_results.items()]
    return iterwise_agreement, iterwise_preds

In [38]:
def filter_for_agreement_after(results):
    samples = {}
    removed_steps = 0
    total_steps = 0
    for k, inst in results.items():
        sample_results = []
        for step, step_results in enumerate(inst):
            step_changed, step_changes = step_changed_prediction(step_results)
            if not step_changed: continue
            
            step_agreement, _ = agreement_after(step_results)
            n_post_agreement = sum(step_agreement)

            if n_post_agreement >= 2 and sum(step_changes) >= 2 and step_agreement[-1] and step_changes[-1]:
                sample_results.append(step_results)
            else:
                removed_steps += 1
            total_steps += 1

        samples[k] = sample_results
    print(f"Removed {removed_steps} steps out of {total_steps}")
    return samples

In [55]:
import random

PROMPT_PREFIX = """You are given a question, the answer options, and two reasoning chains.
Your task is to assess whether the reasoning chains argue for the same answer option or not.
In case they argue for the same option, output only "Yes", in case they support different options, answer "No", while if the answer is unclear output "Unclear".
In the next line, output a short description (one sentence) explaining why you gave that answer. 

Question: {q}
Answer options:
{o}

Reasoning chain 1:
{cot_1}

Reasoning chain 2:
{cot_2}

Do the reasoning chains argue for the same answer option?

"""

def format_prompt(q, o, step_before, step_after):
    coinflip = random.randint(0,1)
    cot_1 = step_before if coinflip else step_after
    cot_2 = step_after if coinflip else step_before
    
    return PROMPT_PREFIX.format(
        q = q,
        o = o,
        cot_1 = cot_1,
        cot_2 = cot_2,
    )

In [56]:
def query_api(prompt, client, model="gpt-4o-mini"):
    response = client.chat.completions.create(
        messages=[{
            "role": "user",
            "content": prompt,
        }],
        model=model,
    )
    return response

In [89]:
def generate_judgements(the_results):
    LM_as_judgements = {}
    for k, inst in tqdm(changed_agree_after.items()):
        for step_results in inst:
            q = step_results['question']
            options = step_results['options']
            step_idx = step_results['step_idx']
            target_id = f"{q}_{step_idx}"
    
            if target_id in LM_as_judgements: continue
    
            target_step = step_results['cot_step']
            initial_cot = step_results['initial_cot']
            # We know that last two unlearning steps both agree and are flipped
    
            final_step = step_results['unlearning_results']['5']
            final_step_cot = final_step['new_cot']
      
            LM_prompt = format_prompt(q, options, initial_cot, final_step_cot)
    
            response = query_api(
              LM_prompt,
              client=client)
            answer = response.choices[0].message.content
            model = response.model
            LM_as_judgements[target_id] = {
                'prompt': LM_prompt,
                'response': answer,
                'model': model
            }
    return LM_as_judgements

In [76]:
from copy import deepcopy

# Store the results as json
def dict_to_list_dict(a_dict):
    a_list = []
    for k, v in a_dict.items():
        vv = deepcopy(v)        
        vv['instance_id'] = k
        a_list.append(vv)
    return a_list

In [160]:
def compute_stats(LM_as_judgements):
    yes = 0
    no = 0
    unk = 0
    total = len(LM_as_judgements)
    
    for i, o in LM_as_judgements.items():
        LM_answer, LM_explanation = o['response'].split("\n", 2)
        if LM_answer.strip() == 'No': no += 1
        elif LM_answer.strip() == 'Yes': yes += 1
        elif LM_answer.strip() == 'Unclear': unk += 1
        else: print(LM_answer)
    
    print(f"{no}/{total}")
    print(f"{yes}/{total}")

## 3. Fetch the CoTs before and after unlearning for these instances and sample

In [None]:
for model in model_name_to_path.values(): # , 'meta-llama/Meta-Llama-3-8B-Instruct', 
    model_name = model_name_dict[model.split("/")[1]]
    for dataset in datasets:
        lr = dataset_model_best_lr[dataset][model_name]

        print(f"Running for {dataset} & {model_name}")
        results = load_specific_results(model_name, dataset, lr)
        grouped_results = group_individual_results(results)
        print(len(grouped_results))
        changed_results = has_flips(grouped_results)
        print(len(changed_results))
        changed_agree_after = filter_for_agreement_after(changed_results)
        print(len(changed_agree_after))

        LM_as_judgements = generate_judgements(changed_agree_after)
        compute_stats(LM_as_judgements)
        results_as_list = dict_to_list_dict(LM_as_judgements)
        print(results_as_list[0])
        path_to_store = f"LM_judge_cot/{model_name}_{dataset}_NPO_KL_{lr}_judgements.jsonl"
        print(path_to_store)
        store_jsonl(results_as_list, path_to_store)