# Split the gold standard into training and holdout datasets

In [1]:
import pandas as pd

from itertools import product

## Define inputs

In [2]:
nodes_fname = "data/min_hetionet/minhet_nodes.tsv"

edges_fname = "data/min_hetionet/minhet_edges.tsv"

gold_fname = "data/hetionet/goldstd.tsv"

In [3]:
nodes = pd.read_csv(nodes_fname, sep='\t')

In [4]:
nodes.head()

Unnamed: 0,node_uid,node_id,name,het_type
0,0,10,NAT2,Gene
1,1,100,ADA,Gene
2,2,10000,AKT3,Gene
3,3,10005,ACOT8,Gene
4,4,10007,GNPDA1,Gene


In [5]:
edges = pd.read_csv(edges_fname, sep='\t')

In [6]:
edges.head()

Unnamed: 0,start_id,end_id,het_etype,start_htype,end_htype
0,DB00643,51547,UPREGULATES_CuG,Compound,Gene
1,DB08881,10450,UPREGULATES_CuG,Compound,Gene
2,DB01211,10450,DOWNREGULATES_CdG,Compound,Gene
3,DB00374,10450,DOWNREGULATES_CdG,Compound,Gene
4,DB00398,10450,UPREGULATES_CuG,Compound,Gene


In [7]:
gold = (pd
    .read_csv(gold_fname, sep='\t')
    .rename(columns={"category": "etype"})
)

In [8]:
gold.head()

Unnamed: 0,disease_id,chemical_id,disease_name,drug_name,etype
0,DOID:10652,DB00843,Alzheimer's disease,Donepezil,1
1,DOID:10652,DB00674,Alzheimer's disease,Galantamine,1
2,DOID:10652,DB01043,Alzheimer's disease,Memantine,1
3,DOID:10652,DB00989,Alzheimer's disease,Rivastigmine,1
4,DOID:9206,DB00736,Barrett's esophagus,Esomeprazole,1


---

In [9]:
def all_pairs(df):
    chem = set(df["chemical_id"])
    dise = set(df["disease_id"])
    
    return set(product(chem, dise))

def df_to_pairs(df):
    return set(
        (row.chemical_id, row.disease_id)
        for row in df.itertuples()
    )

def pairs_to_df(pairs):
    return pd.DataFrame(list(pairs), columns=["chemical_id", "disease_id"])

In [10]:
def split_data(gold, holdout_ratio):
    
    # any relations between the existing chemicals and diseases
    # in the gold standard will be assumed to be false for
    # testing purposes
    
    # this set of presumed negatives will be used to train the
    # classifier
    assumed_false = all_pairs(gold) - df_to_pairs(gold)
    
    
    pos_holdout = gold.sample(frac = holdout_ratio)
    neg_holdout = pairs_to_df(assumed_false & all_pairs(pos_holdout))
    
    
    # get all gold standard positive examples which aren't in the holdout

    pos_train = (gold
        .merge(
            pos_holdout[["chemical_id", "disease_id"]],
            how="outer", on=["chemical_id", "disease_id"],
            indicator=True
        )
        .query("_merge == 'left_only'")
        .drop("_merge", axis=1)
    )
    
    neg_train = (
        (assumed_false & all_pairs(pos_train))
        - df_to_pairs(pos_holdout)
        - df_to_pairs(neg_holdout)
    )
    
    neg_train = pairs_to_df(neg_train)
    
    
    # validation (checking we did things right)
    assert df_to_pairs(pos_train).isdisjoint(df_to_pairs(neg_train))
    assert df_to_pairs(pos_holdout).isdisjoint(df_to_pairs(neg_holdout))
    
    
    train_pairs = df_to_pairs(pos_train) | df_to_pairs(neg_train)
    holdout_pairs = df_to_pairs(pos_holdout) | df_to_pairs(neg_holdout)
    
    assert train_pairs.isdisjoint(holdout_pairs)
    
    
    return (pos_train, neg_train, pos_holdout, neg_holdout)

## Main loop

should we subsample the holdout negatives too? around >95% of the holdout are negatives.. which might mess with things

In [11]:
K = 5
HOLDOUT_RATIO = 0.2
NEG_TRAIN_RATIO = 4

train_fname = "data/min_hetionet/test/train_{}.tsv"
holdout_fname = "data/min_hetionet/test/holdout_{}.tsv"

for idx in range(K):
    print("Splitting data for fold {}".format(idx))
    
    pos_train, neg_train, pos_holdout, neg_holdout = split_data(gold, HOLDOUT_RATIO)
    
    neg_train = (neg_train
        .sample(n = len(pos_train) * NEG_TRAIN_RATIO)
        .assign(etype = 0)
    )

    #------------------------------------------------------------------
    
    train = (pos_train
        [["chemical_id", "disease_id", "etype"]]
        .append(neg_train)
        .reset_index(drop=True)
    )
    
    holdout = (pos_holdout
        [["chemical_id", "disease_id", "etype"]]
        .append(
            neg_holdout.assign(etype = 0)
        )
        .reset_index(drop=True)
    )
    
    train.to_csv(train_fname.format(idx), sep='\t', index=False)
    holdout.to_csv(holdout_fname.format(idx), sep='\t', index=False)

Splitting data for fold 0
Splitting data for fold 1
Splitting data for fold 2
Splitting data for fold 3
Splitting data for fold 4
