# Splitting Train and Test Sets based on clusters

### Create dataframe from csv file

In [3]:
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit

In [4]:
clustered_chemicals = pd.read_csv("/home/raldisi/Desktop/Clustered_chemicals.csv")

### Splitting training and testing sets based on clusters

In [6]:
train_inds, test_inds = next(GroupShuffleSplit(test_size=.20, n_splits=2, random_state = 7).split(clustered_chemicals, groups=clustered_chemicals['Cluster']))

train = clustered_chemicals.iloc[train_inds]
test = clustered_chemicals.iloc[test_inds]

### Splitting training and testing sets based on weighted clusters

In [1]:
import pandas as pd
import pybel
import networkx as nx

In [2]:
full_graph = pybel.from_pickle("/home/raldisi/Desktop/full_graph.pickle")

In [None]:
clusters_dict = {i : clustered_chemicals['PubchemID'].loc[clustered_chemicals['Cluster'] == i].tolist()
                for i in range(1,clustered_chemicals.Cluster.nunique()+1)}

In [None]:
subgraphs_dict = {}
for cluster, chemicals in clusters_dict.items():
    chemicals_subgraph = []
    for chemical in chemicals:
        chemical = pybel.dsl.Abundance(namespace='pubchem', name=str(chemical))
        if chemical not in full_graph.nodes():
            continue
        for neighbor in full_graph.neighbors(chemical):
            chemicals_subgraph.append(neighbor)
    subgraphs_dict[cluster] = list(dict.fromkeys(chemicals_subgraph)) # to remove duplicates

In [None]:
fullgraph_edges = len(full_graph.edges())
cluster_weights = {}
for cluster, nodes in tqdm(subgraphs_dict.items()):
    subgraph = full_graph.subgraph(nodes)
    edges = len(subgraph.edges())
    cluster_weights[cluster] = edges/fullgraph_edges

In [None]:
# next steps:
## counting total number of edges in each cluster
## creating weight calculation based on the edge count in each cluster
## split sets based on weights