In [1]:
import os, collections, itertools

import numpy as np
import pandas as pd


In [2]:
pop_def_d = os.path.join(
    '/oak/stanford/groups/mrivas/ukbb24983/sqc'
)

# input
GWAS_covar_f = os.path.join(
    pop_def_d,
    'population_stratification_w24983_20200828',
    'ukb24983_GWAS_covar.20200828.phe'
)

# output
GWAS_covar_nonWBsplit_f = os.path.join(
    pop_def_d,
    'population_stratification_w24983_20211020',
    'ukb24983_GWAS_covar.20211020.phe'
)


In [3]:
def train_val_test_split(n_individuals, rg, split_ratio={'train':.7, 'val':.1, 'test':.2}):
    assert np.sum(np.array([x for x in split_ratio.values()])) == 1
    indiv_idx = np.arange(n_individuals)
    rg.shuffle(indiv_idx)
    train_idx = int(n_individuals * split_ratio['train'])
    train_val_idx = int(n_individuals * (split_ratio['train'] + split_ratio['val']))
    train = indiv_idx[:train_idx]
    val   = indiv_idx[train_idx:train_val_idx]
    test  = indiv_idx[train_val_idx:]
    assert np.sum([len(train), len(val), len(test)]) == n_individuals
    return train,val,test


In [4]:
# read input file
GWAS_covar_df = pd.read_csv(GWAS_covar_f, sep='\t')


In [5]:
# fix seed
# https://numpy.org/neps/nep-0019-rng-policy.html
bg = np.random.MT19937(20211020)
rg = np.random.Generator(bg)

split_dict = collections.defaultdict(set)
nonWB_pops = ['non_british_white', 'african', 's_asian', 'e_asian', 'related', 'others']


In [6]:
# split non-WB pops into train/val/test = 0.7/0.1/0.2
for pop in nonWB_pops:
    filtered_df = GWAS_covar_df.loc[GWAS_covar_df['population'] == pop, ]
    train,val,test = train_val_test_split(filtered_df.shape[0], rg)

    train_IIDs = set(filtered_df.iloc[train, np.in1d(filtered_df.columns, ['IID'])].values.flatten())
    val_IIDs   = set(filtered_df.iloc[val,   np.in1d(filtered_df.columns, ['IID'])].values.flatten())
    test_IIDs  = set(filtered_df.iloc[test,  np.in1d(filtered_df.columns, ['IID'])].values.flatten())

    split_dict['train'].update(set(train_IIDs))
    split_dict['val'  ].update(set(val_IIDs))
    split_dict['test' ].update(set(test_IIDs))


In [7]:
# create a new column
GWAS_covar_df['split_nonWB'] = GWAS_covar_df['IID'].map(
    lambda iid: (
        'train' if iid in split_dict['train'] else (
            'val' if iid in split_dict['val'] else (
                'test' if iid in split_dict['test'] else 'NA'
            )
        )
    )
)


In [8]:
# check the results
GWAS_covar_df.groupby(['population', 'split_nonWB']).size()

population         split_nonWB
DO_NOT_PASS_SQC    NA               1931
african            test             1300
                   train            4547
                   val               650
e_asian            test              341
                   train            1192
                   val               171
non_british_white  test             4981
                   train           17433
                   val              2491
others             test             5732
                   train           20059
                   val              2865
related            test             8927
                   train           31242
                   val              4463
s_asian            test             1567
                   train            5481
                   val               783
white_british      NA             337129
dtype: int64

In [9]:
# write to a file
out_cols = [x for x in itertools.chain(*[GWAS_covar_df.columns[:4], ['split_nonWB'], GWAS_covar_df.columns[4:-1]])]
GWAS_covar_df[out_cols].to_csv(GWAS_covar_nonWBsplit_f, sep='\t', index=False)


In [10]:
GWAS_covar_nonWBsplit_f


'/oak/stanford/groups/mrivas/ukbb24983/sqc/population_stratification_w24983_20200828/ukb24983_GWAS_covar.20200828.nonWBsplit.20211020.phe'