In [58]:
import pickle
import numpy as np
from collections import Counter, defaultdict
from scipy.stats import poisson

In [10]:
def tvd(p, q):
    return 0.5 * np.sum(np.abs(p - q))

In [53]:
def compute_tvd_for_poisson(stats, lambda_param, special_char='*', temperature=0.7, max_value=20):
    counts = Counter(stats[(special_char, lambda_param, temperature)])
    total_samples = len(stats[(special_char, lambda_param, temperature)])
    empirical_probs = np.array([counts.get(i, 0) / total_samples for i in range(max_value + 1)])
    poisson_probs = np.array([poisson.pmf(i, lambda_param) for i in range(max_value + 1)])
    return 0.5 * np.sum(np.abs(empirical_probs - poisson_probs))

def compute_auc_for_poisson(stats):
    symbols = set([k[0] for k in stats.keys()])
    lambdas = set([k[1] for k in stats.keys()])
    temperatures = set([k[2] for k in stats.keys()])
    # Aggregate the TVD scores for each lambda
    scores = []
    for lambda_param in lambdas:
        ls = []
        for symbol in symbols:
            for temperature in (1.0,):
                ls.append(compute_tvd_for_poisson(stats, lambda_param, symbol, temperature))
        scores.append(np.mean(ls))

    # Compute the AUC
    return np.trapezoid(sorted(scores))


In [78]:
def compute_auc_for_mdp(stats):

    true_transition = {
        "1": [0.0, 0.5, 0.0, 0.5],
        "2": [0.2, 0.0, 0.8, 0.0],
        "3": [0.1, 0.2, 0.1, 0.6],
        "4": [0.1, 0.1, 0.1, 0.7]
    }

    trajectories_trans = [stats[i:i+1000] for i in range(0, len(stats), 1000)]

     # Keep track of local and global indices for each state
    idxs_trajectory_transition = []
    for global_idx, traj_trans in enumerate(trajectories_trans):
        for local_idx, (state, trans) in enumerate(traj_trans):
            idxs_trajectory_transition.append(
                {"local_idx": local_idx,
                "global_idx": (global_idx * 1000) + local_idx,
                "state": state,
                "transition": trans
                }
            )

    # Average out a transition at local idx {X} if it is observed at that same point across all trajectories
    xy = defaultdict(list)
    for state in ["1", "2", "3", "4"]:
        selected_states = [x for x in idxs_trajectory_transition if x["state"] == state]
        local_group = defaultdict(list)
        for item in selected_states:
            local_group[item["local_idx"]].append(item["transition"])

        for k, v in local_group.items():
            xy[state].append([k, sum(v) / len(v)])

    # Calculate TVD per time step
    xys = {}
    for state, item in xy.items():
        tvds = []
        for idx, obs_trans in item:
            tvds.append((idx, tvd(true_transition[state], obs_trans)))

        xys[state] = tvds

    # Calculate AUC for all four states
    aucs = []
    for state, item in xys.items():
        aucs.append(
            np.trapezoid(
                y=[x[1] for x in sorted(item, key=lambda x: x[0])],
                x=[x[0] for x in sorted(item, key=lambda x: x[0])]
            )
        )

    return np.mean(aucs) / len(trajectories_trans)

In [96]:
def compute_toss_score(model: str):

    # Get the TVD-AUC score for both the coin_flip_multiset and die_roll_multiset experiments
    with open(f'../exps/contextual/coin_flip_multiset/{model}/stats.pkl', 'rb') as f:
        coin_flip_stats = pickle.load(f)
        coin_flip_auc = coin_flip_stats['averaged_stats']['biased_point']['tvd_auc']['mean']
    with open(f'../exps/contextual/die_roll_multiset/{model}/stats.pkl', 'rb') as f:
        die_roll_stats = pickle.load(f)
        die_roll_auc = die_roll_stats['averaged_stats']['biased_point']['tvd_auc']['mean']

    # Get the TVD score for the random number experiments
    with open(f'../exps/contextual/random_number/{model}/stats.pkl', 'rb') as f:
        coin_flip_random_number_stats = pickle.load(f)
        auc_data = []
        for k, v in coin_flip_random_number_stats['biased_point'].items():
            auc_data.append(tvd(v['choice_probs'], np.exp(v["expected_dist"])))
        random_number_auc = np.trapezoid(sorted(auc_data))

    # MULTINOMIAL
    # print(f'TVD-AUC score for coin_flip_multiset: {coin_flip_auc}')
    # print(f'TVD-AUC score for die_roll_multiset: {die_roll_auc}')
    # print(f'TVD-AUC score for random_number: {random_number_auc}')

    # POISSON
    with open(f'../exps/poisson/{model}/stats.pkl', 'rb') as f:
        poisson_stats = pickle.load(f)
        poisson_auc = compute_auc_for_poisson(poisson_stats)

    # print(f'TVD-AUC score for poisson: {poisson_auc}')


    # Normalizing coefficients



    # MDP
    # with open(f'/home/ritwik/dev/random_needles/exps/mdp/simple_mdp/{model}/stats.pkl', 'rb') as f:
    #     mdp_stats = pickle.load(f)
    #     mdp_auc = compute_auc_for_mdp(mdp_stats)
    # print(f'TVD-AUC score for MDP: {mdp_auc}')

    data = {
        "coin_flip_multiset": coin_flip_auc / 3.609,
        "die_roll_multiset": die_roll_auc / 4.104,
        "random_number": random_number_auc / 0.5023,
        "poisson": poisson_auc / 1.502,
    }

    toss_score = 0.4 * data['coin_flip_multiset'] # + 0.4 * data['die_roll_multiset'] + 0.1 * data['random_number'] + 0.1 * data['poisson']

    return toss_score


In [100]:
# List of all models being used.
# INSTRUCT_MODELS=(
#     google/gemma-2-2b-it
#     meta-llama/Llama-3.1-8B-Instruct
#     microsoft/Phi-3.5-mini-instruct
#     mistralai/Mistral-7B-Instruct-v0.3
#     allenai/OLMoE-1B-7B-0924-Instruct
# )

# STD_MODELS=(
#     google/gemma-2-2b # TODO
#     meta-llama/Llama-3.1-8B # TODO
#     microsoft/phi-2
#     mistralai/Mistral-7B-v0.3
#     allenai/OLMoE-1B-7B-0924
# )
# /home/davidchan/Repos/random_needles/exps/contextual/coin_flip_multiset/meta-llama_Llama-3.1-8B-Instruct
data = []
for model in [
    'google_gemma-2-2b-it',
    'meta-llama_Llama-3.1-8B-Instruct',
    'microsoft_Phi-3.5-mini-instruct',
    'mistralai_Mistral-7B-Instruct-v0.3',
    'allenai_OLMoE-1B-7B-0924-Instruct',
    'google_gemma-2-2b',
    'meta-llama_Llama-3.1-8B',
    'microsoft_phi-2',
    'mistralai_Mistral-7B-v0.3',
    'allenai_OLMoE-1B-7B-0924'
]:
    try:
        print(f'{model}: {compute_toss_score(model)}')
    except Exception as e:
        print(f'Error for model {model}: {e}')

google_gemma-2-2b-it: 0.31970926787912196
meta-llama_Llama-3.1-8B-Instruct: 0.3486717239011687
microsoft_Phi-3.5-mini-instruct: 0.45306309522371246
Error for model mistralai_Mistral-7B-Instruct-v0.3: [Errno 2] No such file or directory: '../exps/contextual/die_roll_multiset/mistralai_Mistral-7B-Instruct-v0.3/stats.pkl'
allenai_OLMoE-1B-7B-0924-Instruct: 0.477922006022945
google_gemma-2-2b: 0.41001955049989036
meta-llama_Llama-3.1-8B: 0.37865860832840426
microsoft_phi-2: 0.3687636525669912
Error for model mistralai_Mistral-7B-v0.3: [Errno 2] No such file or directory: '../exps/contextual/die_roll_multiset/mistralai_Mistral-7B-v0.3/stats.pkl'
allenai_OLMoE-1B-7B-0924: 0.443379131998255


In [99]:
# Compute the min, max, and mean for each model across each experiment
