In [1]:
import networkx as nx
import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

import graph_learning
import utils

In [2]:
df_syn = utils.get_synergyage_subset_with_biogrid_gene_names()
df_biogrid = utils.load_biogrid()

In [3]:
# data_list = utils.build_datalist(df_syn, df_biogrid, subgraph=True, reduced_synergyage=True)
# utils.save_datalist(data_list, "synergyage_1.1", override=True)

In [5]:
data_list = utils.load_datalist("./data/datalists/synergyage_1.1")
print(f"Node feature matrix shape: {data_list[0].x.shape}")

Node feature matrix shape: torch.Size([4719, 1])


In [6]:
from typing import List
from torch_geometric.data import Data

def get_balanced_data_list(data_list, sample_size_per_class) -> List[Data]:
    class_array: np.array = np.array(list(map(lambda data: data.y.numpy(), data_list)))
    class_indexes = np.array(list(map(lambda row_array: np.argmax(row_array), class_array)))
    classes = set(class_indexes)
    samples = {}
    
    for c in classes:
        indexes = (np.argwhere(class_indexes == c))
        choice = np.random.choice(indexes.flatten(), sample_size_per_class)
        samples[c] = choice
    
    sample_data_list = []
    for choice_list in samples.values():
        for idx in choice_list:
            sample_data_list += [data_list[idx]]
    
    return sample_data_list

In [7]:
SAMPLE_PER_CLASS = 100
NUM_EPOCHS = 30
BATCH_SIZE = 8

data_list_sample = get_balanced_data_list(data_list, SAMPLE_PER_CLASS)

loader: DataLoader = graph_learning.loader_from_datalist(data_list_sample, batch_size=BATCH_SIZE)

In [17]:
import torch_geometric.nn as geo_nn
import torch.nn as nn

model = graph_learning.create_GNN_model(1, 10, 3)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")
graph_learning.train_graph_classification(
    model, 
    loader, 
    loader, 
    num_epochs = NUM_EPOCHS,
    learning_rate = 0.01,
    weighted_loss=False)

Number of trainable parameters: 103
[Epoch 1/30] Loss: 1.1917 | Train: 0.33 0.33 0.33 | Test: 0.33 0.33 0.33 |
Confusion_matrix: 
[[  0 100   0]
 [  0 100   0]
 [  0 100   0]]

[Epoch 2/30] Loss: 1.0998 | Train: 0.33 0.33 0.33 | Test: 0.33 0.33 0.33 |
Confusion_matrix: 
[[  0 100   0]
 [  0 100   0]
 [  0 100   0]]

[Epoch 3/30] Loss: 1.0788 | Train: 0.33 0.33 0.33 | Test: 0.33 0.33 0.33 |
Confusion_matrix: 
[[  0   0 100]
 [  0   0 100]
 [  0   0 100]]

[Epoch 4/30] Loss: 1.0716 | Train: 0.33 0.33 0.33 | Test: 0.33 0.33 0.33 |
Confusion_matrix: 
[[  0   0 100]
 [  0   0 100]
 [  0   0 100]]



KeyboardInterrupt: 