<h1>Graph Construction Functions</h1>

This notebook consists of a collection of functions which can be used for various stages of building graph representations of WSIs with tile-level nodes. Individual functions are given with examples of usage; an end-to-end function which goes straight from WSI to pytorch geometric graph is given in a separate document and is preferable for practical use. Note that some of these functions (especially after node feature extraction) are still rough/may require some data manipulation outside of the provided functions; there is limited cross-compatibility with the end-to-end function. 

<h3>Overview of Necessary Installations</h3>

In [2]:
#Use pip to install the following packages in CLI to the appropriate conda environment/kernel before running 

#For running hovernet nucleus segmentation
# !pip install scanpy
# !pip install torchvision
# !pip install opencv-python

#For running graph construction
# !pip install pyflann
# !pip install networkx
# !pip install torch_sparse, torch_scatter
# !pip install git+https://github.com/rusty1s/pytorch_geometric.git

In [None]:
#General imports 
import numpy as np
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
from torch.optim.lr_scheduler import StepLR
from PIL import Image
import cv2
import skimage 
from torchvision import models, transforms
import itertools
import math, random
import pandas as pd
import seaborn as sns
import scanpy as sc
from glob import glob
import sys,os
%matplotlib inline

#For hovernet
import albumentations as A
from pathml.datasets.pannuke import PanNukeDataModule
from pathml.ml.hovernet import HoVerNet, loss_hovernet, post_process_batch_hovernet, _HoverNetDecoder
from pathml.ml.utils import wrap_transform_multichannel, dice_score
from pathml.utils import plot_segmentation

#For graph construction
from collections import OrderedDict
from pyflann import *
import skimage.feature
import networkx as nx
import torchvision.transforms.functional as F
import torch_geometric.data as data
import torch_geometric.utils as utils
import torch_geometric

<h3>Nucleus Extraction</h3>

In [None]:
#General setup

#prepare the model, the GPU, and send the model to the GPU
device = torch.device("cuda:0")
checkpoint = torch.load("/path/to/hovernet_pannuke.pt", map_location='cpu')

n_classes_pannuke = 6

hovernet = HoVerNet(n_classes=n_classes_pannuke) #nuclei will be classified into 1 of 6 classes after segmentation
hovernet = torch.nn.DataParallel(hovernet) # wrap model to use multi-GPU
hovernet.load_state_dict(checkpoint) #load the best checkpoint for prediction/finetuning 

hovernet = hovernet.module
hovernet.to(device);
hovernet.eval();

In [None]:
#General function set for nucleus segmentation
def pil_loader(path):
    """
    Open single image as file to avoid ResourceWarning 
    (https://github.com/python-pillow/Pillow/issues/835).
    
    Input: string path name; Output: PIL image.
    
    (For segmentation purposes, this function is rarely called alone.)
    """
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
def as_array(path):
    """Open single image as np array.
    
    Input: string path name; Output: np array.
    """
    x = np.asarray(pil_loader(path))
    return x 

def extract_nuclei(df, save_path=None, save_name=None, num_workers=12, **kwargs):

    """
    This function generates segmentation masks from a tile dataframe.
    
    Setup: 
    - Initialize a hovernet PathML model under global variable name hovernet.
    - Ensure GPU is prepared and hovernet model is sent to GPU in eval mode.
    - Ensure you have proper read/write permissions for any directory you intend
    to write your masks to!
    
    Inputs:
    - df: a Pandas DataFrame object containing one row per tile. There should
    be one column titled 'full_path' whose entries are the paths to the image 
    of each tile. 
    - save_path (optional): string; if specified, this will be the path 
    to which the segmentation data will be written; this string MUST be an absolute
    path and should not contain the file name. If save_path is left blank while 
    save_name is specified, then the file will be saved to the same directory
    in which the function is run (via notebook or script).
    - save_name (optional): string; if specified, this will be the filename 
    to which the segmentation data will be written; this string should not include
    a file extension. (File will be written as .npy) Must be specified if save_path
    is also specified.
    - num_workers (optional): int; number of processes for running inference. 
    Default is 12.
    
    Outputs:
    - a NumPy array of dimensions (#num tiles, 6, tile_dim, tile_dim) containing 
    six-channel segmentation masks for each tile (in the order of rows in the 
    original DataFrame). If save_name is specified, this NumPy array will be saved
    to the given path; otherwise, it will be returned. 
    """
    #prepare an array of tiles from the df
    tiles = np.array([as_array(path) for path in df.full_path.to_numpy()])
    #rearrange axes for torch inference
    tiles = np.moveaxis(tiles, 3, 1)
#     print(tiles.shape)

    #Build simple dataloader
    tile_data = torch.utils.data.DataLoader(tiles, batch_size=10, num_workers=num_workers)

    #pass tiles to the GPU for segmentation inference in small batches
    #fill up arrays with predictions. 
    mask_pred = None

    with torch.no_grad():
        for i, data in tqdm(enumerate(tile_data)):
            # send the data to the GPU
            images = data.float().to(device)

            # pass thru network to get predictions
            outputs = hovernet(images)
            _, preds_classification = post_process_batch_hovernet(outputs, n_classes=n_classes_pannuke)

            #add model results to our prediction-storing arrays
            if i == 0:
                mask_pred = preds_classification
            else:
                mask_pred = np.concatenate([mask_pred, preds_classification], axis=0)

    #if no save_name specified, return the array. 
    if save_name is None:
        return mask_pred
    
    #if save_name is given, but no save_path, then save a copy to the local directory. 
    elif save_path is None:
        np.save(save_name + '.npy', mask_pred)
    
    #if save_name and save_path are given, save the array with the given name to the specified path.
    #Hacky loop for creating directories.
    else:
        for i in range(len(save_path.split('/'))):
            try:
                #Each time the loop iterates, try making one more nested directory as described in the save_path string.
                os.makedirs('/' + '/'.join(dest.split('/')[:i+1]))
            except OSError as err:
                print("OS error: {0}".format(err))
        np.save(save_path + '/' + save_name + '.npy', test) 

In [None]:
#Example code block:

#prepare the model, the GPU, and send the model to the GPU
device = torch.device("cuda:0")
checkpoint = torch.load("/path/to/hovernet_pannuke.pt", map_location='cpu')

n_classes_pannuke = 6

hovernet = HoVerNet(n_classes=n_classes_pannuke) #nuclei will be classified into 1 of 6 classes after segmentation
hovernet = torch.nn.DataParallel(hovernet) # wrap model to use multi-GPU
hovernet.load_state_dict(checkpoint) #load the best checkpoint for prediction/finetuning 

hovernet = hovernet.module
hovernet.to(device);
hovernet.eval();

#load in a df of tiles
df = pd.read_pickle('010721_Master_df_filtered_TCGA_annotated.pkl')

#Run segmentation on tiles for the first 10 WSIs and store to a persistent disk
for sample in df.groupby('sample_id').count().index.to_numpy()[:10]:
    wsi = df.loc[df.sample_id == sample]
    #Although storage-intensive, it is useful to save a copy of slide-specific dfs to ensure a matching order 
    #of rows in df to tiles to segmentation masks.
    wsi.to_pickle('/mnt/disks/data/slide_dfs/'+sample+'.pkl')
    extract_nuclei(wsi, save_path='/mnt/disks/data/segmentation_masks/', save_name=sample, num_workers=12)

<h3>Feature Extraction/Node-building</h3>

Each node represents one tile of the WSI, with the following features (averaged across all nuclei in the tile):
 - (x, y) coordinates of centroid (averaged and used to plot nodes relative to each other in the final graph)
 - Short axis length 
 - Long axis length
 - Angle
 - Area 
 - Arc length
 - Eccentricity
 - Roundness
 - Solidity
 - Intensity
 - Dissimilarity
 - Homogeneity
 - Energy
 - ASM
 
Additional neural network features for each tile to be added later. 

In [1]:
#IMPORTANT GLOBAL VARIABLE FOR NAMING THE FINAL DATAFRAME
COLUMNS_LIST = ['sample_id', 'full_path', 'num_nuclei', 'x_coord', 'y_coord', 'avg_short_axis', 
                'avg_long_axis', 'avg_angle', 'avg_area', 'avg_arc_length', 'avg_eccentricity', 'avg_roundness', 
                'avg_solidity', 'avg_intensity', 'avg_dissimilarity', 'avg_homogeneity', 'avg_energy', 'avg_ASM']


#General function set 
def compare(image, path, mask):
    """
    Visualize a tile and its corresponding binary mask. 
    input: np array img of dimensions (size, size, 3), img path string, 
    np array mask of dimensions (size, size)
    
    output: nothing returned; matplotlib object created/displayed
    """
    fig, ax = plt.subplots(ncols = 2, figsize = (16, 7))
    mask = np.repeat(np.where(mask > 0, 255, 0)[:,:,np.newaxis], 3, axis=2)
    name = path.split('/')[-1]
    print(f'Image name: {name}')
    ax[0].imshow(image)
    ax[1].imshow(mask)
    ax[0].set_title("Tile")
    ax[1].set_title("Nucleus mask")

def get_cell_image(img, cx, cy, size=512):
    """
    Extract a "context-window" around a specified nucleus of size 64x64px. 
    
    input: np array img of dimensions (size, size, 3), (cx, cy) centroid 
    coordinates of nucleus, (optional) size param for img size (default 512px)
    output: a 64x64px np array. 
    """
    cx = 32 if cx < 32 else size - 32 if cx > size - 32 else cx
    cy = 32 if cy < 32 else size - 32 if cy > size - 32 else cy
    if len(img.shape) == 3:
        return img[cy - 32:cy + 32, cx - 32:cx + 32, :]
    else:
        return img[cy - 32:cy + 32, cx - 32:cx + 32]

def get_basic_cell_features(img, grayscale, contour):
    """
    Modified feature extractor for single nucleus basic (non-neural) features.
    :param img: np array image
    :param grayscale: grayscale version of np array image
    :param contour: contour produced from nucleus segmentation (list of (x,y) 
    coordinates for one nucleus)
    
    :return: x coordinate, y coordinate of nucleus centroid, concatenated
    feature vector for successful contour; None for unsuccessful contour (various
    cases apply)
    """
    # Get contour coordinates from contour
    
    #Contours with fewer than 5 points cannot be fit to ellipse - return None
    if contour.shape[0] < 5:
        return None
    (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)
    
    #contours without valid centroids cannot be processed - return None
    if math.isnan(cx) or math.isnan(cy):
        return None
    cx, cy = int(cx), int(cy)
    
    # Get a 64 x 64 center crop about each nucleus for GLCM features
    img_cell = get_cell_image(grayscale, cx, cy)
    img_cell_grey = np.pad(img_cell, [(0, 64 - img_cell.shape[0]), (0, 64 - img_cell.shape[1])],
                           mode='reflect')
    
    # 1. Generate contour features
    eccentricity = math.sqrt(1 - (short_axis / long_axis) ** 2)
    convex_hull = cv2.convexHull(contour)
    area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)
    solidity = float(area) / hull_area
    arc_length = cv2.arcLength(contour, True)
    
    #it's possible in rare cases for the area to be evaluated as 0 - return None, avoid div-by-0 error
    if area == 0:
        return None
    roundness = (arc_length / (2 * math.pi)) / (math.sqrt(area / math.pi))
    intensity = get_mean_contour_intensity_grayscale(grayscale, contour)

    # 2. Generating GLCM features
    out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])
    dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]
    homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]
    energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]
    ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]
    # Concatenate + Return all features
    x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity, intensity],
         [dissimilarity, homogeneity, energy, ASM]]
    return cx, cy, np.array(list(itertools.chain(*x)), dtype=np.float64)

def get_mean_contour_intensity_grayscale(gray_img, contours):
    """
    Proxy for how dark/light the pixels are within the segmented nucleus area
    :param gray_img: cv2 converted grayscale image of segmented nucleus area
    :param contours: contour produced from nucleus segmentation
    :return: mean pixel intensity over nucleus contour area
    """
    img_size = gray_img.shape[0]
    assert gray_img.shape[0] == gray_img.shape[1]
    nucl_mask = np.zeros((img_size, img_size))
    cv2.fillPoly(nucl_mask, pts=contours, color=(1.));  # use contour area to make a mask
    z = np.ma.masked_array(data=gray_img,
                           mask=(nucl_mask != 1.))  # mask masked version to select out only the contour area
    return z.compressed().mean()  # take mean grayscale pixel intensity over contour area

def tile(slide, masks, i):
    """
    Prepares a single tile for feature extraction. 
    :param slide: Pandas DataFrame object in which one row represents one tile. Ensure 
    at least one column in this df has name 'full_path' for accessing tile image, and 
    one column has name 'sample_id' for identifying from which slide the tiles originate.
    :param masks: np array object of dimensions (num_df_rows, tile_dim_1, tile_dim_2)
    :param i: indexing integer. Important is that the slide df and the masks array 
    are ordered identically for corresponding tiles and binary masks to be processed
    together; this shouldn't be a problem if previous functions are used. 
    
    :return: np array image of dimensions (tile_dim_1, tile_dim_2, 3); np array mask 
    of dimensions (tile_dim_1, tile_dim_2); list contours (each element in this list
    is a list of (x,y) coordinates, representing the boundary of one nucleus); string 
    path (full path to image); string sample_id (the associated sample ID for identifying
    tiles); x_coord, y_coord of the upper-left-hand corner of the tile (relative to its
    position on the WSI)
    """
    #Get image path
    path = slide.iloc[i]['full_path']
    
    #Get out tile x, y coordinates with some very filthy one-liners
    x_coord = int(path.split('/')[-1].split('_')[0])
    y_coord = path.split('/')[-1].split('_')[1]
    y_coord = int(y_coord[:y_coord.find('.')])
    
    #Retrieve image
    image = as_array(path)
    #Re-type mask for cv2 to be happy
    #This is performed per mask, not for all masks at once, for memory preservation
    mask = masks[i].astype('uint8')

    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    sample_id = slide.iloc[i]['sample_id']
    return image, mask, contours, path, sample_id, x_coord, y_coord

def tile_level_feats(contours, image, x_coord, y_coord, tile_size=512):
    """
    Computes features for each nucleus in a tile and averages them to get feature 
    set for the entire tile. A future version of this function will incorporate
    neural-extracted features.
    
    :param contours: list of all contours (boundaries of nuclei) in a given tile.
    :param image: np array of tile, with dimensions (tile_dim_1, tile_dim_2, 3).
    :param x_coord: int for x_coord of upper-left-hand corner of tile.
    :param y_coord: int for y_coord of upper-left-hand corner of tile.
    :param tile_size: (optional) int for tile dimensions. Used for computing average
    centroid of tile. Default is 512.
    
    :return: feature vector as np array; integer count of number of nuclei in tile. 
    For the rare tile where no nuclei are (successfully) called from the segmentation
    mask, a dummy array will be returned with all entries -1, and a count of 0 nuclei. 
    This tile will then be removed in downstream processing and will not be included 
    in the final graph. Later versions of this function may elect to keep such tiles, 
    where neural features can make them more informative for graph analysis.
     
    """
    temp_data = []
    #Make grayscale image copy
    grayscale = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    for contour in contours:
        output = get_basic_cell_features(image, grayscale, contour)
        if output is not None:
            cent_x, cent_y, features = output
            
            #re-type the data as a list - all lists will have identical len
            temp_data.append([cent_x, cent_y] + list(features))
    if len(temp_data) == 0:
        return np.array([-1 for i in range(15)]), 0
    
    #the non-ragged list can then be quickly averaged and converted to np array format
    result = np.average(np.asarray(temp_data), axis=0)
    
    #Compute the position for each node by averaging centroids of the tile's nuclei.
    result[0] = result[0]/512 + x_coord
    result[1] = result[1]/512 + y_coord
    return result, len(temp_data)

def slide_level_feats(slide, masks):
    """
    Computes features for all tiles in a given slide. 
    
    :param slide: Pandas DataFrame object in which one row represents one tile. Ensure 
    at least one column in this df has name 'full_path' for accessing tile image, and 
    one column has name 'sample_id' for identifying from which slide the tiles originate.
    Also ensure that this df has tiles for ONLY one slide at a given time. 
    :param masks: np array object of dimensions (num_df_rows, tile_dim_1, tile_dim_2)
    
    :return: Pandas Dataframe object in which one row represents one node in graph. 
    Columns 'slide_id' and 'full_path' can be used to identify tile/WSI of origin; remaining 
    columns are features of the node. 
    """
    data = []
    for i in tqdm(range(slide.shape[0])):
        #Extract tile data from each row of the df
        image, mask, contours, path, sample_id, x_coord, y_coord = tile(slide, masks, i)
        #Extract features for the tile
        features, num_nuclei = tile_level_feats(contours, image, x_coord, y_coord)
        #Add features to data list
        data.append([sample_id, path] + [num_nuclei] + list(features))
    #Re-type the data list into a DataFrame object
    data = pd.DataFrame(data=data, columns=COLUMNS_LIST)
    return data

def process_whole_slide(wsi, mask_path, save_path=None, save_name=None, **kwargs):
    """
    Computes features for all tiles in a given slide - starting from dataframe or path to dataframe,
    and path to array file. 
    
    :param wsi: Can be provided as a string, which will be interpreted as a path to a dataframe pickle
    file. Additionally, a dataframe with tiles for a single WSI (only one WSI at a time!) can be passed.
    The risk is if the dataframe is built at a different time from the mask array, the order of tiles
    may be different. Ensure that the order of rows in the dataframe is reproducible/matches up with 
    the order of elements inthe mask array. 
    :param mask_path: A string representing a path to a numpy array of six-channel, default hovernet
    segmentation masks. 
    :param save_path: (optional) string; if specified, this will be the path 
    to which the segmentation data will be written; this string MUST be an absolute
    path and should not contain the file name. If save_path is left blank while 
    save_name is specified, then the file will be saved to the same directory
    in which the function is run (via notebook or script).
    :param save_name: (optional) string; if specified, this will be the filename 
    to which the segmentation data will be written; this string should not include
    a file extension. (File will be written as .npy) Must be specified if save_path
    is also specified.
    
    :return: Pandas Dataframe, if save_name is not specified. 
    """
    #Setting up slide object
    slide = None
    if isinstance(wsi, pd.DataFrame):
        slide = wsi 
        slide = slide.reset_index()
    else:
        slide = pd.read_pickle(wsi)
        slide = slide.reset_index()
    
    #Load mask array, process into binary masks.  
    masks = np.load(mask_path)
    masks = np.sum(masks, axis=1)
    masks = np.where(masks > 0, 255, 0)
    
    #Write the df object, removing tiles with no nuclei.
    #Future versions may give option to keep tiles with no nuclei if they have additional features 
    #(e.g. ResNet feature vector representation)
    nodes = slide_level_feats(slide, masks)
    nodes = nodes.loc[nodes.num_nuclei > 0]
    
    #if no save_name specified, return the array. 
    if save_name is None:
        return nodes
    
    #if save_name is given, but no save_path, then save a copy to the local directory. 
    elif save_path is None:
        nodes.to_pickle(save_name + '.pkl')
    
    #if save_name and save_path are given, save the array with the given name to the specified path.
    #Hacky loop for creating directories.
    else:
        for i in range(len(save_path.split('/'))):
            try:
                #Each time the loop iterates, try making one more nested directory as described in the save_path string.
                os.makedirs('/' + '/'.join(dest.split('/')[:i+1]))
            except OSError as err:
                print("OS error: {0}".format(err)) 
        nodes.to_pickle(save_path + '/' + save_name + '.pkl')

In [None]:
#Example code block:

#IMPORTANT GLOBAL VARIABLE FOR NAMING THE FINAL DATAFRAME
COLUMNS_LIST = ['sample_id', 'full_path', 'num_nuclei', 'x_coord', 'y_coord', 'avg_short_axis', 
                'avg_long_axis', 'avg_angle', 'avg_area', 'avg_arc_length', 'avg_eccentricity', 'avg_roundness', 
                'avg_solidity', 'avg_intensity', 'avg_dissimilarity', 'avg_homogeneity', 'avg_energy', 'avg_ASM']

#Set up df and mask directories
df_dir = sorted(glob('/mnt/disks/data/slide_dfs/*'))
mask_dir = sorted(glob('/mnt/disks/data/segmentation_masks/*'))

#Store all data in a single large df
all_nodes = None

#Run node feature extraction for first 10 slides
for i in range(10):
    if i == 0:
        all_nodes = process_whole_slide(df_dir[i], mask_dir[i]) 
    else:
        all_nodes = pd.concat([all_nodes, process_whole_slide(df_dir[i], mask_dir[i])], ignore_axis=True)

all_nodes.to_pickle('/mnt/disks/data/graph_nodes.pkl')

<h3>Graph-building</h3>

In [None]:
#General function set 
def df2graph(df):
    """
    This function takes in a dataframe with columns ordered as follows:
        
        ['full_path', 'x_coord', 'y_coord', 'num_nuclei', 'avg_short_axis',
       'avg_long_axis', 'avg_angle', 'avg_area', 'avg_arc_length',
       'avg_eccentricity', 'avg_roundness', 'avg_solidity', 'avg_intensity',
       'avg_dissimilarity', 'avg_homogeneity', 'avg_energy', 'avg_ASM',
       'class_1_prob'] - index column is 'slide_id'
       
    NOTE: You may need to reorder/drop some columns if you use this function
    on a raw dataframe of nodes. This format is ESSENTIAL to ensuring that 
    only numerical features get added for each node in the graph, which prevents
    huge ugly torch-typing errors from the downstream conversion of networkx graphs
    to pytorch-geometric graphs. 
       
    It returns a networkx Graph object in which nodes are the avg-centroids 
    of tiles (with coordinates [node_x, node_y]) and features are all numerical
    values except for the coords (not incl. full_path or slide_id)
       """
    
    """Silly bunch of code, but here's what happens. 
    -First, drop NaNs as necessary, then reset index.
    -Sort nodes so that they appear top-bottom, left-right relative to 
    their position on the WSI (makes for more organized graph node names)
    -Do it again (???) - the tile sorting won't stick unless you do. 
    -Re-index on the new sorted order of rows, overwriting the original indexes
    of the rows (before sorting). 
    -Delete the old indexes (which get pushed to a column titled 'index')"""
    df = df.dropna()
    df = df.reset_index()
    df = df.sort_values(by=['node_x', 'node_y'])
    df = df.reset_index()
    df = df.sort_values(by=['node_x', 'node_y'])
    df = df.reindex([i for i in range(df.shape[0])])
    df = df.drop(labels=['index'], axis=1)

    #initialize graph
    G = nx.Graph()
    
    #use iterrows method - index will be node name, features and centroid will be extracted appropriately
    for v, node in df.iterrows():
        G.add_node(v, centroid=np.array([node['node_x'], node['node_y']], dtype=np.float32), x=node[3:].to_numpy(dtype=np.float32))
    
    return G

def KNN(G):
    """Condensed version of KNN which adds edges to graph."""
    #Code directly transplanted from pathomic-fusion notebook 
    #First, build a "dataset"  using the centroid data - collect node data into a new array
    centroids = []
    for u, attrib in G.nodes(data=True):
        centroids.append(attrib['centroid'])

    cell_centroids = np.array(centroids).astype(np.float64)
    dataset = cell_centroids


    start = None

    #Run K-means
    for idx, attrib in tqdm(list(G.nodes(data=True))):
        start = idx

        #initialize the FLANN object 
        flann = FLANN()

        #Add one node's worth of edges to the graph at a time 
        testset = np.array([attrib['centroid']]).astype(np.float64)

        #Calculate edges 
        results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)
        results, dists = results[0], dists[0]
        nns_fin = []
       # assert (results.shape[0] < 6)

        #Use results to draw in edges in the graph 
        for i in range(1, len(results)):
            G.add_edge(idx, results[i], weight = dists[i])
            nns_fin.append(results[i])

    return G

def from_networkx(G):
    r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
    :class:`torch_geometric.data.Data` instance.
    Args:
        G (networkx.Graph or networkx.DiGraph): A networkx graph.
    """

    G = G.to_directed() if not nx.is_directed(G) else G
    edge_index = torch.tensor(list(G.edges)).t().contiguous()

    keys = []
    keys += list(list(G.nodes(data=True))[0][1].keys())
    keys += list(list(G.edges(data=True))[0][2].keys())
    data = {key: [] for key in keys}

    for _, feat_dict in G.nodes(data=True):
        for key, value in feat_dict.items():
            print(value) if count==0 else None
            data[key].append(value)
    for _, _, feat_dict in G.edges(data=True):
        for key, value in feat_dict.items():
            data[key].append(value)

    #Hopefully I can re-type the final dictionary value manually to avoid issues
    weights = data['weight']
    weights = [float(x) for x in weights]
    # weights = np.array(weights)
    data['weight'] = weights

    # THIS IS THE PROBLEMATIC PART
    for key in data.keys():
        data[key] = torch.tensor(data[key])
    #     print(key)

    data['edge_index'] = edge_index
    data = torch_geometric.data.Data.from_dict(data)
    data.num_nodes = G.number_of_nodes()

    return data

def save_to_pytorch_geometric(G, save_path='test', **kwargs):
    """
    Code block for writing the networkx graph into a pytorch geometric graph, and then saving it.
    Note - the save_path parameter should NOT have a filename extension to it. """
    G = from_networkx(G)

    edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)
    G.edge_attr = edge_attr_long 

    edge_index_long = G['edge_index'].type(torch.LongTensor)
    G.edge_index = edge_index_long

    x_float = G['x'].type(torch.FloatTensor)
    G.x = x_float

    G['weight'] = None
    G['nn'] = None
    torch.save(G, save_path+'.pt')