In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Forget01

In [None]:
import os
import json
import matplotlib.pyplot as plt


alpha_list = [0, 1, 2, 3, 4] # your set of values
alpha_list = sorted(alpha_list)

metric_keys = [
    'extraction_strength', 'extraction_strength_forget_para', 'extraction_strength_forget_para_pqa', 'extraction_strength_forget_para_pqpa', 
    'extraction_strength_ra', 'extraction_strength_retain', 'extraction_strength_retain_para', 'extraction_strength_retain_para_pqa',
    'extraction_strength_retain_para_pqpa', 'extraction_strength_wf'
]

metrics_original = {key: [] for key in metric_keys}
metrics_ethos = {key: [] for key in metric_keys}
metrics_proposed = {key: [] for key in metric_keys}
metrics_proposedfull = {key: [] for key in metric_keys}

paths = {
    'original': 'yourpath/src/tv/tofu_Llama-3.1-8B-Instruct/original_tv/weight',
    'erasor': 'yourpath/src/tv/tofu_Llama-3.1-8B-Instruct/erasor/ratio90',
    'erasorfull': 'yourpath/src/tv/tofu_Llama-3.1-8B-Instruct/erasor/ratiofull',
    'ethos': 'yourpath/src/tv/tofu_Llama-3.1-8B-Instruct/ethos/rank005'
}

def collect_metrics(base_path, metrics_dict):
    if 'original' in base_path:
        path_name = 'forget01_alpha_'  # Can change to forget05, forget10
    else:
        path_name = 'forget01_tv_alpha_'  # Can change to forget05, forget10
    for alpha in alpha_list:
        json_path = os.path.join(
            base_path,
            path_name + str(alpha),
            'evals',
            'TOFU_SUMMARY.json'
        )
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
                for key in metric_keys:
                    metrics_dict[key].append(data.get(key, None))
        except FileNotFoundError:
            for key in metric_keys:
                metrics_dict[key].append(None)

collect_metrics(paths['original'], metrics_original)
collect_metrics(paths['ethos'], metrics_ethos)
collect_metrics(paths['erasor'], metrics_proposed)
collect_metrics(paths['erasorfull'], metrics_proposedfull)

In [None]:
import numpy as np

retain_path = "yourpath/saves/finetune/tofu_Llama-3.1-8B-Instruct_retain99/evals/TOFU_SUMMARY.json"  # Can change to retain95, retain90
full_path = "yourpath/saves/finetune/tofu_Llama-3.1-8B-Instruct_full/evals_forget01/TOFU_SUMMARY.json"  # Can change to forget05, forget10

with open(retain_path, "r") as f:
    retain_metrics = json.load(f)
with open(full_path, "r") as f:
    full_metrics = json.load(f)

In [None]:
full_model_utility = full_metrics['extraction_strength_retain']
threshold = full_model_utility * 0.95
print(f"[INFO] Full of extraction_strength_retain: {full_model_utility:.6f}")
print(f"[INFO] Threshold for extraction_strength_retain (95% of full): {threshold:.6f}")

def find_best_alpha(metrics_dict, alpha_list, threshold):
    best_alpha = None
    best_fq = float('inf')
    best_mu = None

    for alpha, fq, mu in zip(alpha_list, metrics_dict["extraction_strength"], metrics_dict["extraction_strength_retain"]):
        if mu is not None and fq is not None and mu >= threshold:
            if fq <= best_fq:
                best_fq = fq
                best_alpha = alpha
                best_mu = mu
                
    if best_alpha is None:
        print(f"[WARN] No alpha satisfying the threshold. Using fallback: max extraction_strength_retain.")
        max_idx = None
        max_mu = -float('inf')

        for idx, mu in enumerate(metrics_dict["extraction_strength_retain"]):
            if mu is not None and mu > max_mu:
                max_mu = mu
                max_idx = idx

        if max_idx is not None:
            best_alpha = alpha_list[max_idx]
            best_mu = metrics_dict["extraction_strength_retain"][max_idx]
            best_fq = metrics_dict["extraction_strength"][max_idx]
        else:
            best_alpha = {}
            best_mu = None
            best_fq = None

    return best_alpha, best_fq, best_mu


alpha_tv, fq_tv, mu_tv = find_best_alpha(metrics_original, alpha_list, threshold)
alpha_ethos, fq_ethos, mu_ethos = find_best_alpha(metrics_ethos, alpha_list, threshold)
alpha_svd2, fq_svd2, mu_svd2 = find_best_alpha(metrics_proposedfull, alpha_list, threshold)
alpha_svd, fq_svd, mu_svd = find_best_alpha(metrics_proposed, alpha_list, threshold)


print(f"- Original     : alpha = {alpha_tv}, forget = {fq_tv}, retain = {mu_tv}")
print(f"- ETHOS  : alpha = {alpha_ethos}, forget = {fq_ethos}, retain = {mu_ethos}")
print(f"- ERASOR    : alpha = {alpha_svd}, forget = {fq_svd}, retain = {mu_svd}")
print(f"- ERASOR full   : alpha = {alpha_svd2}, forget = {fq_svd2}, retain = {mu_svd2}")

In [None]:
full_model_utility = full_metrics['extraction_strength_retain']
threshold = full_model_utility * 0.90
print(f"[INFO] Full of extraction_strength_retain: {full_model_utility:.6f}")
print(f"[INFO] Threshold for extraction_strength_retain (90% of full): {threshold:.6f}")

def find_best_alpha(metrics_dict, alpha_list, threshold):
    best_alpha = None
    best_fq = float('inf')
    best_mu = None

    for alpha, fq, mu in zip(alpha_list, metrics_dict["extraction_strength"], metrics_dict["extraction_strength_retain"]):
        if mu is not None and fq is not None and mu >= threshold:
            if fq <= best_fq:
                best_fq = fq
                best_alpha = alpha
                best_mu = mu
                
    if best_alpha is None:
        print(f"[WARN] No alpha satisfying the threshold. Using fallback: max extraction_strength_retain.")
        max_idx = None
        max_mu = -float('inf')

        for idx, mu in enumerate(metrics_dict["extraction_strength_retain"]):
            if mu is not None and mu > max_mu:
                max_mu = mu
                max_idx = idx

        if max_idx is not None:
            best_alpha = alpha_list[max_idx]
            best_mu = metrics_dict["extraction_strength_retain"][max_idx]
            best_fq = metrics_dict["extraction_strength"][max_idx]
        else:
            best_alpha = {}
            best_mu = None
            best_fq = None

    return best_alpha, best_fq, best_mu


alpha_tv, fq_tv, mu_tv = find_best_alpha(metrics_original, alpha_list, threshold)
alpha_ethos, fq_ethos, mu_ethos = find_best_alpha(metrics_ethos, alpha_list, threshold)
alpha_svd2, fq_svd2, mu_svd2 = find_best_alpha(metrics_proposedfull, alpha_list, threshold)
alpha_svd, fq_svd, mu_svd = find_best_alpha(metrics_proposed, alpha_list, threshold)


print(f"- Original     : alpha = {alpha_tv}, forget = {fq_tv}, retain = {mu_tv}")
print(f"- ETHOS  : alpha = {alpha_ethos}, forget = {fq_ethos}, retain = {mu_ethos}")
print(f"- ERASOR    : alpha = {alpha_svd}, forget = {fq_svd}, retain = {mu_svd}")
print(f"- ERASOR full   : alpha = {alpha_svd2}, forget = {fq_svd2}, retain = {mu_svd2}")

In [None]:
import pandas as pd
import math


target_alphas = {
    "Original": 0.5, # best number for original
    "ETHOS": 1.3, # best number for ethos
    "ERASOR": 1.4, # best number for erasor
    "ERASORfull": 3.4 # best number for erasor
}

metrics_dicts = {
    "Original": metrics_original,
    "ETHOS": metrics_ethos,
    "ERASOR": metrics_proposed,
    "ERASORfull": metrics_proposedfull,
}


metric_keys = [
    'extraction_strength_retain', 'extraction_strength_retain_para_pqa',
    "extraction_strength_ra", "extraction_strength_wf",
    'extraction_strength', 'extraction_strength_forget_para_pqa'
]


def get_all_metrics(metrics_dict, alpha_list, target_alpha, metric_keys):
    results = {}
    if target_alpha not in alpha_list:
        return {k: None for k in metric_keys}

    idx = alpha_list.index(target_alpha)

    for key in metric_keys:
        vals = metrics_dict.get(key, [None]*len(alpha_list))
        results[key] = vals[idx] if idx < len(vals) else None
    return results


def extract_metrics_direct(metrics_dict, metric_keys):
    row = {}
    for key in metric_keys:
        row[key] = metrics_dict.get(key, None)
    return row


rows = []
rows.append({
    "Method": "RETAIN",
    "Alpha": "N/A",
    **extract_metrics_direct(retain_metrics, metric_keys)
})

rows.append({
    "Method": "FULL",
    "Alpha": "N/A",
    **extract_metrics_direct(full_metrics, metric_keys)
})


for method, alpha in target_alphas.items():
    metrics_dict = metrics_dicts[method]
    row = {"Method": method, "Alpha": alpha}
    row.update(get_all_metrics(metrics_dict, alpha_list, alpha, metric_keys))
    rows.append(row)


df = pd.DataFrame(rows)
print(df.to_string(index=False))
