In [None]:
import os
from os.path import dirname

import torch
import numpy as np
import pandas as pd

from collections import defaultdict

tch_device = torch.device("cpu")
tch_dtype = torch.float32

In [None]:
def get_unique_mask(prtinv, prtcnt):
    """
        Generates mask for unique protein samples
        
        Parameters
        ----------
        prtinv: (torch.tensor) indicates where elements in the original 
            input ended up in the unique list. This has a shape of `(n_seedsamps,)`.
        
        prtcnt: (torch.tensor) The count of each unique protein. This has 
            a shape of `(n_unq,)`.
    
    
        Outputs
        -------
        mask1d: (torch.tensor) dtype (torch.bool) `(n_seedsamps,)`.
        
    """
    n_seedsamps, = prtinv.shape
    n_unq, = prtcnt.shape
    
    # The following is only necessary for creating the "first occurance boolean mask".
    #prtinvind_sorted = torch.argsort(prtinv, stable=True) # stable is in torch version 1.13.0
    pp = prtinv + torch.arange(n_seedsamps, device=prtinv.device)/n_seedsamps # same functionality as stable
    prtinvind_sorted = torch.argsort(pp)
    assert prtinvind_sorted.shape == (n_seedsamps,)
    prtcnt_cs = prtcnt.cumsum(dim=0)
    assert prtcnt_cs.shape == (n_unq,)
    prtcnt_cszp = torch.cat([prtcnt.new_zeros(1), prtcnt_cs[:-1]])
    assert prtcnt_cszp.shape == (n_unq,)
    mask_idx = prtinvind_sorted[prtcnt_cszp].sort().values
    assert mask_idx.shape == (n_unq,)
    mask1d = torch.zeros(n_seedsamps, device=prtinv.device, dtype=torch.bool)
    mask1d[mask_idx] = 1
    assert mask1d.shape == (n_seedsamps,)
#     mask = mask1d.reshape(n_seeds, n_samps)
#     assert mask.shape == (n_seeds, n_samps)

    return mask1d

In [None]:
def BStrap(df, grp_cols, rng_vars, target_vars, n_boot, q_low=0.05, q_high=0.95, tch_rng=None):
    n_seeds = None
    all_setts, all_seed_y = [], []
    
    n_trg = len(target_vars)
    groups = list(df.groupby(grp_cols))
    n_sett = len(groups) 
    for (setting, df_sett) in groups:    
        #####################################
        # check the sanity of the subset df #
        #####################################
        for col in df_sett.columns:
            if not(col in rng_vars+target_vars):
                assert len(df_sett[col].unique()) == 1

        if n_seeds is None:
            n_seeds = df_sett.shape[0]
        all_setts.append(setting)
        seed_y_arr = df_sett[target_vars].values
        seed_y_arr = seed_y_arr.astype(np.float32)
        assert seed_y_arr.shape == (n_seeds, n_trg)
        all_seed_y.append(seed_y_arr)

    sett_seed_y = np.stack(all_seed_y, axis=0)
    assert sett_seed_y.shape == (n_sett, n_seeds, n_trg)
    
    
    sett_seed_y_tnsr = torch.tensor(sett_seed_y)
    assert sett_seed_y_tnsr.shape == (n_sett, n_seeds, n_trg)

    #########################
    ### bootstrap indices ###
    #########################

    ind_boot = torch.randint(n_seeds, size=(n_sett, n_boot*n_seeds, 1), generator=tch_rng)
    assert ind_boot.shape == (n_sett, n_boot*n_seeds, 1)
    sett_seed_y_boot = torch.take_along_dim(sett_seed_y_tnsr, ind_boot, dim=-2)
    assert sett_seed_y_boot.shape == (n_sett, n_boot*n_seeds, n_trg)
    sett_seed_y_boot = sett_seed_y_boot.reshape(n_sett, n_boot, n_seeds, n_trg)
    assert sett_seed_y_boot.shape == (n_sett, n_boot, n_seeds, n_trg)
    sett_y_boot = sett_seed_y_boot.mean(dim=-2) # average for each bootstrapped set
    assert sett_y_boot.shape == (n_sett, n_boot, n_trg)

    # confidence interval and the mean for the bootstrapped stat
    sett_y_low = sett_y_boot.quantile(q=q_low, dim=-2).detach().cpu().numpy()
    assert sett_y_low.shape == (n_sett, n_trg)
    sett_y_high = sett_y_boot.quantile(q=q_high, dim=-2).detach().cpu().numpy()
    assert sett_y_high.shape == (n_sett, n_trg)
    sett_y_mean = sett_y_boot.mean(dim=-2).detach().cpu().numpy()
    assert sett_y_mean.shape == (n_sett, n_trg)

    #####################################################
    ### dataframe of settings with bootstrapped stats ###
    #####################################################

    dfs = pd.DataFrame(all_setts, columns = grp_cols)
    dfl = pd.DataFrame(sett_y_low, columns = [f"{var}_low" for var in target_vars])
    dfh = pd.DataFrame(sett_y_high, columns = [f"{var}_high" for var in target_vars])
    dfm = pd.DataFrame(sett_y_mean, columns = [f"{var}_mean" for var in target_vars])

    df_bs = pd.concat([dfs, dfl, dfh, dfm], axis=1)
    assert df_bs.shape == (n_sett, len(grp_cols) + 3*n_trg)

    col_order = grp_cols
    for col in target_vars:
        col_order += [f'{col}_mean', f'{col}_low', f'{col}_high']
    df_bs = df_bs[col_order]
    
    return df_bs

In [None]:
data_types = ["pinn"]
sep_dict = {'(30, 40)': 0, '(40, 50)': 1, '(50, 60)': 2, '(60, 70)': 3, '(70, 80)': 4}
qs = torch.tensor([0.5, 0.7, 0.9, 1.]) # quantile of property to compute
res_cfg_dict = defaultdict(list)

raw_dpath = "../summary/pinn/01_raw.tsv"
#!rm {raw_dpath}
if not os.path.exists(raw_dpath):
    ######################
    # pgvae+ experiments #
    ######################
    mbo_steps = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
    for data_type in data_types:
        resdir = f"../results/{data_type}"
        for fname in os.listdir(resdir):
            if not fname.endswith(".pt"):
                continue
            print(fname)
            res = torch.load(f"{resdir}/{fname}")
            x_org, y_org = res["x"], res["y"]
            step, cfg_it = res["step"], res["cfg_it"]
            method_name = res["method_name"]
            n_samples_gen = res["n_samples_gen"]
            weighted_opt_firststep = res["weighted_opt_firststep"]
            orc_spec = res["orc_spec"]
            imb_ratio = orc_spec["ro"]
            tr_size = orc_spec["N"]
            high_perc_range = "fix"
            distance = orc_spec["high_perc_range"]

            for mbo_step in mbo_steps:
                ###############################################
                # get the samples upto the specified mbo step #
                ###############################################
                ns, xs, *xd = x_org.shape
                mask = step < mbo_step
                assert mask.shape == (ns, xs)
                assert step.shape == (ns, xs)
                mask_re = mask.reshape((ns*xs,))

                step_re = step.reshape((ns*xs, ))
                step_re = step_re[mask_re]
                step_re = step_re.reshape((ns, -1))

                # filter x
                x_re = x_org.reshape((ns*xs, *xd))
                y_re = y_org.reshape((ns*xs,))
                x_re, y_re = x_re[mask_re], y_re[mask_re]
                x = x_re.reshape((ns, -1, *xd))
                y = y_re.reshape((ns, -1))
                assert x.shape[1] == step_re.shape[1]

                #####################
                # get the quantiles #
                #####################
                ns, xs, *xd = x.shape

                #########################
                ### Remove duplicates ###
                #########################
                assert y.shape == (ns, xs)
                x_gpu = x.to("cuda")
                prtinv, prtcnt = apply_unique(x_gpu, x_gpu.device) #
                mask1d = get_unique_mask(prtinv, prtcnt)
                assert mask1d.shape == (ns*xs,)
                mask1d_re = mask1d.reshape((ns, xs))
                assert mask1d_re.shape == (ns, xs)
                mask1d_re = mask1d_re.to("cpu")

                # real zero values should not get mixed up with zero mask values
                y_masked = y.detach().clone()
                zero_replace = 123*1e05
                y_masked[y_masked == 0] == zero_replace

                y_masked = y_masked * mask1d_re
                assert y_masked.shape == (ns, xs)
                y_masked[y_masked == 0] = float('nan')
                # convert zeros back to normal
                y_masked[y_masked == zero_replace] = 0


                #######################################################
                ### get quantiles for all samples upto the mbo step ###
                #######################################################

                yqs = torch.nanquantile(y_masked, qs, dim=-1)
                assert yqs.shape == (qs.numel(), ns)
                yqs = yqs.transpose(0, 1)
                assert yqs.shape == (ns, qs.numel())
                yqs_np = yqs.detach().numpy()

                for nss in range(ns):
                    # accumulating the stats for different runs
                    res_cfg_dict["cfg_it"].append(cfg_it)
                    res_cfg_dict["data_type"].append(data_type)
                    res_cfg_dict["method_name"].append(method_name)
                    res_cfg_dict["wopt_fst_step"].append(weighted_opt_firststep)
                    res_cfg_dict["mbo_step"].append(mbo_step)
                    res_cfg_dict["hrange"].append(high_perc_range)
                    res_cfg_dict["distance"].append(distance)
                    res_cfg_dict["imb_ratio"].append(imb_ratio)
                    res_cfg_dict["n_samples_gen"].append(n_samples_gen)
                    res_cfg_dict["tr_size"].append(tr_size)
                    res_cfg_dict["fname"].append(fname)
                    res_cfg_dict["seed"].append(nss)

                    for qid, q in enumerate(qs):
                        res_cfg_dict[f"y_{int(100*q)}"].append(yqs_np[nss, qid])
    
    df_stats = pd.DataFrame.from_dict(res_cfg_dict, orient="index").transpose()
    # df_stats: per setting, mbo_step and seed report the target variables (percentiles of property)
    df_stats["wopt_fst_step"] = df_stats["wopt_fst_step"].fillna(False)
    df_stats["wopt_fst_step"] = df_stats["wopt_fst_step"].apply(lambda x: "/wopt" if x else "")
    df_stats["method_name"] = df_stats["method_name"] + df_stats["wopt_fst_step"]
    
    ####################################
    ### Compute relative performance ###
    ####################################
    
    sett_vars = ["data_type", "method_name", "wopt_fst_step", "hrange", "distance", "imb_ratio", "n_samples_gen", "seed"]
    groups = list(df_stats.groupby(sett_vars))
    target_vars = [f"y_{int(100*q)}" for q in qs]

    df_lst = []
    for i, (grp_sett, df_grp) in enumerate(groups):
        assert len(df_grp["mbo_step"]) == len(df_grp["mbo_step"].unique())
        df_grp2 = df_grp.copy(deep=True)
        for target_var in target_vars:
            y0 = df_grp[df_grp["mbo_step"] == 0][target_var].values.item()
            df_grp2[f"d{target_var}"] = df_grp2[target_var] - y0
        df_lst.append(df_grp2)
    df_stats = pd.concat(df_lst, axis=0, ignore_index=True)
    
    os.makedirs(dirname(raw_dpath), exist_ok=True)
    df_stats.to_csv(raw_dpath, sep="\t", index=False)
else:
    print("loading stats from file")
    df_stats = pd.read_csv(raw_dpath, sep="\t")

In [None]:
#########################################################
### Compute the confidence interval via bootstrapping ###
#########################################################

bts_dpath = "../summary/pinn/02_bts.prq"
# !rm {bts_dpath}
if not os.path.exists(bts_dpath):
    # set of variables identifying a setting
    sett_vars = ["data_type", "method_name", "wopt_fst_step", "imb_ratio", "hrange", "distance", "n_samples_gen"]
    # target variables to compute the confidence interval for
    target_vars = [f"y_{int(100*q)}" for q in qs] + [f"dy_{int(100*q)}" for q in qs]
    rng_vars = ["seed"]
    step_vars = ["mbo_step"]

    ### bootstap settings ###
    n_boot = 1000
    q_low, q_high = 0.05, 0.95
    grp_cols = sett_vars + step_vars

    tch_seed = 1234567891011
    tch_rng = torch.Generator(device=tch_device)
    tch_rng.manual_seed(tch_seed)

    df_bs = BStrap(df_stats, grp_cols, rng_vars, target_vars, n_boot, q_low=0.05, q_high=0.95, tch_rng=tch_rng)
    # define the separation level
    df_bs["sep_level"] = df_bs["distance"].apply(lambda x: sep_dict[x])

    os.makedirs(dirname(bts_dpath), exist_ok=True)
    df_bs.to_parquet(bts_dpath)
else:
    print("loading form file")
    df_bs = pd.read_parquet(bts_dpath)

In [None]:
######### Bootstrapping for imb ratios combined #########
#########################################################
### Compute the confidence interval via bootstrapping ###
#########################################################

btsg_dpath = "../summary/pinn/03_btsg.prq"
# !rm {btsg_dpath}
if not os.path.exists(btsg_dpath):
    # set of variables identifying a setting
    sett_vars = ["data_type", "method_name", "wopt_fst_step", "hrange", "distance", "n_samples_gen"]
    # target variables to compute the confidence interval for
    target_vars = [f"y_{int(100*q)}" for q in qs] + [f"dy_{int(100*q)}" for q in qs]
    rng_vars = ["seed", "imb_ratio"]
    step_vars = ["mbo_step"]

    ### bootstap settings ###
    n_boot = 1000
    q_low, q_high = 0.05, 0.95
    grp_cols = sett_vars + step_vars

    tch_seed = 1234567891011
    tch_rng = torch.Generator(device=tch_device)
    tch_rng.manual_seed(tch_seed)

    # exclude these columns because it changes from one setting to another (imb_ratio is included in seeds)
    df_stats_ = df_stats.drop(["cfg_it", "fname", "tr_size"], axis=1)
    df_bsg = BStrap(df_stats_, grp_cols, rng_vars, target_vars, n_boot, q_low=0.05, q_high=0.95, tch_rng=tch_rng)
    df_bsg["sep_level"] = df_bsg["distance"].apply(lambda x: sep_dict[x])
    
    os.makedirs(dirname(btsg_dpath), exist_ok=True)
    df_bsg.to_parquet(btsg_dpath)
else:
    print("loading form file")
    df_bsg = pd.read_parquet(btsg_dpath)