In [1]:
import networkx as nx
import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

import graph_learning
import utils

In [2]:
df_syn = utils.get_synergyage_subset_with_biogrid_gene_names()
df_biogrid = utils.load_biogrid()

data_list = utils.build_datalist(df_syn, df_biogrid, subgraph=True)
utils.save_datalist(data_list, "synergyage_1", override=True)

In [3]:
data_list = utils.load_datalist("./data/datalists/synergyage_1")

In [4]:
from typing import List
from torch_geometric.data import Data

def get_balanced_data_list(data_list, sample_size_per_class) -> List[Data]:
    class_array: np.array = np.array(list(map(lambda data: data.y.numpy(), data_list)))
    class_indexes = np.array(list(map(lambda row_array: np.argmax(row_array), class_array)))
    classes = set(class_indexes)
    samples = {}
    
    for c in classes:
        indexes = (np.argwhere(class_indexes == c))
        choice = np.random.choice(indexes.flatten(), sample_size_per_class)
        samples[c] = choice
    
    sample_data_list = []
    for choice_list in samples.values():
        for idx in choice_list:
            sample_data_list += [data_list[idx]]
    
    return sample_data_list


In [7]:
SAMPLE_PER_CLASS = 100
NUM_EPOCHS = 30
BATCH_SIZE = 8

data_list_sample = get_balanced_data_list(data_list, SAMPLE_PER_CLASS)

loader: DataLoader = graph_learning.loader_from_datalist(data_list_sample, batch_size=BATCH_SIZE)

In [9]:
model = graph_learning.create_GNN_model(1, 10, 3)
# model = graph_learning.create_GAT_model(1, 10, 3)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")
graph_learning.train_graph_classification(model, loader, loader, num_epochs = NUM_EPOCHS, weighted_loss=True)

Number of trainable parameters: 163
[Epoch 1/30] Loss: 0.3879 | Train: 0.33 0.33 0.33 | Test: 0.33 0.33 0.33 |
Confusion_matrix: 
[[100   0   0]
 [100   0   0]
 [100   0   0]]



KeyboardInterrupt: 

In [10]:
loader.dataset

[Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=1066, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[4], source_idx=2563, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=96, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=67, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=2087, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=2549, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=2962, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=668, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=606, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=2857, y=[3]),
 Data(edge_index=[2, 61342], x=[6893, 1], intervention_genes=[2], source_idx=673, y=[3]),
 Data(

In [11]:
synergy_genes = utils.get_synergyage_genes(df_syn)

In [15]:
np.random.choice(list(synergy_genes), size=10)

array(['ced-4', 'par-5', 'sqt-1', 'daf-2', 'nuc-1', 'skn-1', 'mes-4',
       'daf-18', 'nuo-5', 'cmk-1'], dtype='<U10')