## Train set and Oracle generation

For each benchmark task, generate train sets at varying imbalance ratios and separation levels. Also generate the oracles. \
Note that for pinn the oracle is read from the file and for gmm
the oracle is constructed from its parameters saved with the train set. \
For semi-synthetic protein datasets, the oracles are created for different lengths of appended sequencs.

In [None]:
import os
import torch
import numpy as np
import pandas as pd

import itertools
import collections
from collections import defaultdict

import matplotlib.pyplot as plt

seed_rng = 12345
torch.manual_seed(seed_rng)
np.random.seed(seed_rng)

In [None]:
# taken from ECNet code: https://github.com/luoyunan/ECNet
from Bio import SeqIO
def _read_native_sequence(data_dir):
    fasta = SeqIO.read(f'{data_dir}/native_sequence.fasta', 'fasta')
    native_sequence = str(fasta.seq)
    return native_sequence

def _mutation_to_sequence(mutation, native_sequence):
    '''
    Parameters
    ----------
    mutation: ';'.join(WiM) (wide-type W at position i mutated to M)
    '''
    sequence = native_sequence
    mut_positions = []
    
    splitter = ";" if (";" in mutation) else ":"
    for mut in mutation.split(splitter):
        wt_aa = mut[0]
        mt_aa = mut[-1]
        pos = int(mut[1:-1])
        assert wt_aa == sequence[pos - 1],\
                "%s: %s->%s (fasta WT: %s)"%(pos, wt_aa, mt_aa, sequence[pos - 1])
        sequence = sequence[:(pos - 1)] + mt_aa + sequence[pos:]
        mut_positions.append(str(pos - 1))
    mut_positions = ",".join(mut_positions)
    return sequence, mut_positions

def _drop_invalid_mutation(df, native_sequence):
    '''
    Drop mutations WiM where
    - W is incosistent with the i-th AA in native_sequence
    - M is ambiguous, e.g., 'X'
    '''
    flags = []
    for mutation in df['mutation'].values:
        splitter = ";" if (";" in mutation) else ":"
        for mut in mutation.split(splitter):
            wt_aa = mut[0]
            mt_aa = mut[-1]
            pos = int(mut[1:-1])
            valid = True if wt_aa == native_sequence[pos - 1] else False
            valid = valid and (mt_aa not in ['X'])
        flags.append(valid)
    df = df[flags].reset_index(drop=True)
    return df

### Protein GB1_syth: 
The length of the appended sequence controls the separation level. \
The length is controlled by varianle `ext_len`.

In [None]:
ext_lens = [3, 4, 5, 8]
for ext_len in ext_lens:
    seed_rng = 12345
    np.random.seed(seed_rng)

    data_type = "protein"
    ds_name = f"protein_gb_synth_l{ext_len}"
    savedir = f"../sample_trainset/{ds_name}"
    os.makedirs(savedir, exist_ok=True)


    ########################
    ### Build the Oracle ###
    ########################

    oracle_dir = "../oracles"
    os.makedirs(oracle_dir, exist_ok=True)
    orc_dir = "../datasets/GB1.txt"
    df_orc = pd.read_csv(orc_dir, sep="\t")
    df_orc = df_orc[["Variants", "Fitness"]]
    df_orc.columns = ["sequence", "score"]
    df_orc["score"] = df_orc["score"]/max(df_orc["score"])
    scores_gt = df_orc["score"]

    ###########################
    ##### concat the seqs #####
    ###########################

    split_thr = 0.001 # 0.01
    df_orc_low = df_orc.loc[df_orc["score"] < split_thr]
    df_orc_high = df_orc.loc[df_orc["score"] >= split_thr]

    CCMPRED_AMINO_ACID_INDEX = collections.OrderedDict(
        {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4,
         'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9,
         'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14,
         'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, '-': 20})

    #####################################
    aas = list(CCMPRED_AMINO_ACID_INDEX.keys())[:-1] # excluding '-'
    nmax = len(aas)
    slow = df_orc_low.shape[0]

    inds = np.random.randint(nmax, size=slow*(ext_len-1))
    inds_end = np.random.randint(nmax-1, size=slow) # ecxlude the last aa for the last position
    aas_arr = np.array(aas)

    seq_ext = aas_arr[inds]
    seq_ext_end = aas_arr[inds_end]
    assert seq_ext.shape == (slow*(ext_len-1),)
    assert seq_ext_end.shape == (slow,)
    seq_ext_arr = seq_ext.reshape(slow, (ext_len-1))
    seq_ext_arr_end = seq_ext_end.reshape(slow, 1)
    seq_ext_arr = np.concatenate((seq_ext_arr, seq_ext_arr_end), axis=-1)
    assert seq_ext_arr.shape == (slow, ext_len)

    seq_ext_low = ["".join(list(seqext)) for seqext in seq_ext_arr]
    df_orc_low["ext"] = seq_ext_low
    df_orc_low["sequence"] = df_orc_low["ext"] + df_orc_low["sequence"]
    df_orc_low = df_orc_low[["sequence", "score"]]

    #####################################

    # generate the extension for the high end
    inds = np.random.randint(nmax, size=ext_len)
    seq_ext_high = list(aas_arr[inds])
    seq_ext_high = "".join(seq_ext_high)
    seq_ext_high = seq_ext_high[:-1] + aas[-1]
    print(seq_ext_high)
    df_orc_high["sequence"] = seq_ext_high + df_orc_high["sequence"]
    df_orc = pd.concat((df_orc_low, df_orc_high))


    # save the oracle
    orc_savedir = f"{oracle_dir}/{ds_name}"
    df_orc.to_csv(orc_savedir, sep= "\t", index=False)
    
    #######################################################
    # train set generation for different imbalance ratios #
    #######################################################
    
    ros = [0.0125, 0.025, 0.05, 0.1, 0.2, 0.4, 0.8]
    Ntrs = [1000]
    high_end_ranges = [(0.1, 0.2)]
    low_end_ranges = [(0.0001, 0.001)]

    cfg_lists = [ros, Ntrs, high_end_ranges, low_end_ranges]
    cfgs = itertools.product(*cfg_lists)

    orc_specs = defaultdict(list)
    for cfg_it, cfg in enumerate(cfgs):
        ro, Ntr, high_end_range, low_end_range = cfg
        low_perc = 10
        print(f"cfg is {cfg}")

        h1_end, h2_end = high_end_range
        l1_end, l2_end = low_end_range
        trs_low = Ntr
        trs_high = int(Ntr * ro)

        idx_low = np.where((l1_end <= df_orc["score"]) & (df_orc["score"] < l2_end))[0]
        rand_idx_low = np.random.choice(idx_low, size=trs_low, replace=False)

        idx_high = np.where((h1_end <= df_orc["score"]) & (df_orc["score"] <= h2_end))[0]
        rand_idx_high = np.random.choice(idx_high, size=trs_high, replace=False)

        df_orc_low = df_orc.iloc[rand_idx_low, :]
        df_orc_high = df_orc.iloc[rand_idx_high, :]
        df = pd.concat((df_orc_low, df_orc_high))
        #print(df_orc_low.shape[0], df_orc_high.shape[0])

        orc_spec = dict(cfg_it=cfg_it, ro=ro, high_end_range=high_end_range,
                        low_end_range=low_end_range, data_type=data_type, 
                        N=int(Ntr*(1+ro)), orc_path=orc_savedir)

        x_tr = df["sequence"].values
        y_tr = df["score"].values
        np.savez(f"{savedir}/ds{cfg_it}", x=x_tr, y=y_tr, orc_spec=orc_spec)

        # to get all confings in one file
        orc_specs["cfg_it"].append(cfg_it)
        orc_specs["ro"].append(ro)
        orc_specs["high_end_range"].append(high_end_range)
        orc_specs["data_type"].append(data_type)
        orc_specs["low_end_range"].append(low_end_range)
        orc_specs["N"].append(int(Ntr*(1+ro)))

        cfg_all = pd.DataFrame.from_dict(orc_specs, orient='index').transpose()
        cfg_all.columns = ["cfg_it", "ro", "high_end_range", "data_type", "low_end_range", "N"]
        cfg_all.to_csv(f"{savedir}/cfgs", sep="\t")

### Protein PhoQ_syth:
The length of the appended sequence controls the separation level. \
The length is controlled by varianle `ext_len`.

In [None]:
ext_lens = [3, 4, 6, 8]
for ext_len in ext_lens:
    seed_rng = 12345
    np.random.seed(seed_rng)

    data_type = "protein"
    ds_name = f"protein_phoq_synth_l{ext_len}"
    savedir = f"../sample_trainset/{ds_name}"
    os.makedirs(savedir, exist_ok=True)


    ########################
    ### Build the Oracle ###
    ########################

    oracle_dir = "../oracles/"
    os.makedirs(oracle_dir, exist_ok=True)
    orc_dir = "../datasets/PhoQ.txt"
    df_orc = pd.read_csv(orc_dir, sep="\t")
    df_orc = df_orc[["Variants", "Fitness"]]
    df_orc.columns = ["sequence", "score"]
    df_orc["score"] = df_orc["score"]/max(df_orc["score"])
    scores_gt = df_orc["score"]

    ###########################
    ##### concat the seqs #####
    ###########################

    split_thr = 0.001 # 0.01
    df_orc_low = df_orc.loc[df_orc["score"] < split_thr]
    df_orc_high = df_orc.loc[df_orc["score"] >= split_thr]

    CCMPRED_AMINO_ACID_INDEX = collections.OrderedDict(
        {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4,
         'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9,
         'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14,
         'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, '-': 20})


    #############################

    aas = list(CCMPRED_AMINO_ACID_INDEX.keys())[:-1] # excluding '-'
    nmax = len(aas)
    slow = df_orc_low.shape[0]

    inds = np.random.randint(nmax, size=slow*(ext_len-1))
    inds_end = np.random.randint(nmax-1, size=slow) # ecxlude the last aa for the last position
    aas_arr = np.array(aas)

    seq_ext = aas_arr[inds]
    seq_ext_end = aas_arr[inds_end]
    assert seq_ext.shape == (slow*(ext_len-1),)
    assert seq_ext_end.shape == (slow,)
    seq_ext_arr = seq_ext.reshape(slow, (ext_len-1))
    seq_ext_arr_end = seq_ext_end.reshape(slow, 1)
    seq_ext_arr = np.concatenate((seq_ext_arr, seq_ext_arr_end), axis=-1)
    assert seq_ext_arr.shape == (slow, ext_len)

    seq_ext_low = ["".join(list(seqext)) for seqext in seq_ext_arr]
    df_orc_low["ext"] = seq_ext_low
    df_orc_low["sequence"] = df_orc_low["ext"] + df_orc_low["sequence"]
    df_orc_low = df_orc_low[["sequence", "score"]]
    #############################

    # generate the extension for the high end
    inds = np.random.randint(nmax, size=ext_len)
    seq_ext_high = list(aas_arr[inds])
    seq_ext_high = "".join(seq_ext_high)
    seq_ext_high = seq_ext_high[:-1] + aas[-1]
    print(seq_ext_high)
    df_orc_high["sequence"] = seq_ext_high + df_orc_high["sequence"]
    df_orc = pd.concat((df_orc_low, df_orc_high))


    # save the oracle
    orc_savedir = f"{oracle_dir}/{ds_name}"
    df_orc.to_csv(orc_savedir, sep= "\t", index=False)
    
    #######################################################
    # train set generation for different imbalance ratios #
    #######################################################
    
    ros = [0.0125, 0.025, 0.05, 0.1, 0.2, 0.4, 0.8]
    Ntrs = [1000]
    high_end_ranges = [(0.1, 0.2)]
    low_end_ranges = [(0.0001, 0.001)]

    cfg_lists = [ros, Ntrs, high_end_ranges, low_end_ranges]
    cfgs = itertools.product(*cfg_lists)

    orc_specs = defaultdict(list)
    for cfg_it, cfg in enumerate(cfgs):
        ro, Ntr, high_end_range, low_end_range = cfg
        print(f"cfg is {cfg}")

        h1_end, h2_end = high_end_range
        l1_end, l2_end = low_end_range
        trs_low = Ntr
        trs_high = int(Ntr * ro)

        idx_low = np.where((l1_end <= df_orc["score"]) & (df_orc["score"] < l2_end))[0]
        rand_idx_low = np.random.choice(idx_low, size=trs_low, replace=False)

        idx_high = np.where((h1_end <= df_orc["score"]) & (df_orc["score"] <= h2_end))[0]
        rand_idx_high = np.random.choice(idx_high, size=trs_high, replace=False)

        df_orc_low = df_orc.iloc[rand_idx_low, :]
        df_orc_high = df_orc.iloc[rand_idx_high, :]
        df = pd.concat((df_orc_low, df_orc_high))

        #print(df_orc_low.shape[0], df_orc_high.shape[0])

        orc_spec = dict(cfg_it=cfg_it, ro=ro, high_end_range=high_end_range,
                        low_end_range=low_end_range, data_type=data_type, 
                        N=int(Ntr*(1+ro)), orc_path=orc_savedir)

        x_tr = df["sequence"].values
        y_tr = df["score"].values
        np.savez(f"{savedir}/ds{cfg_it}", x=x_tr, y=y_tr, orc_spec=orc_spec)

        # to get all confings in one file
        orc_specs["cfg_it"].append(cfg_it)
        orc_specs["ro"].append(ro)
        orc_specs["high_end_range"].append(high_end_range)
        orc_specs["data_type"].append(data_type)
        orc_specs["low_end_range"].append(low_end_range)
        orc_specs["N"].append(int(Ntr*(1+ro)))

        cfg_all = pd.DataFrame.from_dict(orc_specs, orient='index').transpose()
        cfg_all.columns = ["cfg_it", "ro", "high_end_range", "data_type", "low_end_range", "N"]
        cfg_all.to_csv(f"{savedir}/cfgs", sep="\t")

### Protein AAV Virus:
paper ref: https://www.nature.com/articles/s41587-020-00793-4

In [None]:
rng_seed = 12345
np.random.seed(rng_seed)

data_type = "protein"
ds_name = "protein_aav"
savedir = f"../sample_trainset/{ds_name}"
os.makedirs(savedir, exist_ok=True)

########################
### Build the Oracle ###
########################

oracle_dir = "../oracles/"
os.makedirs(oracle_dir, exist_ok=True)
# data downloaded from the FLIP paper 
# https://github.com/J-SNACKKB/FLIP/tree/main/splits/aav
# https://www.biorxiv.org/content/10.1101/2021.11.09.467890v2.full.pdf
orc_dir = "../datasets/aav.csv"
df_aav = pd.read_csv(orc_dir, sep=",")

In [None]:
df_aav = df_aav.dropna(subset=["number_of_mutations", "score", "full_aa_sequence", "mutation_mask"])
df_aav["full_len"] = df_aav["full_aa_sequence"].apply(lambda x: len(x))
df_aav["mut_len"] = df_aav["mutated_region"].apply(lambda x: len(x))
ref_len = len(df_aav["reference_region"].values[0])
# subset to the mutations that do not change the length
df_aav = df_aav[df_aav["mut_len"] == ref_len]
assert np.all(df_aav["number_of_mutations"].values < 29)
# exclude mutated regions containing *
df_aav = df_aav[~df_aav["mutated_region"].str.contains("\*")]

In [None]:
# make the oracle which contains both designed and sampled mutants
df_orc = df_aav[["mutated_region", "score", "number_of_mutations", "mut_des_split"]] 
df_orc.columns = ["sequence", "score", "number_of_mutations", "mut_des_split"]
# transform the score
min_, max_ = min(df_orc["score"]), max(df_orc["score"])
df_orc["score"] = (df_orc["score"] - min_)/(max_ - min_)

# save the oracle
orc_savedir = f"{oracle_dir}/{ds_name}"
df_orc_ = df_orc[["sequence", "score"]].reset_index(drop=True)
df_orc_.to_csv(orc_savedir, sep= "\t", index=False)

In [None]:
if False:
    # check the properties of the data
    df_ = df_orc[df_orc["mut_des_split"] == "train"] # mut: sampled (train), des: designed (test)
    fig, ax = plt.subplots()
    hh = ax.hist2d(x=df_["score"], y=df_["number_of_mutations"], bins=10)
    fig.colorbar(hh[3], ax=ax)

    df_pos = df_[df_["score"] > -min_/(max_ - min_)] # zero is df_aav is transformed to -min_/(max_ - min_) in df_orc
    df_neg = df_[df_["score"] <= -min_/(max_ - min_)]
    ctr = df_pos[["number_of_mutations"]].values
    plt.figure()
    plt.hist(ctr)
    plt.title("score > 0")

    ctr = df_neg[["number_of_mutations"]].values
    plt.figure()
    plt.hist(ctr)
    plt.title("score < 0")

In [None]:
df_ = df_orc[df_orc["mut_des_split"] == "train"] # mut: sampled (train), des: designed (test)
df_pos = df_[df_["score"] > -min_/(max_ - min_)] # zero in df_aav is transformed to -min_/(max_ - min_) in df_orc
df_neg = df_[df_["score"] <= -min_/(max_ - min_)]

Ntrs = [1000]
ros = [0.0125, 0.025, 0.05, 0.1, 0.2, 0.4, 0.8]
high_end_percs = [(5, 10)]
mut_thrs = [10, 8, 6]
cfg_lists = [ros, Ntrs, high_end_percs, mut_thrs]
cfgs = list(itertools.product(*cfg_lists))

cfg_lists_add = [ros, Ntrs, high_end_percs, [15, 12]]
cfgs_add = list(itertools.product(*cfg_lists_add))

cfgs = cfgs + cfgs_add

In [None]:
orc_specs = defaultdict(list)
for cfg_it, cfg in enumerate(cfgs):
    ro, Ntr, high_end_perc, mut_thr = cfg
    print(f"cfg is {cfg}")
    tr_high = int(Ntr*ro)

    df_neg_tmp = df_neg[df_neg["number_of_mutations"] > mut_thr]
    #print(f"neg score size {df_neg_tmp.shape}, pos score size {df_pos.shape}")
    score_pos = df_pos["score"]

    h1_perc, h2_perc = high_end_perc
    h1_score, h2_score = np.percentile(score_pos, h1_perc), np.percentile(score_pos, h2_perc)
    idx_high = np.where((df_pos.score >= h1_score) & (df_pos.score < h2_score))[0]
    #print(f"idx_high size {idx_high.size}")

    # take sample from the high/pos section
    rand_idx_high = np.random.choice(idx_high, size=tr_high, replace=False)
    df_pos_samp = df_pos.iloc[rand_idx_high, :]

    # take sample from the low/neg region
    rand_idx_low = np.random.choice(np.arange(df_neg_tmp.shape[0]), size=Ntr, replace=False)
    df_neg_samp = df_neg_tmp.iloc[rand_idx_low, :]

    df_all = pd.concat((df_pos_samp, df_neg_samp), axis=0).reset_index(drop=True)
    x_tr = df_all["sequence"].values
    y_tr = df_all["score"].values

    orc_spec = dict(cfg_it=cfg_it, ro=ro, high_end_perc=high_end_perc,
                    mut_thr=mut_thr, data_type=data_type, 
                    N=int(Ntr*(1+ro)), orc_path=orc_savedir)
    np.savez(f"{savedir}/ds{cfg_it}", x=x_tr, y=y_tr, orc_spec=orc_spec)
    
    # to get all confings in one file
    orc_specs["cfg_it"].append(cfg_it)
    orc_specs["ro"].append(ro)
    orc_specs["high_end_perc"].append(high_end_perc)
    orc_specs["data_type"].append(data_type)
    orc_specs["mut_thr"].append(mut_thr)
    orc_specs["N"].append(int(Ntr*(1+ro)))
    
    cfg_all = pd.DataFrame.from_dict(orc_specs, orient='index').transpose()
    cfg_all.columns = ["cfg_it", "ro", "high_end_perc", "data_type", "mut_thr", "N"]
    cfg_all.to_csv(f"{savedir}/cfgs", sep="\t") 

### GMM:

In [None]:
# gmm oracle
def oracle_gmm(x, mus, sigmas, weights):
    (n_seeds, xs, *xd) = x.shape
    n_gmm = mus.shape[0]
    assert mus.shape == (n_gmm, *xd)
    assert sigmas.shape == (n_gmm, *xd)
    assert weights.shape == (n_gmm,)

    # for all mus and xs compute x-mu
    x = x.reshape((n_seeds, xs, 1, *xd))
    mus = mus.reshape((1, 1, n_gmm,  *xd))
    sigmas = sigmas.reshape((1, 1, n_gmm,  *xd))

    power = (x - mus)/sigmas
    power = power**2
    assert power.shape == (n_seeds, xs, n_gmm, *xd)
    power = torch.sum(power, dim=-1)
    assert power.shape == (n_seeds, xs, n_gmm)
    weights = weights.reshape((1, 1, n_gmm))
    y = weights*torch.exp(-0.5*power)
    assert y.shape == (n_seeds, xs, n_gmm)
    y = torch.sum(y, dim=-1)
    assert y.shape == (n_seeds, xs)
    return y

In [None]:
seed_rng = 12345
torch.manual_seed(seed_rng)
np.random.seed(seed_rng)

data_type = "gmm"
ds_name = data_type
savedir = f"../sample_trainset/{ds_name}"
os.makedirs(savedir, exist_ok=True)

mu_2nds = [4, 6, 8, 10, 12] #list(np.arange(5, 51, 10))
ros = [0.05, 0.1, 0.2, 0.4, 0.8, 1.]
cfg_lists = [mu_2nds, ros]
cfgs = itertools.product(*cfg_lists)

In [None]:
orc_specs = defaultdict(list)
for cfg_it, cfg in enumerate(cfgs):
    mu_2nd, ro = cfg
    print(f"cfg is {cfg}")

    # visualize the oracle
    mu_1st = 0
    mu_shift = 0

    xd = 1
    n_gmm = 2
    mu_2nd = mu_2nd + mu_shift
    mu_1st = mu_1st + mu_shift
    
    sigmas_gmm = torch.tensor([.25, 1.])
    weights = torch.tensor([1., 2.5])

    mus_gmm = [mu_1st*torch.ones(1, xd), mu_2nd*torch.ones(1, xd)]
    mus = torch.cat(mus_gmm, dim=0)
    assert mus.shape == (n_gmm, xd)
    sigmas = sigmas_gmm.reshape(-1, 1).broadcast_to(n_gmm, xd)
    
    # define a distribution to sample x from
    # take samples from the left gaussian
    N1 = 200
    mu1_sample = mu_1st
    sigma1_sample = .6
    x1_samples = mu1_sample + sigma1_sample*torch.randn((N1, 1))

    # take samples from the right gaussian
    N2 = int(N1*ro)
    start = sigmas_gmm[-1]/2
    end = sigmas_gmm[-1]
    while True:
        # make sure that the max y from the second gaussian
        # is above the max of the first gaussian
        x2_samples_shift = start + (end - start)*torch.rand((N2, 1))
        x2_samples = mu_2nd + x2_samples_shift
        x2_tmp = torch.unsqueeze(x2_samples, dim=0)
        y2_tmp = oracle_gmm(x2_tmp, mus=mus, sigmas=sigmas, weights=weights)
        ymax = torch.max(y2_tmp)
        if (ymax > mu_1st):
            break

    X = torch.cat((x2_samples, x1_samples))
    X = torch.unsqueeze(X, dim=0)
    N = X.numel()

    Y_mean = oracle_gmm(X, mus=mus, sigmas=sigmas, weights=weights).reshape(-1, 1) # (N,1)
    #print(f"max of y is {torch.max(Y_mean):.2f}")

    x = X.detach().numpy().reshape(-1, 1)
    y = Y_mean.detach().numpy().reshape(-1)
    sigmas_gmm = sigmas_gmm.detach().numpy()
    weights = weights.detach().numpy()

    orc_spec = dict(cfg_it=cfg_it, mu_1st=mu_1st, mu_2nd=mu_2nd, mu_shift=mu_shift,
                    ro=ro, data_type=data_type, xd=xd, n_gmm=n_gmm,
                    sigmas_gmm=sigmas_gmm, weights=weights, N1=N1, N=N1+N2)
    np.savez(f"{savedir}/ds{cfg_it}", x=x, y=y, orc_spec=orc_spec)
    
    # to get all confings in one file
    orc_specs["cfg_it"].append(cfg_it)
    orc_specs["mu_1st"].append(mu_1st)
    orc_specs["mu_2nd"].append(mu_2nd)
    orc_specs["mu_shift"].append(mu_shift)
    orc_specs["ro"].append(ro)
    orc_specs["data_type"].append(data_type)
    orc_specs["xd"].append(xd)
    orc_specs["n_gmm"].append(n_gmm)
    orc_specs["sigmas_gmm"].append(sigmas_gmm)
    orc_specs["weights"].append(weights)
    orc_specs["N1"].append(N1)
    orc_specs["N"].append(N1+N2)
    
    cfg_all = pd.DataFrame.from_dict(orc_specs, orient='index').transpose()
    cfg_all.columns = ["cfg_it", "mu_1st", "mu_2nd", "mu_shift", "ro", "data_type", 
                       "xd", "n_gmm", "sigmas_gmm", "weights", "N1", "N"]
    cfg_all.to_csv(f"{savedir}/cfgs", sep="\t")

### PINN Poisson:

In [None]:
seed_rng = 12345
np.random.seed(seed_rng)

data_type = "pinn"
ds_name = data_type
savedir = f"../sample_trainset/{ds_name}"
os.makedirs(savedir, exist_ok=True)

high_perc_ranges = [(30, 40), (40, 50), (50, 60), (60, 70), (70, 80)]
ros = [0.05, 0.1, 0.2, 0.4, 0.8]
cfg_lists = [high_perc_ranges, ros]
cfgs = itertools.product(*cfg_lists)

orc_specs = defaultdict(list)
for cfg_it, cfg in enumerate(cfgs):
    high_perc_range, ro = cfg
    low_perc = 15
    print(f"cfg is {cfg}")

    data = np.load("../datasets/pinn_poisson.npz")
    x = data['x']
    modeld, sold = x.shape # models by solutions
    xbar = np.mean(x, axis=-1).reshape((modeld, 1))
    x = x - xbar
    assert x.shape == (modeld, sold)
    
    orc = data["gt"]
    dsr = int(sold**0.5)
    orc_re = orc[0].reshape(dsr, dsr)
    w_re = -orc_re
    w_re = w_re - w_re.min()
    w_re = w_re**2
    w_re = w_re/w_re.sum()
    
#     plt.pcolormesh(w_re, cmap="RdBu")
#     plt.colorbar()

    seed = data["rng_seed"]
    seed_unq = np.unique(seed)

    orc = data["gt"]
    orc = orc - np.mean(orc, axis=-1).reshape((modeld,1))
    epochs = data["epoch"]

    w = w_re.reshape(1, sold)
    y_mse_w = np.sum(np.square(x - orc) * w, axis=-1)
    y_mae = np.abs(x - orc).max(axis=-1) # max or mean
    y = y_mse_w # this metric distinguishes the quality of solutions
    
    mylog= -np.log10(y)
    thr1 = np.percentile(mylog, low_perc)
    
    low_end, high_end = high_perc_range
    thr2 = np.percentile(mylog, low_end)
    thr3 = np.percentile(mylog, high_end)

    train_size = 1000
    perc = 1 - ro
    train_size1 = int(train_size*perc)
    train_size2 = train_size - train_size1

    idx = np.where(mylog <= thr1)[0]
    rand_idx1 = np.random.choice(idx, size=train_size1, replace=False)

    idx = np.where((mylog >= thr2) & (mylog <= thr3))[0]
    rand_idx2 = np.random.choice(idx, size=train_size2, replace=False)
    
    rand_idx_comb = np.concatenate((rand_idx1, rand_idx2))
    x_tr = x[rand_idx_comb, :]
    y_tr = mylog[rand_idx_comb]
    
    orc_spec = dict(cfg_it=cfg_it, ro=ro, high_perc_range=high_perc_range,
                    data_type=data_type, N=train_size, low_perc=low_perc)

    #print(f"max -log mse ", np.max(y_tr))
    np.savez(f"{savedir}/ds{cfg_it}", x=x_tr, y=y_tr, orc_spec=orc_spec)
    
    # to get all confings in one file
    orc_specs["cfg_it"].append(cfg_it)
    orc_specs["ro"].append(ro)
    orc_specs["high_perc_range"].append(high_perc_range)
    orc_specs["data_type"].append(data_type)
    orc_specs["low_perc"].append(low_perc)
    orc_specs["N"].append(train_size)
    
    cfg_all = pd.DataFrame.from_dict(orc_specs, orient='index').transpose()
    cfg_all.columns = ["cfg_it", "ro", "high_perc_range", "data_type", "low_perc", "N"]
    cfg_all.to_csv(f"{savedir}/cfgs", sep="\t")