In [1]:
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, roc_auc_score
import os
sns.set_theme(context="paper")
from matplotlib import rcParams
rcParams['font.family'] = 'serif'
rcParams['font.serif'] = ['Times New Roman']  # or another serif font of your choice
rcParams['font.weight'] = 'normal'
title_fontsize = 19
label_fontsize = 19
legend_title_fontsize = 12
legend_fontsize = 10

In [3]:
colors = sns.color_palette("bright", n_colors=4)
colors = [colors[i] for i in [0, 2, 3]]

# Color mapping for consistency
model_names = [
    'Prompt-Guard-86M',
    'LLaMA-2-7B-Chat',
    'LLaMA-2-7B',
]
color_match = dict(zip(model_names, colors))

bw_adjust = 0.2

In [4]:
def load_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    probs = np.array(data['poisoned_probs'])
    labels = data['targets']
    # swap 9 to 1 and 1 to 0
    labels = np.array([1 if x == 9 else 0 for x in labels])
    labels = np.array(labels)
    return labels, probs

def plot_roc(ax, data_dict, title):
    # ax.clear()
    for model_name, (labels, probs) in data_dict.items():
        fpr, tpr, _ = roc_curve(labels, probs)
        roc_auc = roc_auc_score(labels, probs)
        ax.plot(fpr, tpr, lw=2, label=f'{model_name} (AUC = {roc_auc:.3f})', color=color_match[model_name])
    
    ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    ax.set_xlim([-0.005, 1.0])
    ax.set_ylim([0.0, 1.01])
    ax.set_xlabel('False Positive Rate', fontsize=label_fontsize)
    if not 'val' in title.lower():
        ax.set_ylabel('True Positive Rate', fontsize=label_fontsize)
    ax.legend(loc="lower right", frameon=True, fontsize=legend_fontsize, framealpha=1, edgecolor='black', facecolor='white',
              title_fontsize=legend_title_fontsize, title="Model")
    ax.set_title(f'ROC Curves - {title}', fontsize=title_fontsize, fontweight='bold')

def plot_density(ax, labels, probs, model_name, split):
    ax.clear()
    color = color_match[model_name]
    sns.kdeplot(probs[labels == 1], ax=ax, fill=True, color=color, alpha=0.5, label="Positive", clip=(0, 1), bw_adjust=bw_adjust)
    sns.kdeplot(probs[labels == 0], ax=ax, fill=True, linewidth=2, color=color, alpha=0.25, label="Negative", clip=(0, 1), bw_adjust=bw_adjust)
    ax.set_xlabel("Score / Probability", fontsize=label_fontsize)
    ax.set_ylabel("Density", fontsize=label_fontsize)
    ax.legend(title="Scores", facecolor='white', frameon=True, fontsize=legend_fontsize, framealpha=1, edgecolor='black',
              title_fontsize=legend_title_fontsize, loc='upper center')
    # ax.set_xlim(0.0, 1.0)
    ax.set_xlim([-0.005, 1.0])
    ax.set_title(f'{model_name} - {split}', fontsize=title_fontsize, fontweight='bold')
    sns.despine(ax=ax)

def load_all_data(directories):
    all_data = {'test': {}, 'val': {}}
    steps = range(499, 9000, 500)
    
    for dir in directories:
        model_name = dir.split('_')[-3]
        if 'chat' in model_name.lower():
            model_name = 'LLaMA-2-7B-Chat'
        elif 'prompt-guard' in model_name.lower():
            model_name = 'Prompt-Guard-86M'
        else:
            model_name = 'LLaMA-2-7B'
        
        for split in ['test', 'val']:
            all_data[split][model_name] = {}
            for step in steps:
                file_path = os.path.join(dir, f'{split}_step_{step}.json')
                if os.path.exists(file_path):
                    labels, probs = load_data(file_path)
                    all_data[split][model_name][step] = {'labels': labels, 'probs': probs}
    
    return all_data, steps

In [5]:
def animate(frame):
    step = steps[frame]
    
    # Clear the entire figure
    plt.clf()
    
    # Recreate the subplot grid
    fig, axs = plt.subplots(4, 2, figsize=(12, 15), sharex=True, sharey='row')
    
    # Update ROC curves
    test_data = {model: (data[step]['labels'], data[step]['probs']) for model, data in all_data['test'].items()}
    val_data = {model: (data[step]['labels'], data[step]['probs']) for model, data in all_data['val'].items()}
    
    if step == 499:
        test_title = f'Test Set (Step  {step+1})'
        val_title = f'Validation Set (Step  {step+1})'
    else:
        test_title = f'Test Set (Step {step+1})'
        val_title = f'Validation Set (Step {step+1})'
    plot_roc(axs[0, 0], test_data, test_title)
    plot_roc(axs[0, 1], val_data, val_title)
    
    # Update density plots
    for i, model in enumerate(all_data['test'].keys()):
        plot_density(axs[i+1, 0], all_data['test'][model][step]['labels'], all_data['test'][model][step]['probs'], model, f'Test Set (Step {step+1})')
        plot_density(axs[i+1, 1], all_data['val'][model][step]['labels'], all_data['val'][model][step]['probs'], model, f'Validation Set (Step {step+1})')
    
    plt.tight_layout()
    
    # Save each frame as a PDF
    plt.savefig(f'animation/roc_density_frame_{frame:03d}.pdf')
    plt.close()

# Main execution
directories = [
    './model_9000_steps_500_eval_shuffle_False_base_llama-2-7b_bs_6',
    './model_9000_steps_500_eval_shuffle_False_base_llama-2-7b-chat_bs_6',
    './model_9000_steps_500_eval_shuffle_False_base_Prompt-Guard-86M_bs_6'
]

all_data, steps = load_all_data(directories)

# Generate each frame
for frame in range(len(steps)):
    animate(frame)

print("Individual PDF frames have been generated.")

Individual PDF frames have been generated.


<Figure size 640x480 with 0 Axes>