In [1]:
from matplotlib import pyplot as plt
import numpy as np
import pickle
import os

plt.style.use('ggplot')

In [2]:
def compute_metrics_ablation(results_dict, dataset_list, methods, metrics):
    summary = dict()
    for metric in metrics:
        summary[metric] = []
        
        metric_arr = []
        for dataset_name in dataset_list:
            methods_perms = dict()
            for method in methods:
                methods_perms[method] = []
            
            outcomes = results_dict[model]['results_dic']
            outcomes = outcomes[list(outcomes.keys())[0]]
            
            for seed in outcomes:
                outcome = outcomes[seed][list(outcomes[seed].keys())[0]]
                for m in outcome:
                    if m in methods:
                        methods_perms[m].append(outcome[m])
            
            values_to_plot = []
            for method in methods:
                acc_mean = np.mean(np.array([d[metric] for d in methods_perms[method]]))
                acc_var = np.std(np.array([d[metric] for d in methods_perms[method]]))
                values_to_plot.append((acc_mean, acc_var))
            metric_arr.append(values_to_plot)
        
        summary[metric].append(methods)
        # average accuracy over all datasets
        summary[metric].append(np.mean(np.array([[v[0] for v in vs] for vs in metric_arr]),axis=0))
        # average variance over all datasets 
        summary[metric].append(np.mean(np.array([[v[1] for v in vs] for vs in metric_arr]),axis=0)) 

    return summary

In [3]:
methods_to_compare = ['LR-average_voting']
metrics_to_compare = ['accuracy','macro_F1']

In [30]:
results_dir = '../results'

model='Llama'
dataset_name = 'subjective'

file_path = os.path.join(results_dir, "/hpc/home/jd420/Projects/ICL/results/results_k_[8]_seeds_5_datasets_['subjective']_models_['Llama']_.pkl")
with open(file_path,'rb') as f:
    baseline_results = pickle.load(f)

print(compute_metrics_ablation(baseline_results, ['subjective'], ['LR-6-average_voting'], metrics_to_compare))

{'accuracy': [['LR-6-average_voting'], array([0.65234375]), array([0.04934879])], 'macro_F1': [['LR-6-average_voting'], array([0.63602752]), array([0.06791602])]}


In [31]:
results_dir = '../results'

model='Qwen'
dataset_name = 'sst5'

file_path = os.path.join(results_dir, "/hpc/home/jd420/Projects/ICL/results/results_k_[8]_seeds_5_datasets_['sst5']_models_['Qwen']_.pkl")
with open(file_path,'rb') as f:
    baseline_results = pickle.load(f)

print(compute_metrics_ablation(baseline_results, ['sst5'], methods_to_compare, metrics_to_compare))

{'accuracy': [['LR-average_voting'], array([0.43515625]), array([0.04276229])], 'macro_F1': [['LR-average_voting'], array([0.41064634]), array([0.02801832])]}


## Ablation on Fix Weight

In [17]:
results_dir = '../results/Ablation'

model='Llama'
dataset_name = 'subjective'

file_path = os.path.join(results_dir, f"results_k_[8]_seeds_5_datasets_['{dataset_name}']_models_['{model}']_ablation_fix_weights.pkl")
with open(file_path,'rb') as f:
    results = pickle.load(f)

In [18]:
print(compute_metrics_ablation(results, ['subjective'], methods_to_compare, metrics_to_compare))

{'accuracy': [['LR-average_voting'], array([0.615625]), array([0.03988203])], 'macro_F1': [['LR-average_voting'], array([0.59646778]), array([0.06592935])]}


In [19]:
results_dir = '../results/Ablation'

model='Qwen'
dataset_name = 'sst5'

file_path = os.path.join(results_dir, f"results_k_[8]_seeds_5_datasets_['{dataset_name}']_models_['{model}']_ablation_fix_weights.pkl")
with open(file_path,'rb') as f:
    results = pickle.load(f)

In [20]:
print(compute_metrics_ablation(results, ['subjective'], methods_to_compare, metrics_to_compare))

{'accuracy': [['LR-average_voting'], array([0.4]), array([0.06218671])], 'macro_F1': [['LR-average_voting'], array([0.29417353]), array([0.04281093])]}


## Ablation on Invariance Constraints

In [21]:
results_dir = '../results/Ablation'

model='Llama'
dataset_name = 'subjective'

results_dicts = dict()
for level in ['no', 'small', 'large']:
    file_path = os.path.join(results_dir, f"results_k_[8]_seeds_5_datasets_['{dataset_name}']_models_['{model}']_ablation_{level}_invar.pkl")
    with open(file_path,'rb') as f:
        results_dicts[level] = pickle.load(f)

In [22]:
for sample, results in results_dicts.items():
    print(sample)
    print(compute_metrics_ablation(results, ['subjective'], methods_to_compare, metrics_to_compare))

no
{'accuracy': [['LR-average_voting'], array([0.61875]), array([0.04153136])], 'macro_F1': [['LR-average_voting'], array([0.60367096]), array([0.06001094])]}
small
{'accuracy': [['LR-average_voting'], array([0.5390625]), array([0.])], 'macro_F1': [['LR-average_voting'], array([0.35025381]), array([0.])]}
large
{'accuracy': [['LR-average_voting'], array([0.5390625]), array([0.])], 'macro_F1': [['LR-average_voting'], array([0.35025381]), array([0.])]}


## Ablation on the Number of K-learners

In [33]:
with open("/hpc/home/jd420/Projects/ICL/results/results_k_[16]_seeds_5_datasets_['subjective']_models_['Llama']_ablation.pkl",'rb') as f:
    results_dicts = pickle.load(f)

In [39]:
methods_k_learners = ['LR-average_voting-first-2', 'LR-average_voting-first-3', 'LR-average_voting-first-4', 'LR-average_voting-first-5', 'LR-average_voting-first-6', 'LR-average_voting-first-7', 'LR-average_voting-first-8', 'LR-average_voting-first-9', 'LR-average_voting-first-10', 'LR-average_voting-first-11']
model = 'Llama'
compute_metrics_ablation(results_dicts, ['subjective'], methods_k_learners, metrics_to_compare)

{'accuracy': [['LR-average_voting-first-2',
   'LR-average_voting-first-3',
   'LR-average_voting-first-4',
   'LR-average_voting-first-5',
   'LR-average_voting-first-6',
   'LR-average_voting-first-7',
   'LR-average_voting-first-8',
   'LR-average_voting-first-9',
   'LR-average_voting-first-10',
   'LR-average_voting-first-11'],
  array([0.52109375, 0.49140625, 0.49765625, 0.52265625, 0.5765625 ,
         0.62890625, 0.63828125, 0.64453125, 0.6703125 , 0.68515625]),
  array([0.04619295, 0.05586893, 0.06631888, 0.08169201, 0.07536532,
         0.06743233, 0.06937746, 0.06652104, 0.0391873 , 0.03957477])],
 'macro_F1': [['LR-average_voting-first-2',
   'LR-average_voting-first-3',
   'LR-average_voting-first-4',
   'LR-average_voting-first-5',
   'LR-average_voting-first-6',
   'LR-average_voting-first-7',
   'LR-average_voting-first-8',
   'LR-average_voting-first-9',
   'LR-average_voting-first-10',
   'LR-average_voting-first-11'],
  array([0.49323093, 0.40202427, 0.39615483, 0.41

## Ablation on the Number of ICL Examples

In [6]:
results_dir = '../results/Ablation'

model='Llama'
dataset_name = 'subjective'

results_dicts = dict()
for sample in [4,12,24,48]:
    file_path = os.path.join(results_dir, f"results_k_[8]_seeds_5_datasets_['{dataset_name}']_models_['{model}']_ablation_samples_{sample}.pkl")
    with open(file_path,'rb') as f:
        results_dicts[sample] = pickle.load(f)

In [14]:
file_path = os.path.join(results_dir, f"results_k_[8]_seeds_5_datasets_['{dataset_name}']_models_['{model}']_ablation_samples_4.pkl")
with open(file_path,'rb') as f:
    results_4 = pickle.load(f)

In [15]:
file_path = os.path.join(results_dir, f"results_k_[8]_seeds_5_datasets_['{dataset_name}']_models_['{model}']_ablation_samples_12.pkl")
with open(file_path,'rb') as f:
    results_12 = pickle.load(f)

In [7]:
for sample, results in results_dicts.items():
    print(sample)
    print(compute_metrics_ablation(results, ['subjective'], methods_to_compare, metrics_to_compare))

4
{'accuracy': [['LR-average_voting'], array([0.5390625]), array([0.])], 'macro_F1': [['LR-average_voting'], array([0.35025381]), array([0.])]}
12
{'accuracy': [['LR-average_voting'], array([0.5390625]), array([0.])], 'macro_F1': [['LR-average_voting'], array([0.35025381]), array([0.])]}
24
{'accuracy': [['LR-average_voting'], array([0.5390625]), array([0.])], 'macro_F1': [['LR-average_voting'], array([0.35025381]), array([0.])]}
48
{'accuracy': [['LR-average_voting'], array([0.5390625]), array([0.])], 'macro_F1': [['LR-average_voting'], array([0.35025381]), array([0.])]}
