In [39]:
from typing import Dict, Iterable, Tuple, List
import numpy as np
import pandas as pd
import os
import torch 

def split_indices_by_chrom(indices, ratios, seed):
    """
    Split the given indices into train/val/test with the requested ratios,
    performed independently within each chromosome range, then concatenated.

    Parameters
    ----------
    indices : iterable of int or pd.Index
        The indices you want to split.
    res : str
        1kb or 5kb
    ratios : tuple
        (train, val, test) ratios; must sum to 1.0 (within float tolerance).
    seed : int
        RNG seed for reproducibility.

    Returns
    -------
    dict
        {"train": pd.Index, "val": pd.Index, "test": pd.Index}
    """
    chrom_ranges = {
        "chr2L": (0,      23513),
        "chr2R": (23514,  48800),
        "chr3L": (48801,  76911),
        "chr3R": (76912,  108991),
        "chr4":  (108992, 110340),
        "chrX":  (110341, 133883),
        "chrY":  (133884, 137551),
    }

    #lists of each training/validation/testing set
    train_all: List[int] = []
    val_all:   List[int] = []
    test_all:  List[int] = []

    rng = np.random.default_rng(seed) #set the seed

    for chrom, (start, end) in chrom_ranges.items():
        #range filter to filter indices within the chromosome bounds
        mask = (indices >= start) & (indices <= end)
        chrom_idx = indices[mask].to_numpy()
        if chrom_idx.size == 0: #indices do not fall within chromosome category
            continue
        
        #shuffle indexes within chromosomes
        rng.shuffle(chrom_idx)

        #establishes the size for each of the splits
        n = chrom_idx.size
        n_train = int(np.floor(n * ratios[0]))
        n_val   = int(np.floor(n * ratios[1]))
        n_test  = n - n_train - n_val  # ensure total matches
            
        train_all.extend(chrom_idx[:n_train])
        val_all.extend(chrom_idx[n_train:n_train + n_val])
        test_all.extend(chrom_idx[n_train + n_val:])

    return {
        #the sorted function places the idx values in chronological order
        "train": pd.Index(sorted(train_all), dtype="int64"),
        "val":   pd.Index(sorted(val_all),   dtype="int64"),
        "test":  pd.Index(sorted(test_all),  dtype="int64"),
    }

In [40]:
def load_datasets(p_value, resolution):
    os.chdir('/oscar/data/larschan/shared_data/BindGPS/data/datasets')
    node_df = pd.read_pickle(f"node_dataset_{resolution}.pkl")
    edge_df = pd.read_pickle(f"edge_dataset_{p_value}_{resolution}.pkl")
    return node_df, edge_df

In [41]:
# p_value = '0_1' #0.1 pvalue
# resolution = '5kb'
# node_df, edge_df = load_datasets(p_value, resolution)

In [42]:
display(node_df)

Unnamed: 0,chr,start,end,counts,expression_level,gene_in_bin,gene_id,DNA sequence,clamp,gaf,...,h3k27me3,h3k36me3,h3k4me1,h3k4me2,h3k4me3,h3k9me3,h4k16ac,psq,gene_labels,mre_labels
0,chr2L,0,5000,0.0,0.0,False,0,Cgacaatgcacgacagaggaagcagaacagatatttagattgcctc...,0.011338,0.029695,...,0.000000,20.003200,0.0,0.000000,0.0,0.000000,0.0,30.074890,0,1
1,chr2L,5000,10000,19.0,2.0,True,FBgn0002121,TGCCTCTCATTCTGTCTTATTTTACCGCAAACCCAAatcgacaatg...,20.016980,20.031764,...,10.088679,290.071502,1.0,0.000000,0.0,20.028571,0.0,70.098778,1,1
2,chr2L,10000,15000,0.0,0.0,False,0,GAGGAGAATGCAAAAAAGCTAAGAACAAAACAATTACTACAAATCG...,10.023666,10.011295,...,0.000000,140.053661,0.0,110.071154,0.0,0.000000,0.0,70.025733,1,0
3,chr2L,15000,20000,0.0,0.0,False,0,ATTCGACGGCGGTTCTGGGTTATCTATGCTCCAAGTGGCGTATGAA...,0.072610,0.058189,...,0.000000,200.062504,0.0,130.005405,0.0,0.000000,0.0,60.001147,1,1
4,chr2L,20000,25000,0.0,1.0,True,"FBgn0031209,FBgn0263584",GTGGCCGAATTTATTCTAAACTGAAAATAATAATAAAAATTAATCA...,0.079141,0.071863,...,10.025000,160.083882,0.0,0.000000,0.0,0.000000,0.0,70.087586,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
27508,chrY,3645000,3650000,0.0,1.0,True,FBgn0267592,NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...,0.000641,0.000349,...,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,0.000000,0,0
27509,chrY,3650000,3655000,0.0,0.0,False,0,GTTCTCCACACAAAAAGAATTTTTTCATATACCCTATATAAACGAA...,0.059530,0.050137,...,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,10.009552,0,1
27510,chrY,3655000,3660000,0.0,0.0,False,0,acacggagtaaaaatccgcccagtttgcttagcctccgccaaacgt...,0.000308,0.000322,...,0.000000,0.000000,0.0,0.000000,1.0,0.000000,0.0,0.017110,0,1
27511,chrY,3660000,3665000,0.0,0.0,False,0,NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...,0.000460,0.000196,...,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,0.000000,0,0


In [43]:
edge_df.head(5)

Unnamed: 0,chr1,start1,end1,chr2,start2,end2,contactCount,p-value,q-value,bias1,bias2,ExpCC,loop_size,bin1,bin2,p-value_transformed
0,chr2L,5000,10000,chr2L,15000,20000,142,3.591059e-15,1.69565e-10,0.945941,0.921175,68.049704,5.0,1,3,14.444777
1,chr2L,5000,10000,chr2L,55000,60000,24,0.01801108,1.0,0.945941,1.265691,14.8872,45.0,1,11,1.74446
2,chr2L,5000,10000,chr2L,60000,65000,16,0.07213755,1.0,0.945941,1.003579,10.5853,50.0,1,12,1.141839
3,chr2L,5000,10000,chr2L,80000,85000,13,0.0539825,1.0,0.945941,1.046688,7.784579,70.0,1,16,1.267747
4,chr2L,5000,10000,chr2L,90000,95000,14,0.02098138,1.0,0.945941,1.151248,7.472074,80.0,1,18,1.678166


In [45]:
#define variables
target = node_df['counts']
mask = node_df['gene_in_bin']
# mask = node_df['expression_level']
input_features = node_df.loc[:,'clamp':'psq']

#p-value threasholding (sampler will take care of this)
######################################################################

y = target
all_mask = mask
mask_idx = all_mask.index
splits = split_indices_by_chrom(mask_idx, ratios=(0.70, 0.15, 0.15), seed=42) # splitting by chrom, see cell above
train_idx, val_idx, test_idx = splits["train"], splits["val"], splits["test"]
train_mask = np.full(all_mask.shape[0], False)
train_mask[train_idx] = True
val_mask = np.full(all_mask.shape[0], False)
val_mask[val_idx] = True
test_mask = np.full(all_mask.shape[0], False)
test_mask[test_idx] = True

# X = input_features
# edges = edge_df.loc[:,'bin1':'bin2']
# edge_weight = edge_df['p-value_transformed']

# # to torch
# X = torch.tensor(X.to_numpy())
# y = torch.tensor(y.to_numpy())
# all_mask= torch.tensor(all_mask.to_numpy())
# train_mask = torch.tensor(train_mask)
# val_mask = torch.tensor(val_mask)
# test_mask = torch.tensor(test_mask)
# edge_index = torch.tensor(edges.transpose().to_numpy())
# edge_weight = torch.tensor(edge_weight.transpose().to_numpy())

# # out
# print("Node features:\t",X.shape)
# print("All Mask:\t",all_mask.shape)
# print("Train Mask:\t",train_mask.shape)
# print("Val Mask:\t",train_mask.shape)
# print("Test Mask:\t",train_mask.shape)
# print("Target Variable:",y.shape)
# print("Edge weights:\t",edge_weight.shape)
# print("Edge indices:\t",edges.shape)

Int64Index([    2,     4,     5,     6,     7,     8,     9,    11,    12,
               13,
            ...
            27496, 27497, 27502, 27503, 27505, 27506, 27507, 27508, 27510,
            27511],
           dtype='int64', length=19258)
[False False False ... False False False]
[False False  True ...  True  True False]


In [16]:
splits

{'train': Int64Index([    2,     4,     5,     6,     7,     8,     9,    11,    12,
                13,
             ...
             27496, 27497, 27502, 27503, 27505, 27506, 27507, 27508, 27510,
             27511],
            dtype='int64', length=19258),
 'val': Int64Index([    0,     1,     3,    19,    29,    38,    40,    44,    46,
                67,
             ...
             27399, 27426, 27439, 27445, 27474, 27479, 27481, 27504, 27509,
             27512],
            dtype='int64', length=4126),
 'test': Int64Index([   10,    22,    23,    27,    33,    52,    53,    55,    60,
                61,
             ...
             27458, 27459, 27478, 27486, 27488, 27493, 27498, 27499, 27500,
             27501],
            dtype='int64', length=4129)}

In [None]:
# from torch_geometric.data import Data

# data = Data(
#     x=X,
#     train_mask=train_mask,
#     all_mask=all_mask,
#     val_mask=val_mask,
#     test_mask=test_mask,
#     edge_index=edge_index,
#     edge_weight=edge_weight,
#     y=yf
# )
