In [4]:
import wandb
import pandas as pd
import numpy as np
import scipy.stats as st

In [2]:
api = wandb.Api()
entity, project = "your-entity", "your-project"
experiment_name = "your-experiment"
run_name = "your-run"

# Let's first extract the dataframe from the wandb project

In [82]:
runs = api.runs(entity + "/" + project, {
    "$and": [{
        'config.experiment_name': experiment_name,
        'config.run_name': run_name
    }]
}) 

key_list = ["model", "user model", "logging policy", "temperature", "test policy", "seed", "metric rank"]
metric_key_list = []
name_list, summary_list, config_list, model_list, usermodel_list, logpi_list, testpi_list, rank_list, temp_list, seed_list = [[]], [[]], [[]], [[]], [[]], [[]], [[]], [[]], [[]], [[]]
metrics_list = []
ranks = [str(i+1) for i in range(10)] + ["-"] # Hardcoded for now
flip = True

for i, run in enumerate(runs):
    if "Metrics/test/loss" not in run.summary._json_dict:
        # Discarding unfinished runs
        continue

    name_list[0] += [run.name] * len(ranks)
    summary_list[0] += [run.summary._json_dict] * len(ranks)
    config_list[0]  += [
        {k: v for k,v in run.config.items()
         if not k.startswith('_')}] * len(ranks)
    
    model_list[0] += [run.config["model"]["_target_"].split(".")[-1]] * len(ranks)
    usermodel_list[0] += [run.config["data"]["train_simulator"]["user_model"]["_target_"].split(".")[-1]] * len(ranks)
    logpi_target = run.config["data"]["train_policy"]["_target_"].split(".")[-1]
    
    if logpi_target == "NoisyOraclePolicy":
        noise = run.config["data"]["train_policy"]["noise"]
        if noise == 0.0:
            logpi_list[0] += ["OraclePolicy"] * len(ranks)
        else:
            logpi_list[0] += ["NoisyOraclePolicy"] * len(ranks)
    else:
        logpi_list[0] += [logpi_target] * len(ranks)

    temp_list[0] += [run.config["data"]["train_simulator"]["temperature"]] * len(ranks)
    testpi_target = run.config["data"]["test_policy"]["_target_"].split(".")[-1]

    if testpi_target == "NoisyOraclePolicy":
        noise = run.config["data"]["test_policy"]["noise"]
        if noise == 0.0:
            testpi_list[0] += ["OraclePolicy"] * len(ranks)
        else:
            testpi_list[0] += ["NoisyOraclePolicy"] * len(ranks)
    else:
        testpi_list[0] += [testpi_target] * len(ranks)

    seed_list[0] += [run.config["random_state"]] * len(ranks)
    rank_list[0] += ranks
    
    if flip:
        flip = False
        
        for k, v in run.summary._json_dict.items():
            if k[:7] != "Metrics" or k[8:13] == "train":
                continue
            metric_name = k.split("/")[1] + "/" + k.split("/")[2].split("@")[0]
            
            if metric_name not in metric_key_list:
                metric_key_list.append(metric_name)

    metrics_list.append(np.zeros((len(metric_key_list), len(ranks))))
    
    for k, v in run.summary._json_dict.items():
        
        if k[:7] != "Metrics" or k[8:13] == "train":
            continue
        metric_name = k.split("/")[1] + "/" + k.split("/")[2].split("@")[0]
        k_idx = metric_key_list.index(metric_name)
        
        if len(k.split("/")[2].split("@")) == 2:
            rank_idx = ranks.index(k.split("/")[2].split("@")[1])
        else: 
            rank_idx = -1
        metrics_list[-1][k_idx, rank_idx] = v
        
key_list += metric_key_list + ["summary", "config", "name"]

value_list = [model_list[0], usermodel_list[0], logpi_list[0], temp_list[0], testpi_list[0], seed_list[0], rank_list[0]] + np.hstack(metrics_list).tolist() + [summary_list[0], config_list[0], name_list[0]]

runs_df = pd.DataFrame(dict(zip(key_list, value_list)))

print("%d runs selected" % (len(runs_df) // len(ranks)))

131 runs selected


# Regret in OPS

Let's define policies for OPS

In [3]:
def oracle(cm_set, n = None):
    argmin = cm_set["test/ppl"].idxmin()
    return cm_set.loc[argmin]["model"], cm_set.loc[argmin]["test/ppl"]

def ppl(cm_set, n = None):
    argmin = cm_set["val/ppl"].idxmin()
    return cm_set.loc[argmin]["model"], cm_set.loc[argmin]["test/ppl"]

def ndcg(cm_set, n = None):
    argmax = cm_set["test/nDCG"].idxmax()
    return cm_set.loc[argmax]["model"], cm_set.loc[argmax]["test/ppl"]

def cmip(cm_set, n = None):
    argmin = cm_set["test/cmi"].idxmin()
    return cm_set.loc[argmin]["model"], cm_set.loc[argmin]["test/ppl"]

def ppl_st_ndcg(cm_set, thresh, n = None):
    thresh_set = cm_set[cm_set["test/nDCG"] > thresh]

    if len(thresh_set) > 0:
        argmin = thresh_set["val/ppl"].idxmin()
    else:
        argmin = cm_set["val/ppl"].idxmin()
    return cm_set.loc[argmin]["model"], cm_set.loc[argmin]["test/ppl"]

def ndcg_st_cmip(cm_set, thresh, n = None):
    thresh_set = cm_set[cm_set["test/cmi"] < thresh]
    
    if len(thresh_set) > 0:
        argmax = thresh_set["test/nDCG"].idxmax()
    else:
        argmax = cm_set["test/cmi"].idxmin()
    return cm_set.loc[argmax]["model"], cm_set.loc[argmax]["test/ppl"]


def ndcg_st_debiased(cm_set, n = None):
    debiased_set = cm_set[cm_set["test/pointwise_ci"] == False]
    
    if len(debiased_set) > 0:
        argmax = debiased_set["test/nDCG"].idxmax()
    else:
        argmax = cm_set["test/cmi"].idxmin()
    return cm_set.loc[argmax]["model"], cm_set.loc[argmax]["test/ppl"]

def ppl_st_debiased(cm_set, n = None):
    debiased_set = cm_set[cm_set["test/pointwise_ci"] == False]
    
    if len(debiased_set) > 0:
        argmin = debiased_set["val/ppl"].idxmin()
    else:
        argmin = cm_set["test/cmi"].idxmin()
    return cm_set.loc[argmin]["model"], cm_set.loc[argmin]["test/ppl"]

def ndcg_st_topncmip(cm_set, n):
    top3_set = cm_set.nsmallest(n, "test/cmi", keep = 'all')
    argmax = top3_set["test/nDCG"].idxmax()
    return cm_set.loc[argmax]["model"], cm_set.loc[argmax]["test/ppl"]

def ppl_st_topncmip(cm_set, n):
    top3_set = cm_set.nsmallest(n, "test/cmi", keep = 'all')
    argmin = top3_set["val/ppl"].idxmin()
    return cm_set.loc[argmin]["model"], cm_set.loc[argmin]["test/ppl"]

def ppl_st_topnndcg(cm_set, n):
    top3_set = cm_set.nlargest(n, "test/nDCG", keep = 'all')
    argmin = top3_set["val/ppl"].idxmin()
    return cm_set.loc[argmin]["model"], cm_set.loc[argmin]["test/ppl"]

def ndcg_st_topnppl(cm_set, n):
    top3_set = cm_set.nsmallest(n, "val/ppl", keep = 'all')
    argmax = top3_set["test/nDCG"].idxmax()
    return cm_set.loc[argmax]["model"], cm_set.loc[argmax]["test/ppl"]

def ppl_st_topnndcg_topn_cmip(cm_set, n):
    top3_ndcg_set = cm_set.nlargest(n, "test/nDCG", keep = 'all')
    top3_cmip_set = cm_set.nsmallest(n, "test/cmi", keep = 'all')
    selected_set = pd.merge(top3_ndcg_set, top3_cmip_set, how='inner', left_index=True, right_index=True)
    
    if len(selected_set) > 0:
        argmin = selected_set["val/ppl_x"].idxmin()
    else:
        argmin = top3_cmip_set["val/ppl"].idxmin()
    return cm_set.loc[argmin]["model"], cm_set.loc[argmin]["test/ppl"]

def ndcg_st_topnppl_topn_cmip(cm_set, n):
    top3_ppl_set = cm_set.nsmallest(n, "val/ppl", keep = 'all')
    top3_cmip_set = cm_set.nsmallest(n, "test/cmi", keep = 'all')
    selected_set = pd.merge(top3_ppl_set, top3_cmip_set, how='inner', left_index=True, right_index=True)
    
    if len(selected_set) > 0:
        argmax = selected_set["test/nDCG_x"].idxmax()
    else:
        argmax = top3_cmip_set["test/nDCG"].idxmax()
    return cm_set.loc[argmax]["model"], cm_set.loc[argmax]["test/ppl"]

Now, we can compute the regret of each policy in each configuration:

In [79]:
from itertools import chain, combinations
def powerset(items, nc):
    l_items = list(items)
    return list(chain.from_iterable(combinations(l_items, k) for k in nc))
usermodels = ["GradedPBM", "GradedDBN", "MixtureDBN", "GradedCarousel"]
logpis = ["NoisyOraclePolicy", "LightGBMRanker"]
testpis = ["NoisyOraclePolicy", "LightGBMRanker", "UniformPolicy"]
seeds = [2023, 3901, 2837, 47969, 3791, 3807, 8963, 11289, 75656, 31277]

ops_policies = {"↓PPL": ppl, "↑nDCG": ndcg, "↓CMIP": cmip, "↓PPL s.t. top-4 CMIP": ppl_st_topncmip, 
                "↑nDCG s.t. top-4 CMIP": ndcg_st_topncmip,"↓PPL s.t. top-4 nDCG": ppl_st_topnndcg, 
                "↑nDCG s.t. top-4 PPL": ndcg_st_topnppl, "↓PPL s.t. top-4 nDCG, top-4 CMIP": ppl_st_topnndcg_topn_cmip, 
                "↑nDCG s.t. top-4 PPL, top-4 CMIP": ndcg_st_topnppl_topn_cmip}
regrets = {ops_pi: {seed: {usermodel: {logpi: [] for logpi in logpis} for usermodel in usermodels} for seed in seeds} for ops_pi in ops_policies.keys()}

for usermodel in usermodels:
    for logpi in logpis:
        for testpi in testpis:
            if testpi == logpi:
                continue
            for seed in seeds:
                cm_set = runs_df[(runs_df["user model"] == usermodel)
                                 & (runs_df["logging policy"] == logpi)
                                 & (runs_df["test policy"] == testpi)
                                 & (runs_df["seed"] == seed)
                                 & (runs_df["metric rank"] == "-")
                                ]

                if len(cm_set) != 0:
                    for candidates in powerset(range(len(cm_set)), [5, 6, 7]):
                        cm_set_sample = cm_set.iloc[list(candidates)]
                        oracle_model, oracle_reward = oracle(cm_set_sample.copy(deep = True))

                        for key, ops_pi in ops_policies.items():
                            _, reward = ops_pi(cm_set_sample, n = 4)

                            regrets[key][seed][usermodel][logpi].append(reward - oracle_reward)

for ops_pi in ops_policies.keys():
    for seed in seeds:
        for usermodel in usermodels:
            for logpi in logpis:
                regrets[ops_pi][seed][usermodel][logpi] = np.mean(regrets[ops_pi][seed][usermodel][logpi]) * 1000

for ops_pi in ops_policies.keys():   
    for seed in seeds:
        for usermodel in usermodels:
            regrets[ops_pi][seed][usermodel]["\tTotal"] = np.mean(list(regrets[ops_pi][seed][usermodel].values()))
        for logpi in logpis:
            if "Total\t" not in regrets[ops_pi][seed]:
                regrets[ops_pi][seed]["Total\t"] = {}
            regrets[ops_pi][seed]["Total\t"][logpi] = np.mean([regrets[ops_pi][seed][um][logpi] for um in usermodels])

        regrets[ops_pi][seed]["Total\t"]["\tTotal"] = np.mean([[regrets[ops_pi][seed][um][logpi] for um in usermodels] for logpi in logpis])


In [80]:
for ops_pi in ops_policies.keys():
    print("----------------------")
    print("Policy: ", ops_pi)
    print(*(["\t\t"] + [lp + "\t" for lp in regrets[ops_pi][seeds[0]][usermodels[0]].keys()]))
    for um in (usermodels + ["Total\t"]):
        print(um + "\t", end = "")
        for lp in (logpis + ["\tTotal"]): 
            vals = [regrets[ops_pi][seed][um][lp] for seed in seeds]
            #print("%.4f (+- %.4f)\t" % (np.mean(vals), np.std(vals)), end = "")
            interval = st.t.interval(confidence=0.95, df=len(vals)-1, loc=np.mean(vals), scale=st.sem(vals))
            print("%.4f (+- %.4f)\t" % (np.mean(vals), interval[1] - np.mean(vals)), end = "")
        print("")

----------------------
Policy:  ↓PPL
		 NoisyOraclePolicy	 LightGBMRanker	 	Total	
GradedPBM	8.1150 (+- 0.2738)	7.5491 (+- 0.2494)	7.8321 (+- 0.2244)	
GradedDBN	0.2938 (+- 0.2331)	0.6354 (+- 0.0998)	0.4646 (+- 0.1234)	
MixtureDBN	0.4555 (+- 0.1445)	0.0075 (+- 0.0093)	0.2315 (+- 0.0742)	
GradedCarousel	3.2976 (+- 0.2125)	0.0452 (+- 0.0546)	1.6714 (+- 0.1117)	
Total		3.0405 (+- 0.0765)	2.0593 (+- 0.0651)	2.5499 (+- 0.0506)	
----------------------
Policy:  ↑nDCG
		 NoisyOraclePolicy	 LightGBMRanker	 	Total	
GradedPBM	14.8202 (+- 6.9371)	1.4929 (+- 0.4453)	8.1566 (+- 3.4893)	
GradedDBN	0.3566 (+- 0.2199)	0.2656 (+- 0.1969)	0.3111 (+- 0.1624)	
MixtureDBN	7.6390 (+- 0.1797)	1.6404 (+- 0.4649)	4.6397 (+- 0.3030)	
GradedCarousel	14.2398 (+- 0.0663)	1.4700 (+- 0.2988)	7.8549 (+- 0.1650)	
Total		9.2639 (+- 1.7220)	1.2172 (+- 0.1705)	5.2406 (+- 0.8480)	
----------------------
Policy:  ↓CMIP
		 NoisyOraclePolicy	 LightGBMRanker	 	Total	
GradedPBM	1.7935 (+- 0.6628)	2.3073 (+- 0.2853)	2.0504 (+- 0.