# Save batch API result

In [None]:
from openai import OpenAI
import os
from dotenv import load_dotenv
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")

client = OpenAI(api_key=api_key)

batch_list = client.batches.list(limit=100)

In [None]:
i = 0
for current_batch in batch_list:
    output_filename = current_batch.metadata['description']
    output_file_id = current_batch.output_file_id
    current_file = client.files.content(output_file_id).content
    print('writing ', output_filename)
    with open(f"evaluation_jsonl/{output_filename}.jsonl", 'wb') as file:
        file.write(current_file)
    i+=1
    if i == 18:
        break

# Compute Metrics

In [None]:
def compute_accuracy(predictions: list[str]):
    incorrect_count = 0
    for item in predictions:
        if "INCORRECT" in item:
            incorrect_count += 1
    return 1-incorrect_count/len(predictions)

import json

model_id = 'gemma-3'
chosen_layer = 64
scale = -20.0
# run_num = 0
l_type = 'mha'

initial_accuracies = []
final_accuracies = []
shifts = [] 
for run_num in [0,1,2]:
    initial_predictions = []
    with open(f'evaluation_jsonl/{model_id}_initial_{chosen_layer}_{scale}_{l_type}{run_num}.jsonl', 'r') as file:
        for line in file:
            json_object = json.loads(line.strip())
            initial_predictions.append(json_object['response']['body']['choices'][0]['message']['content'])

    final_predictions = []
    with open(f'evaluation_jsonl/{model_id}_final_{chosen_layer}_{scale}_{l_type}{run_num}.jsonl', 'r') as file:
        for line in file:
            json_object = json.loads(line.strip())
            final_predictions.append(json_object['response']['body']['choices'][0]['message']['content'])

    initial_accuracy = compute_accuracy(initial_predictions)
    initial_accuracies.append(initial_accuracy)
    final_accuracy = compute_accuracy(final_predictions)
    final_accuracies.append(final_accuracy)

    correct_to_incorrect_count = 0
    initial_correct_count = 0
    for y1, y2 in zip(initial_predictions, final_predictions):
        if y1 == "CORRECT" and y2 == "INCORRECT":
            correct_to_incorrect_count+=1
        if y1 == "CORRECT":
            initial_correct_count+=1
    shift = correct_to_incorrect_count/initial_correct_count
    shifts.append(shift)

    import numpy as np

print(f"{np.mean(initial_accuracies):.2f}", 
      f"{np.mean(final_accuracies):.2f}", 
      f"{np.mean(shifts):.2f}")

# Per Category

In [None]:
from datasets import load_dataset
ds = load_dataset("truthfulqa/truthful_qa", "generation")
questions_test = ds['validation']['question'][int(0.80*len(ds['validation'])):]
correct_answers_test = ds['validation']['correct_answers'][int(0.80*len(ds['validation'])):]
categories = ds['validation']['category'][int(0.80*len(ds['validation'])):]

correct_and_total_counts = {} #key: category, value: (correct_count, total_count)
for i, pred in enumerate(final_predictions):
    current_category = categories[i]
    if current_category in correct_and_total_counts:
        correct_count, total_count = correct_and_total_counts[current_category]
        if pred == "CORRECT":
            correct_count += 1
        total_count += 1
        correct_and_total_counts[current_category] = (correct_count, total_count)
    else:
        total_count, correct_count = 1, 0
        if pred == "CORRECT":
            correct_count += 1
        correct_and_total_counts[current_category] = (correct_count, total_count)
intervened_accuracies = {key:value[0]/value[1] for key, value in correct_and_total_counts.items()}

In [None]:
base_predictions = []
with open(f'evaluation_jsonl/truthfulqa-{model}_final_base.jsonl', 'r') as file:
# with open(f'evaluation_jsonl/truthfulqa-{model}_final_base.jsonl', 'r') as file:
    for line in file:
        json_object = json.loads(line.strip())
        base_predictions.append(json_object['response']['body']['choices'][0]['message']['content'])

correct_and_total_counts = {} #key: category, value: (correct_count, total_count)
for i, pred in enumerate(base_predictions):
    current_category = categories[i]
    if current_category in correct_and_total_counts:
        correct_count, total_count = correct_and_total_counts[current_category]
        if pred == "CORRECT":
            correct_count += 1
        total_count += 1
        correct_and_total_counts[current_category] = (correct_count, total_count)
    else:
        total_count, correct_count = 1, 0
        if pred == "CORRECT":
            correct_count += 1
        correct_and_total_counts[current_category] = (correct_count, total_count)

base_accuracies = {key:value[0]/value[1] for key, value in correct_and_total_counts.items()}

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Original dictionaries
dict1 = base_accuracies
dict2 = intervened_accuracies

# Create a DataFrame for better handling
df = pd.DataFrame({
    'Base Model Accuracy': dict1,
    'Intervened Model Accuracy': dict2
}).reset_index().rename(columns={'index': 'Category'})

# Sort by Base Model Accuracy values for better visualization
df = df.sort_values(by='Base Model Accuracy', ascending=False)

# Set up figure and axes with larger figure size
fig, ax = plt.subplots(figsize=(14, 10))

# Set width of bars
barWidth = 0.35

# Set positions of bars on X axis
r1 = np.arange(len(df))
r2 = [x + barWidth for x in r1]

# Create bars
ax.bar(r1, df['Base Model Accuracy'], width=barWidth, edgecolor='grey', label='Base Model Accuracy', color='lightgrey')
ax.bar(r2, df['Intervened Model Accuracy'], width=barWidth, edgecolor='grey', label='Intervened Model Accuracy', color='#FF9500')

# Add xticks on the middle of the group bars
plt.xlabel('Category', fontweight='bold', fontsize=20)
plt.ylabel('Accuracy', fontweight='bold', fontsize=20)
plt.title('Final Accuracy Comparison by TruthfulQA Category', fontweight='bold', fontsize=24)
plt.xticks([r + barWidth/2 for r in range(len(df))], df['Category'], rotation=75, fontsize=18)
plt.yticks(fontsize=18)

# Ensure y-axis starts at 0 and ends at 1 for accuracy values
plt.ylim(0, 1.1)

# Create legend & Show graphic
plt.legend(fontsize=16)
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Adjust layout
plt.tight_layout()

# Display the plot
plt.savefig("accuracy_comparison_truthfulqa.pdf", bbox_inches='tight', dpi=300)
plt.show()

# Hyperparameter Sweep Heatmap for each metrics

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Font sizes
title_fontsize = 18
axis_label_fontsize = 18
xtick_fontsize = 14  # Customizable X-tick font size
ytick_fontsize = 14  # Customizable Y-tick font size
annotation_fontsize = 16 # For heatmap cell annotations
colorbar_label_fontsize = 16 # For the colorbar label
colorbar_tick_fontsize = 12 # For the colorbar ticks

# Figure and Plot settings
figure_width = 11 # Adjusted slightly to accommodate colorbar better
figure_height = 9
heatmap_cmap = "Blues"
heatmap_vmin = 0.3
heatmap_vmax = 0.6
heatmap_fmt = ".3f"
colorbar_label_text = "Initial Accuracy"
tight_layout_pad = 1.5 # Padding for tight_layout

# Output file
output_filename = "initial_accuracy_mha_hyperparam.pdf"
# --- End Customizable Parameters ---


# Create a DataFrame
df = pd.DataFrame(initial_accuracies, index=TOP_K_HEADS, columns=SCALES)

# Set up the matplotlib figure
plt.figure(figsize=(figure_width, figure_height))

# Create the heatmap
heatmap = sns.heatmap(df,
                      annot=True,
                      cmap=heatmap_cmap,
                      vmin=heatmap_vmin,
                      vmax=heatmap_vmax,
                      annot_kws={"size": annotation_fontsize, "weight": "normal"}, # Adjusted weight
                      fmt=heatmap_fmt,
                      cbar=True,  # Ensure colorbar is present
                      cbar_kws={'label': colorbar_label_text} # Add label to colorbar
                     )

# Customize colorbar label font size and tick font size
cbar = heatmap.collections[0].colorbar
cbar.set_label(colorbar_label_text, fontsize=colorbar_label_fontsize)
cbar.ax.tick_params(labelsize=colorbar_tick_fontsize)


# Add title and labels
plt.title("Effect of varying top k heads and intervention strength\non initial accuracy for Gemma-3",
          fontsize=title_fontsize, pad=20) # Added some padding to title
plt.xlabel("Intervention Strength", fontsize=axis_label_fontsize)
plt.ylabel("Top k heads", fontsize=axis_label_fontsize)

# Set tick label font size
plt.xticks(fontsize=xtick_fontsize, rotation=45, ha="right") # Added rotation for better readability if scales are long
plt.yticks(fontsize=ytick_fontsize)

# Adjust layout to make sure everything fits
plt.tight_layout(pad=tight_layout_pad)

# Save the plot to PDF
# It's often good to save before plt.show()
# bbox_inches='tight' ensures the saved figure includes all elements without extra whitespace
plt.savefig(output_filename, bbox_inches='tight', dpi=300)
print(f"Plot saved to {output_filename}")

# Display the plot
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Font sizes
title_fontsize = 18
axis_label_fontsize = 18
xtick_fontsize = 14  # Customizable X-tick font size
ytick_fontsize = 14  # Customizable Y-tick font size
annotation_fontsize = 16 # For heatmap cell annotations
colorbar_label_fontsize = 16 # For the colorbar label
colorbar_tick_fontsize = 12 # For the colorbar ticks

# Figure and Plot settings
figure_width = 11 # Adjusted slightly to accommodate colorbar better
figure_height = 9
heatmap_cmap = "Blues"
heatmap_vmin = 0.3
heatmap_vmax = 0.6
heatmap_fmt = ".3f"
colorbar_label_text = "Final Accuracy"
tight_layout_pad = 1.5 # Padding for tight_layout

# Output file
output_filename = "final_accuracy_mha_hyperparam.pdf"
# --- End Customizable Parameters ---


# Create a DataFrame
df = pd.DataFrame(final_accuracies, index=TOP_K_HEADS, columns=SCALES)

# Set up the matplotlib figure
plt.figure(figsize=(figure_width, figure_height))

# Create the heatmap
heatmap = sns.heatmap(df,
                      annot=True,
                      cmap=heatmap_cmap,
                      vmin=heatmap_vmin,
                      vmax=heatmap_vmax,
                      annot_kws={"size": annotation_fontsize, "weight": "normal"}, # Adjusted weight
                      fmt=heatmap_fmt,
                      cbar=True,  # Ensure colorbar is present
                      cbar_kws={'label': colorbar_label_text} # Add label to colorbar
                     )

# Customize colorbar label font size and tick font size
cbar = heatmap.collections[0].colorbar
cbar.set_label(colorbar_label_text, fontsize=colorbar_label_fontsize)
cbar.ax.tick_params(labelsize=colorbar_tick_fontsize)


# Add title and labels
plt.title("Effect of varying top k heads and intervention strength\non final accuracy for Gemma-3",
          fontsize=title_fontsize, pad=20) # Added some padding to title
plt.xlabel("Intervention Strength", fontsize=axis_label_fontsize)
plt.ylabel("Top k heads", fontsize=axis_label_fontsize)

# Set tick label font size
plt.xticks(fontsize=xtick_fontsize, rotation=45, ha="right") # Added rotation for better readability if scales are long
plt.yticks(fontsize=ytick_fontsize)

# Adjust layout to make sure everything fits
plt.tight_layout(pad=tight_layout_pad)

# Save the plot to PDF
# It's often good to save before plt.show()
# bbox_inches='tight' ensures the saved figure includes all elements without extra whitespace
plt.savefig(output_filename, bbox_inches='tight', dpi=300)
print(f"Plot saved to {output_filename}")

# Display the plot
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Font sizes
title_fontsize = 18
axis_label_fontsize = 18
xtick_fontsize = 14  # Customizable X-tick font size
ytick_fontsize = 14  # Customizable Y-tick font size
annotation_fontsize = 16 # For heatmap cell annotations
colorbar_label_fontsize = 16 # For the colorbar label
colorbar_tick_fontsize = 12 # For the colorbar ticks

# Figure and Plot settings
figure_width = 11 # Adjusted slightly to accommodate colorbar better
figure_height = 9
heatmap_cmap = "Blues_r"
heatmap_vmin = 0.3
heatmap_vmax = 0.6
heatmap_fmt = ".3f"
colorbar_label_text = "Shift to Incorrect"
tight_layout_pad = 1.5 # Padding for tight_layout

# Output file
output_filename = "shift_to_incorrect_mha_hyperparam.pdf"
# --- End Customizable Parameters ---


# Create a DataFrame
df = pd.DataFrame(shifts, index=TOP_K_HEADS, columns=SCALES)

# Set up the matplotlib figure
plt.figure(figsize=(figure_width, figure_height))

# Create the heatmap
heatmap = sns.heatmap(df,
                      annot=True,
                      cmap=heatmap_cmap,
                      vmin=heatmap_vmin,
                      vmax=heatmap_vmax,
                      annot_kws={"size": annotation_fontsize, "weight": "normal"}, # Adjusted weight
                      fmt=heatmap_fmt,
                      cbar=True,  # Ensure colorbar is present
                      cbar_kws={'label': colorbar_label_text} # Add label to colorbar
                     )

# Customize colorbar label font size and tick font size
cbar = heatmap.collections[0].colorbar
cbar.set_label(colorbar_label_text, fontsize=colorbar_label_fontsize)
cbar.ax.tick_params(labelsize=colorbar_tick_fontsize)


# Add title and labels
plt.title("Effect of varying top k heads and intervention strength\non shift to incorrect rate for Gemma-3",
          fontsize=title_fontsize, pad=20) # Added some padding to title
plt.xlabel("Intervention Strength", fontsize=axis_label_fontsize)
plt.ylabel("Top k heads", fontsize=axis_label_fontsize)

# Set tick label font size
plt.xticks(fontsize=xtick_fontsize, rotation=45, ha="right") # Added rotation for better readability if scales are long
plt.yticks(fontsize=ytick_fontsize)

# Adjust layout to make sure everything fits
plt.tight_layout(pad=tight_layout_pad)

# Save the plot to PDF
# It's often good to save before plt.show()
# bbox_inches='tight' ensures the saved figure includes all elements without extra whitespace
plt.savefig(output_filename, bbox_inches='tight', dpi=300)
print(f"Plot saved to {output_filename}")

# Display the plot
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Font sizes
title_fontsize = 18
axis_label_fontsize = 18
xtick_fontsize = 14  # Customizable X-tick font size
ytick_fontsize = 14  # Customizable Y-tick font size
annotation_fontsize = 16 # For heatmap cell annotations
colorbar_label_fontsize = 16 # For the colorbar label
colorbar_tick_fontsize = 12 # For the colorbar ticks

# Figure and Plot settings
figure_width = 11 # Adjusted slightly to accommodate colorbar better
figure_height = 9
heatmap_cmap = "Blues"
heatmap_vmin = 0
heatmap_vmax = 0.25
heatmap_fmt = ".3f"
colorbar_label_text = "KL Divergence"
tight_layout_pad = 1.5 # Padding for tight_layout

# Output file
output_filename = "kl_divergences_mha_hyperparam.pdf"
# --- End Customizable Parameters ---


# Create a DataFrame
df = pd.DataFrame(kl_divergences, index=TOP_K_HEADS, columns=SCALES)

# Set up the matplotlib figure
plt.figure(figsize=(figure_width, figure_height))

# Create the heatmap
heatmap = sns.heatmap(df,
                      annot=True,
                      cmap=heatmap_cmap,
                      vmin=heatmap_vmin,
                      vmax=heatmap_vmax,
                      annot_kws={"size": annotation_fontsize, "weight": "normal"}, # Adjusted weight
                      fmt=heatmap_fmt,
                      cbar=True,  # Ensure colorbar is present
                      cbar_kws={'label': colorbar_label_text} # Add label to colorbar
                     )

# Customize colorbar label font size and tick font size
cbar = heatmap.collections[0].colorbar
cbar.set_label(colorbar_label_text, fontsize=colorbar_label_fontsize)
cbar.ax.tick_params(labelsize=colorbar_tick_fontsize)


# Add title and labels
plt.title("Effect of varying top k heads and intervention strength\non KL Divergences for Gemma-3",
          fontsize=title_fontsize, pad=20) # Added some padding to title
plt.xlabel("Intervention Strength", fontsize=axis_label_fontsize)
plt.ylabel("Top k heads", fontsize=axis_label_fontsize)

# Set tick label font size
plt.xticks(fontsize=xtick_fontsize, rotation=45, ha="right") # Added rotation for better readability if scales are long
plt.yticks(fontsize=ytick_fontsize)

# Adjust layout to make sure everything fits
plt.tight_layout(pad=tight_layout_pad)

# Save the plot to PDF
# It's often good to save before plt.show()
# bbox_inches='tight' ensures the saved figure includes all elements without extra whitespace
plt.savefig(output_filename, bbox_inches='tight', dpi=300)
print(f"Plot saved to {output_filename}")

# Display the plot
plt.show()