### Updated MEMIT result agg (10-07-23 run)

#### Utils

In [1]:
result_dir = 'log_memit_100723'

LEAGUES = [0.001, 0.0001, 0.00001]
dnames = [
    'company', 
    'country', 
    'verbs', 
    'temporal', 
    'stereoset', 
    'gender'    
]
model_names = [
    'backpack-gpt2',
    'pythia-70m',
    'pythia-160m',
    'pythia-410m',
    'pythia-1b',
    'pythia-1.4b',
    'pythia-2.8b',
    'pythia-6.9b'
]
subject_types = [
    'true_subject', 'prefix_subject'
]
import pandas as pd
import json
import os
from collections import defaultdict, Counter
import matplotlib.pyplot as plt


#### Get results

In [2]:
results_dict = {}
for st in subject_types:
    results_dict[st] = {}
    for m in model_names:
        results_dict[st][m] = {}
        for d in dnames:
            results_dict[st][m][d] = {}
for root, _, files in os.walk(result_dir):
    for fname in files:
        if 'noedit' in fname:
            continue # skip no-edit 

        param_keys = ['model', 'dataset', 'layers', 'v_num_grad_steps', 'clamp_norm_factor', 
                  'mom2_update_weight', 'kl_factor']
        param_dict = dict(zip(param_keys, fname[:-5].split('__')))

        param_str = '__'.join(param_dict[x] for x in param_keys[4:])
        dname = param_dict['dataset'].split('-')[0]
        subject_type = param_dict['dataset'].split('-')[1]
        results_dict[subject_type][param_dict['model']][dname][param_str] = json.load(open(os.path.join(root, fname), 'r'))


In [3]:
# overwrite noedit results for prefix_subject with true_subject
for m in model_names:
    for d in dnames:
        for param_str in results_dict['prefix_subject'][m][d].keys():
            results_dict['prefix_subject'][m][d][param_str]['noedit'] = results_dict['true_subject'][m][d][param_str]['noedit']

            

In [4]:
# check whether results fall in each league
for st in subject_types:
    for m in model_names:
        for d in dnames:
            for param_str in results_dict[st][m][d].keys():
                data = results_dict[st][m][d][param_str] 

                for league in LEAGUES:
                    league_loss_cutoff = data['noedit']['general_score']*(1+league)
                    data['edit']['in_league_{}'.format(league)] = data['edit']['general_score'] < league_loss_cutoff
                data['edit']['intervention_score_delta'] = data['edit']['intervention_score'] - data['noedit']['intervention_score']


In [5]:
# find the best_config for each league
best_configs = {}
for m in model_names:
    best_configs[m] = {}
    for d in dnames:
        for st in subject_types:
            full_dname = '{}-{}'.format(d, st)
            best_configs[m][full_dname] = {}
best_results = {}
for st in subject_types:
    best_results[st] = {}
    for m in model_names:
        best_results[st][m] = {}
        for d in dnames:
            best_results[st][m][d] = {}

for st in subject_types:
    for m in model_names:
        for d in dnames:
            full_dname = '{}-{}'.format(d, st)
            for league in LEAGUES:
                # find the best config for each league

                # find the runs in each league
                options = []
                for param_str in results_dict[st][m][d].keys():
                    data = results_dict[st][m][d][param_str] 
                    if data['edit']['in_league_{}'.format(league)]:
                        options.append(data)
                # find the best run
                lowest_intervention_score_delta = float('inf')
                best_config = None 
                best_data = None
                for data in options:
                    if data['edit']['intervention_score_delta'] < lowest_intervention_score_delta:
                        lowest_intervention_score_delta = data['edit']['intervention_score_delta']
                        best_config = data['edit']['override_params']
                        best_data = data

                best_configs[m][full_dname][league] = best_config
                best_results[st][m][d][league] = best_data

                # # print
                score_deltas = [data['edit']['intervention_score_delta'] for data in options]
                if len(options) == 0:
                #     print("NO OPTIONS")
                    continue 
                best_index = score_deltas.index(min(score_deltas))

                scores = [data['edit']['intervention_score'] for data in options]
                best_score_index = scores.index(min(scores))
                assert best_score_index == best_index

                # print('score_deltas:', score_deltas)
                # print('best_index:', best_index)
                # print('best_score:', score_deltas[best_index])

                # print('chosen config:', options[best_index])
                # print('chosen config:', best_config)


In [6]:
best_results.keys()
best_results['oracle'] = best_results['true_subject']
best_results['prefix'] = best_results['prefix_subject']
del best_results['true_subject']
del best_results['prefix_subject']

In [7]:
best_results['prefix']['pythia-1.4b']['country'][0.001]


{'noedit': {'intervention_score': 0.7989690721649485,
  'general_score': 2.643992486733253,
  'rest_of_prompt_score': 2.6830669109093,
  'hard_negative_score': None},
 'edit': {'intervention_score': 0.7783505154639175,
  'general_score': 2.644253354275785,
  'rest_of_prompt_score': 2.3759197688479055,
  'hard_negative_score': None,
  'override_params': {'layers': None,
   'v_num_grad_steps': 20,
   'clamp_norm_factor': 0.6828108071671678,
   'mom2_update_weight': 62225,
   'kl_factor': 0.07505701446082196},
  'in_league_0.001': True,
  'in_league_0.0001': True,
  'in_league_1e-05': False,
  'intervention_score_delta': -0.020618556701030966}}

In [8]:
with open("memit_results.val.final.json", "w") as fh:
    json.dump(best_results, fh)

#### Make test scripts (using best val config)

In [7]:
make_test_scripts = False

In [8]:
num_trials = 5
out_log_dir = 'log_memit_100723_test_results'
sweep_script_dir = 'sbatches_100723'
sweep_script_write_dir = 'sbatches_100723/test_scripts'

from make_sweep import get_sbatch_header, model_name_to_short, model_name_to_full
# model_to_queue, model_to_jags


def model_to_queue(model_name):
    if '410m' in model_name or '160m' in model_name or '70m' in model_name or 'backpack' in model_name:
        return 'jag-standard'
    else:
        return 'jag-lo'
    
def model_to_jags(model_name):
    # if '6.9b' in model_name or '2.8b' in model_name or '1b' in model_name or 'gpt-j' in model_name or '160m' in model_name:
    #     return ['jagupard37', 'jagupard38', 'jagupard39']
    # elif '1.4b' in model_name:
    #     return ['jagupard32', 'jagupard33', 'jagupard34', 'jagupard35', 'jagupard36']
    # elif '410m' in model_name:
    #     return ['jagupard30', 'jagupard31', ]
    # elif '70m' in model_name or 'backpack' in model_name:
    #     return ['jagupard28', 'jagupard29', ]
    # else:
    #     raise ValueError
    if '6.9b' in model_name or '2.8b' in model_name or '1.4b' in model_name or '1b' in model_name \
        or 'gpt-j' in model_name:
        return ['jagupard37', 'jagupard38', 'jagupard39']
    elif '410m' in model_name or '160m' in model_name or '70m' in model_name or 'backpack' in model_name:
        return ['jagupard32', 'jagupard33', 'jagupard34', 'jagupard35', 'jagupard36']
    else:
        raise ValueError


In [9]:
# make scripts for the run on the test data

if make_test_scripts:
    machine_choosing_index = 0

    dname_cfg_map = {
        'company': 'company_ceo', 'country': 'country_capital', 'verbs': 'verb_conjugation', 
        'temporal': 'temporal', 'stereoset': 'stereoset', 'gender': 'pronoun_gender_bias'
    }
    run_cmds = []
    filenames = []
    for model_name in best_configs:
        for full_dname in best_configs[model_name]:
            for league in best_configs[model_name][full_dname]:
                cur_config = best_configs[model_name][full_dname][league]
                if cur_config is None:
                    print(">> WARNING: NO CONFIG YIELDED FOR", model_name, full_dname, league)
                    continue
                dname, subject_type = full_dname.split('-')


                short_model_name = model_name_to_short(model_name)
                jag_options = model_to_jags(model_name)
                nodelist = jag_options[machine_choosing_index % len(jag_options)]
                machine_choosing_index += 1 


                with open(f"{sweep_script_write_dir}/{short_model_name}_{full_dname}_{league}.sbatch", "w") as fh:
                    filenames.append(f"{sweep_script_write_dir}/{short_model_name}_{full_dname}_{league}.sbatch")

                    print(
                        get_sbatch_header(
                            run_name=f'{short_model_name}_{dname[:3]}_test-sweep', 
                            partition=model_to_queue(model_name), 
                            nodelist=nodelist,
                            log_output_dir=f"{sweep_script_dir}/test_logs",
                            num_hrs=12,
                        ),
                        file=fh
                    )

                    for t in range(num_trials):
                        test_command = (
                            f'python3 run_memit.py "{model_name_to_full[model_name]}" --v_num_grad_steps 20 '
                            f'--clamp_norm_factor {cur_config["clamp_norm_factor"]} '
                            f'--mom2_update_weight {cur_config["mom2_update_weight"]} '
                            f'--kl_factor {cur_config["kl_factor"]} '
                            f'--dataset_names {dname} '
                            f'--subject_types {subject_type} '
                            f'--log_dir {out_log_dir} --test_mode '
                            f'--override_exp_name {short_model_name}__{full_dname}__{league}__trial{t} '
                            f'--seed {t}')

                        run_cmd = (
                            f"{test_command} >> {sweep_script_dir}/test_logs/log.{short_model_name}_{full_dname}_{league}_{t}.txt"
                        )
                        # run_cmd = (
                        #     f"srun --unbuffered run_as_child_processes '{test_command}' "
                        #     f">> {sweep_script_dir}/test_logs/log.{model_name}_{full_dname}_{league}_{t}.txt"
                        # )                    
                        print(run_cmd, file=fh)
                        run_cmds.append(run_cmd)


    for x in filenames:
        print('sbatch', x)
    print(len(filenames))