In [1]:
%load_ext autoreload
%autoreload 2 
# !apt-get install -y xvfb
import time
import torch
import scipy
import scipy.sparse
from collections import Counter
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader, ImbalancedSampler
from torch_geometric.data import Dataset
# https://www.youtube.com/watch?v=QLIkOtKS4os --> creating custom dataset in pytorch geometric
from torch.utils.data import Dataset, random_split
import torch_geometric
from torch_geometric.data import Data, InMemoryDataset
import torch_geometric
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GlobalAttention, SAGEConv

import torch
import torch.nn.functional as F
from torch.nn import Linear

from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.utils import to_networkx, from_networkx
from sklearn.model_selection import StratifiedKFold
import networkx as nx
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import re
import pickle
import seaborn as sn
import random
import os
from typing import Optional

from torch_scatter import scatter_add

from torch_geometric.utils import softmax

from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from collections import Counter
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score

from graph_utils import set_device_and_seed, set_seed, show, visualize_graph, visualize_embedding, _count_parameters, visualise_airway_tree_matplotlib
from graph_datasets import CustomDataset
from graph_models import CustomGlobalAttention, GAT
from graph_training import train_model, test_model, _vis_graph_example, train_test_split

In [2]:
def dice_loss(input, target):
    smooth = 1.

    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))



In [3]:


class CustomDatasetNode(torch_geometric.data.Dataset):
    def __init__(self,
                 root,
                 filename_data,
                 filename_labels,
                 test=False,
                 transform=None,
                 pre_transform=None,
                 label_col_name = 'binaryLL_1',
                 node_level = True,
                args = {'node_feature_names': [], 'edge_feature_names': []}):
        '''
        For NODE LEVEL CLASSIFICATION
        root = where dataset should be stored, folder is split into raw_dir and processed_dir
        filename_data = contains X features for nodes + edges (csv)
        filename_labels= contains Y labels for graphs (csv)
        
        
        '''

            
            
        self.test = False
        self.filename_data = os.path.abspath(filename_data)
        self.filename_labels = os.path.abspath(filename_labels)
        self.node_map = {}
        self.node_level = node_level
        self.y = None
        #         super(CustomDataset, self).__init__(root, transform, pre_transform)
        if len(args['node_feature_names']) > 0:
            self.node_feature_names = args['node_feature_names']
        else:
            self.node_feature_names = None
        
        if len(args['edge_feature_names']) > 0:
            self.edge_feature_names = args['edge_feature_names']
        else:
            self.edge_feature_names = None
        print(f"Using Node features: {self.node_feature_names}, Edge features: {self.edge_feature_names}")
        self.label_col = label_col_name
        print(f"Getting labels from: {self.label_col}")
        super(CustomDatasetNode, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        return self.filename_data
    
    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped: NOTE NOT SURE WHAT THIS IS"""
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()

        if self.test:
            return [f'data_test_{i}.pt' for i in list(self.data.index)]
        else:
            return [f'data_{i}.pt' for i in list(self.data.index)]
    
    def _download(self):
        pass
    
    def process(self):
        self.data = pd.read_csv(os.path.abspath(self.raw_paths[0]))
        label_df = self._process_labels()
        graph_ids = self.data.idno.unique()
        
        for i, idno in tqdm(list(enumerate(graph_ids))):
            # iterate through each patricipant (resetindex is important for node relabelling)
            df = self.data.loc[self.data.idno == idno].copy().reset_index()
            # build a graph out of the df with node features, edge_features and edge_adjacency
            x = self._get_node_features(df)
            edge_adjacency = self._get_edge_adjacency(df,index=i)
            edge_features = self._get_edge_features(df)
            y = self._get_label(idno, label_df, x)
            data = Data(x=x,
                        edge_index=edge_adjacency,
                        edge_attr = edge_features,
                        y = y
                       )
            if self.test:
                torch.save(data, 
                            os.path.join(self.processed_dir, 
                                         f'data_test_{i}.pt'))
            else:
                 torch.save(data, 
                            os.path.join(self.processed_dir, 
                                         f'data_{i}.pt'))
    
    
    def _process_labels(self):
        '''
        Reads label df - checks its binary 0,1 labels (REQUIRES COLUMN self.label_col) if not returns an error 
        Keeps only rows that match idnos in self.data_df which is processed first
        measures class proportions and saves to self.class_proportions
        returns a df which has a label per idno with label in col called 'y'
        '''
        label_df = pd.read_csv(os.path.abspath(self.filename_labels))
        # binarise 
        assert self.label_col in label_df.columns, f"The column {self.label_col} cannot be found"
        # drop unnecessary cols
        
        # drop rows not matching to data ids
        data_df =  pd.read_csv(os.path.abspath(self.filename_data))
        label_df_small = label_df.loc[label_df.idno.isin(data_df.idno.unique())]
        self.y = label_df_small
        print("# Graphs", len(label_df_small), "Label Frequency", Counter(label_df_small[self.label_col].to_list()))
        self.class_proportions = {k:v/len(label_df_small) for k,v in Counter(label_df_small[self.label_col].to_list()).items()}
        print(f"Class proportions: {self.class_proportions}")
        return label_df_small
    
    def _get_edge_adjacency(self,df, index):
        '''
        NOTE pytorch requires nodes to start from 0 and go up in integers so need to remap start and end bpids from df
        Turns endbpid 1 > 1, if the next one in the df is endbpid=6 it becomes 2 etc so that the endbpids are in order from 1 to max (index+1)
        Adds trachea (node 0)
        applyys relabelling to start and endbpid based on the dict 
        returns torch tensor in COO format which is a parallel list [[source_node_list], [corresponding_end_node_list]]
        in this format, at list[0][2] and list[1][2] will be source-end node of the second edge in the list
        '''
        # reindex start / endbpids
        
        relabel_map = {v:k+1 for k,v in df.endbpid.to_dict().items()}
        # add trachea map (node 0)
        relabel_map[-1] = 0
        # save mapping
        self.node_map[index] = (df.idno.unique().item(), relabel_map)
        # apply relabelling to source and end nodes
        source_nodes = df.startbpid.apply(lambda x: relabel_map[x]).to_list()
        end_nodes = df.endbpid.apply(lambda x: relabel_map[x]).to_list()
        # return in COO format
        return torch.tensor([source_nodes, end_nodes], dtype=torch.long)

    def _get_node_features(self,df):
        '''
        DF already contains normalised features
        Feature names in node_features
        Format to return is a tensor of shape num_nodes x node_feature_dim with dtype float 
        assumes ordered in ascending order with nodes 0 and up in integers (so node = index / row num)
            node_features = ['x_norm', 'y_norm', 'z_norm', 'dircosx_norm',
       'dircosy_norm', 'dircosz_norm','lobe_norm',
       'sublobe_norm','angle_norm', 'weibel_generation_norm','dist_nn_in_lobe_norm', 'num_desc_norm','max_path_length_norm']

        '''
        # nodes and features (pos, direction,lobe + sublobe categorical, angle to parent, weibel gen, dist to nearest neighbor in lobe, max path length to leaf, num descendents of node)
        node_features = self.node_feature_names

        # adding trachea info to top of list
        trachea_dict = dict.fromkeys(node_features, 0)
        for i in ['nx', 'ny', 'nz']:
            trachea_dict[i] = df.loc[df.endbpid==1][str('parent_loc_'+i)].item()
        
        # currently usnig as features 
        list_of_nodes = df[node_features].to_dict(orient='records')
        # add trachea to nodes
        list_of_lists_nodes =[list(trachea_dict.values())]+ [list(node_feature.values()) for node_feature in list_of_nodes]
        x = torch.tensor(list_of_lists_nodes, dtype=torch.float)
#         print("Shape node features", x.shape)
        return x

    def _get_edge_features(self,df):
        '''
        get matrix of shape [# edges, edge feature size] with type float
        '''
        if self.edge_feature_names is not None:
            edge_feature_names = self.edge_feature_names
        else:
            edge_feature_names = ['centerlinelength_norm','avginnerarea_norm']
            
        edge_norm = df[edge_feature_names].values
        return torch.tensor(edge_norm, dtype=torch.float)

    def _get_label(self, idno, label_df,node_features):
        '''
        for the selected idno, returns value in self.label_col as an integer
        '''
        # Assume trachea node is always non anomalous (0)
        if self.node_level:
            labels = label_df.loc[label_df.idno==idno,
                 self.label_col].values
            
            num_nodes = node_features.shape[0]
            if num_nodes -1 == labels.shape[0]:
#                 print('adding in label for node -1 = 0 in pytorch method')
                labels = [0] + list(labels)
            
#                 print('Num nodes, num labels', num_nodes, len(labels))
            
            return torch.tensor(labels, dtype=torch.int64)
        else:
            # return single graph label
            labels = label_df.loc[label_df.idno==idno,
                 self.label_col].unique()
#             print('num labels per graph', labels.shape[0])
        
            return torch.tensor(labels, dtype=torch.int64)
    
    def len(self):
        return int(self.data.idno.nunique())
    
    def get(self, idx):
        '''
        Equivalent to __getitem__ in pytorch
        '''
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_{idx}.pt'))   
        return data
            

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

In [4]:
orig_df = pd.read_csv(os.path.abspath('/home/sneha/toy_normalised_1407.csv'))
orig_df.describe()

label_df = pd.read_csv(os.path.abspath('/home/sneha/toy_labels_1407.csv'))
label_df.describe()

Unnamed: 0,idno,startbpid,endbpid,graph_label,node_label
count,120976.0,120976.0,120976.0,120976.0,120976.0
mean,5474256.0,49.485708,75.583455,0.570989,0.142045
std,1769592.0,63.740793,79.533403,0.494937,0.349097
min,3010007.0,-1.0,1.0,0.0,0.0
25%,4014340.0,10.0,20.0,0.0,0.0
50%,5022045.0,25.0,48.0,1.0,0.0
75%,7016549.0,62.0,104.0,1.0,0.0
max,8024979.0,655.0,744.0,1.0,1.0


In [5]:
label_col_name = 'node_label'
train_ids, test_ids = train_test_split(label_df, n_splits_test = 5, label_col_name= label_col_name, seed=0)

pilot_df_train = orig_df.loc[orig_df.idno.isin(train_ids)]
pilot_df_test = orig_df.loc[orig_df.idno.isin(test_ids)]
binary_label_df_train = label_df.loc[label_df.idno.isin(train_ids)]
binary_label_df_test = label_df.loc[label_df.idno.isin(test_ids)]
print("Overall Label frequency distribution", [(x, binary_label_df_test[label_col_name].tolist().count(x)) for x in set(binary_label_df_test[label_col_name].tolist())])


# SAVE PILOT DF for training
pilot_df_train.to_csv('/home/sneha/toy_lobe_cleaned_normalised_w_labels_train.csv')
binary_label_df_train.to_csv('/home/sneha/toy_lobe_binary_labels_train.csv')
pilot_df_test.to_csv('/home/sneha/toy_lobe_cleaned_normalised_w_labels_test.csv')
binary_label_df_test.to_csv('/home/sneha/toy_lobe_binary_labels_test.csv')

Overall Label frequency distribution [(0, 103792), (1, 17184)]
Getting train test split stratified on the 120976 labels
Overall Label frequency distribution [(0, 103792), (1, 17184)]


In [6]:
label_col_name='node_label'

#  removed from features for toy example, added edge features to node ones too just in case
node_features = ['nx', 'ny', 'nz', 'dircosx_norm',
               'dircosy_norm', 'dircosz_norm','angle_norm', 'weibel_generation_norm','dist_nn_in_lobe_norm',
                 'num_desc_norm','max_path_length_norm','centerlinelength_norm','avginnerarea_norm',
                'lobe_norm', 'num_children']
edge_feature_names = ['centerlinelength_norm','avginnerarea_norm']

args = {'node_feature_names': node_features, 'edge_feature_names':edge_feature_names}
# DATASETS


my_data_train  = CustomDatasetNode('data_train_toy/',
                               '/home/sneha/toy_lobe_cleaned_normalised_w_labels_train.csv',
                               '/home/sneha/toy_lobe_binary_labels_train.csv',
                               args = args,
                               label_col_name=label_col_name,
                                node_level=True
                              )

my_data_test  = CustomDatasetNode('data_test_toy/',
                               '/home/sneha/toy_lobe_cleaned_normalised_w_labels_test.csv',
                               '/home/sneha/toy_lobe_binary_labels_test.csv',
                               args = args,
                               label_col_name=label_col_name,
                                node_level=True
                              )

Using Node features: ['nx', 'ny', 'nz', 'dircosx_norm', 'dircosy_norm', 'dircosz_norm', 'angle_norm', 'weibel_generation_norm', 'dist_nn_in_lobe_norm', 'num_desc_norm', 'max_path_length_norm', 'centerlinelength_norm', 'avginnerarea_norm', 'lobe_norm', 'num_children'], Edge features: ['centerlinelength_norm', 'avginnerarea_norm']
Getting labels from: node_label


Processing...


# Graphs 120976 Label Frequency Counter({0: 103792, 1: 17184})
Class proportions: {0: 0.857955296918397, 1: 0.14204470308160297}


100%|██████████████████████████████████████████████████████████████████████████████| 2093/2093 [00:12<00:00, 166.70it/s]
Done!


Using Node features: ['nx', 'ny', 'nz', 'dircosx_norm', 'dircosy_norm', 'dircosz_norm', 'angle_norm', 'weibel_generation_norm', 'dist_nn_in_lobe_norm', 'num_desc_norm', 'max_path_length_norm', 'centerlinelength_norm', 'avginnerarea_norm', 'lobe_norm', 'num_children'], Edge features: ['centerlinelength_norm', 'avginnerarea_norm']
Getting labels from: node_label


Processing...


# Graphs 120976 Label Frequency Counter({0: 103792, 1: 17184})
Class proportions: {0: 0.857955296918397, 1: 0.14204470308160297}


100%|██████████████████████████████████████████████████████████████████████████████| 2093/2093 [00:12<00:00, 168.45it/s]
Done!


In [19]:


from torch_geometric.loader import NeighborLoader, ImbalancedSampler, NeighborSampler
from torch_geometric.data import Batch
train_obj = Batch.from_data_list(my_data_train) # loading all graphs into batch object
train_obj.n_id = torch.arange(train_obj.num_nodes)
sampler = ImbalancedSampler(train_obj)
# train_loader = NeighborLoader(train_obj,
#                                num_neighbors=[15, 10, 5], batch_size=128,
#                               sampler=sampler,
#                                shuffle=False)
train_loader = NeighborSampler(train_obj.edge_index,
                               sizes=[15, 10, 5], batch_size=128,
                              sampler=sampler,
                               shuffle=False)

test_obj = Batch.from_data_list(my_data_test) # loading all graphs into batch object
test_obj.n_id = torch.arange(test_obj.num_nodes)
test_loader = NeighborSampler(test_obj.edge_index, node_idx=None, sizes=[-1], batch_size=128, shuffle=False)

# print(my_data_train.y)
# sampler_train = ImbalancedSampler(my_data_train)


# indices = torch.where(train_obj.y==0)[0]#filtering majority class
# print(len(indices)/len(train_obj.y), indices)
# train_loader = NeighborLoader(train_obj,num_neighbors=[-1]*3, input_nodes=indices, batch_size=128, shuffle=True)  
next(iter(test_loader))

(128,
 tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127]),
 EdgeIndex(edge_index=tensor([[  0,   1,   1,   3,   3,   2,   2,   4,   4,   5,   5,   6,   6,   7,
            7,   9,   9,   8,   8,  10,  10,  11,  11,  13,  14,  14,  15,  15,
           16,  16,  18,  18,  20,  20,  21,  21,  24,  24,  26,  26,  29,  29,
 

In [9]:
data = next(iter(train_loader))
batch_size, n_id, adjs = data
all_data = train_obj.x
print(all_data[n_id].shape)
print(batch_size)


torch.Size([504, 15])
128


In [10]:
class SAGE2(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = x.relu_()
                x = F.dropout(x, p=0.5, training=self.training)
        return x.log_softmax(dim=-1)

    @torch.no_grad()
    def inference(self, x_all, subgraph_loader):
        pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch:
        for i, conv in enumerate(self.convs):
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device)
                x = conv(x, batch.edge_index.to(device))
                if i < len(self.convs) - 1:
                    x = x.relu_()
                xs.append(x[:batch.batch_size].cpu())
                pbar.update(batch.batch_size)
            x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all
    
    


In [16]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        super(SAGE, self).__init__()

        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        
        self.lin =  nn.Sequential(
                                        nn.Linear(hidden_channels, out_channels),
                                        nn.Sigmoid()
                                    )

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adjs):
        # `train_loader` computes the k-hop neighborhood of a batch of nodes,
        # and returns, for each layer, a bipartite graph object, holding the
        # bipartite edges `edge_index`, the index `e_id` of the original edges,
        # and the size/shape `size` of the bipartite graph.
        # Target nodes are also included in the source nodes so that one can
        # easily apply skip-connections or add self-loops.
        for i, (edge_index, e_id, size) in enumerate(adjs):
            xs = []
#             print(i, edge_index, e_id, size, sep='\n')
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
            xs.append(x)
            if i == 0: 
                x_all = torch.cat(xs, dim=0)
                layer_1_embeddings = x_all
            elif i == 1:
                x_all = torch.cat(xs, dim=0)
                layer_2_embeddings = x_all
            elif i == 2:
                x_all = torch.cat(xs, dim=0)
                layer_3_embeddings = x_all    
        #return x.log_softmax(dim=-1)
        
        x_out = self.lin(layer_3_embeddings)
       
        return layer_1_embeddings, layer_2_embeddings, layer_3_embeddings, x_out

    def inference(self, x_all):
        pbar = tqdm(total=x_all.size(0) * self.num_layers)
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch.
        total_edges = 0
        for i in range(self.num_layers):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                total_edges += edge_index.size(1)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                if i != self.num_layers - 1:
                    x = F.relu(x)
                xs.append(x)

                pbar.update(batch_size)

            if i == 0: 
                x_all = torch.cat(xs, dim=0)
                layer_1_embeddings = x_all
            elif i == 1:
                x_all = torch.cat(xs, dim=0)
                layer_2_embeddings = x_all
            elif i == 2:
                x_all = torch.cat(xs, dim=0)
                layer_3_embeddings = x_all
                
        pbar.close()

        return layer_1_embeddings, layer_2_embeddings, layer_3_embeddings

In [17]:
device = set_device_and_seed(GPU=True, gpu_name='cuda:0')
in_features = my_data_train[0].x.shape[1]
out_features= 1
model = SAGE(in_features, 8, out_features)
print(model)
model = model.to(device)
x = train_obj.x.to(device)
y = train_obj.y.squeeze().to(device)
print(x.shape, y.shape)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(epoch):
    # https://towardsdatascience.com/a-comprehensive-case-study-of-graphsage-algorithm-with-hands-on-experience-using-pytorchgeometric-6fc631ab1067
    model.train()
    pbar = tqdm(total=int(len(train_loader)))
    pbar.set_description(f'Epoch {epoch:02d}')
    #pbar = tqdm(total=train_idx.size(0))
    #pbar.set_description(f'Epoch {epoch:02d}')
    denom = 0
    total_loss = total_correct = 0
    for batch_size, n_id, adjs in train_loader:
        denom += batch_size
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        adjs = [adj.to(device) for adj in adjs]
        optimizer.zero_grad()    
        l1_emb, l2_emb, l3_emb, x_out = model(x[n_id], adjs)
#         print("Layer 1 embeddings", l1_emb.shape)
#         print("Layer 2 embeddings", l2_emb.shape)
#         print("Layer 3 embeddings", l3_emb.shape, l3_emb)
#         print('xout', x_out.shape, x_out)
#         out = l3_emb.log_softmax(dim=-1)
#         print(out)
        labels = y[n_id[:batch_size]]
#         print('labels', labels.shape, labels)
#         loss = F.nll_loss(x_out, y[n_id[:batch_size]])
        loss = F.binary_cross_entropy(x_out.squeeze(),labels.float())
#         print('loss',loss)
        loss.backward()
        optimizer.step()
        total_loss += float(loss.item())
        pred = x_out > 0.5
        pred = pred.long()
#             print(pred.unique())
        total_correct += pred.eq(labels.view_as(pred)).sum().item()
#         total_correct += int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
        #pbar.update(batch_size)

    #pbar.close()

    loss = total_loss / denom
    approx_acc = total_correct / denom

    return loss, approx_acc

# def train(epoch):
#     https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py
#     model.train()

#     pbar = tqdm(total=int(len(train_loader.dataset)))
#     pbar.set_description(f'Epoch {epoch:02d}')

#     total_loss = total_correct = total_examples = 0
#     for batch in train_loader:
#         batch = batch.to(device)
#         optimizer.zero_grad()
#         y = batch.y[:batch.batch_size]
#         y_hat = model(batch.x, batch.edge_index)[:batch.batch_size]
#         print('preds', y_hat, 'orig', y)
#         loss = F.cross_entropy(y_hat, y)
#         loss.backward()
#         optimizer.step()

#         total_loss += float(loss) * batch.batch_size
#         total_correct += int((y_hat.argmax(dim=-1) == y).sum())
#         total_examples += batch.batch_size
#         pbar.update(batch.batch_size)
#     pbar.close()

#     return total_loss / total_examples, total_correct / total_examples
losses = []
for epoch in range(1, 11):
    loss, acc = train(epoch)
    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')

Using cuda:0
Setting torch, cuda, numpy and random seeds to 0
SAGE(
  (convs): ModuleList(
    (0): SAGEConv(15, 8)
    (1): SAGEConv(8, 8)
    (2): SAGEConv(8, 8)
  )
  (lin): Sequential(
    (0): Linear(in_features=8, out_features=1, bias=True)
    (1): Sigmoid()
  )
)
torch.Size([123069, 15]) torch.Size([123069])





  0%|                                                                                           | 0/962 [00:00<?, ?it/s][A[A[A


Epoch 01:   0%|                                                                                 | 0/962 [00:05<?, ?it/s][A[A[A


Epoch 01, Loss: 0.0014, Approx. Train: 0.9341





  0%|                                                                                           | 0/962 [00:00<?, ?it/s][A[A[A


Epoch 02:   0%|                                                                                 | 0/962 [00:05<?, ?it/s][A[A[A


Epoch 02, Loss: 0.0011, Approx. Train: 0.9477





  0%|                                                                                           | 0/962 [00:00<?, ?it/s][A[A[A


Epoch 03:   0%|                                                                                 | 0/962 [00:05<?, ?it/s][A[A[A


Epoch 03, Loss: 0.0011, Approx. Train: 0.9505





  0%|                                                                                           | 0/962 [00:00<?, ?it/s][A[A[A


Epoch 04:   0%|                                                                                 | 0/962 [00:05<?, ?it/s][A[A[A


Epoch 04, Loss: 0.0010, Approx. Train: 0.9518





  0%|                                                                                           | 0/962 [00:00<?, ?it/s][A[A[A


Epoch 05:   0%|                                                                                 | 0/962 [00:05<?, ?it/s][A[A[A


Epoch 05, Loss: 0.0011, Approx. Train: 0.9549





  0%|                                                                                           | 0/962 [00:00<?, ?it/s][A[A[A


Epoch 06:   0%|                                                                                 | 0/962 [00:05<?, ?it/s][A[A[A


Epoch 06, Loss: 0.0010, Approx. Train: 0.9547





  0%|                                                                                           | 0/962 [00:00<?, ?it/s][A[A[A


Epoch 07:   0%|                                                                                 | 0/962 [00:00<?, ?it/s][A[A[A

KeyboardInterrupt: 

In [None]:
data = next(iter(train_loader)).to(device)
# for data in train_loader:
print(data.edge_index[:data.batch_size])
y = data.y[:data.batch_size]


In [None]:

model = SAGE(dataset.num_features, 256, dataset.num_classes).to(device)

