In [6]:
import random
import collections
import math
import os

from scipy.sparse import coo_matrix
import pandas as pd
import matplotlib.pyplot as plt


import networkx as nx
from sklearn.neighbors import NearestNeighbors

import pickle
import ipywidgets as widgets
from ipywidgets import interact, fixed

In [7]:
def create_df(tumorList, stromaList, TILList1, TILList2, NK, MP, 
              numtumor=500, numstroma=500, numTIL1=0, numTIL2=0, numNK=0, numMP=0):
    df = pd.DataFrame(columns=['x', 'y', 'label'])
    pos= [] 
    x = []
    y = []
    label = []
    tumor = random.sample(tumorList, numtumor)
    stroma = random.sample(stromaList, numstroma)
    TIL1 = random.sample(set(TILList1) - set(tumor) - set(stroma), numTIL1)
    TIL2 = random.sample(set(TILList2) - set(tumor) - set(stroma) - set(TIL1), numTIL2)
    NK = random.sample(set(NK) - set(tumor) - set(stroma)- set(TIL1) - set(TIL2), numNK)
    MP = random.sample(set(MP) - set(tumor) - set(stroma)- set(TIL1) - set(TIL2)-set(NK), numMP)
    
    loop1 = []
    loop2 = []
    for i,j in zip([tumor, stroma, TIL1, TIL2, NK, MP], ['Tumor', 'Stroma', 'TIL1', 'TIL2', 'NK', 'MP']):
        if i:
            loop1.append(i)
            loop2.append(j)

    for l, labelName in zip(loop1, loop2):
        pos.extend(l)
        for idx, content in enumerate(zip(*l)):
            [x, y][idx].extend(content)
        label.extend([labelName for i in range(len(content))])
    df['x'] = x
    df['y'] = y
    df['label'] = label
    return df, pos


def create_graph(df, pos):
    dfXY = df[['x', 'y']].copy()
    N = len(dfXY)
    nn = NearestNeighbors(radius=60)
    nn.fit(dfXY)
    dists, ids = nn.radius_neighbors(dfXY)
    dists_ = [j for i in dists for j in i]
    ids_ = [j for i in ids for j in i]
    # generate row indices
    rows = [i for i, j in enumerate(ids) for k in j]
    # number of edges
    M = len(rows)
    w = np.ones(M)
    # complete matrix according to positions
    _W = coo_matrix((w, (rows, ids_)), shape=(N, N))
    coo_matrix.setdiag(_W, 0)
    _W = 1/2*(_W + _W.T)
    # create networkx graph
    G = nx.from_scipy_sparse_matrix(_W)
    for i in range(len(G.nodes)):
        G.nodes[i]['pos'] = pos[i]
        G.nodes[i]['cell_types'] = df['label'][i]
    return G
    
    
def add_data(id_, range_, nums=[1500, 1500, 0, 0, 0, 0], count=1):
    TILList1 = [(x+1,y+1) for x in range(range_[0][0], range_[0][1]) for y in range(range_[0][2], range_[0][3])]
    TILList2 = [(x+1,y+1) for x in range(range_[1][0], range_[1][1]) for y in range(range_[1][2], range_[1][3])]
    NK = [(x+1,y+1) for x in range(range_[2][0], range_[2][1]) for y in range(range_[2][2], range_[2][3])]
    MP = [(x+1,y+1) for x in range(range_[3][0], range_[3][1]) for y in range(range_[3][2], range_[3][3])]
    for j in range(count):
        df, pos = create_df(tumorList, stromaList, TILList1, TILList2, NK, MP, \
                            numtumor=nums[0], numstroma=nums[1], numTIL1=nums[2], numTIL2=nums[3], \
                            numNK=nums[4], numMP=nums[5])
        G = create_graph(df, pos)
        patientDict[id_].append(G)
        

# Data creation

In [3]:
# set a fixed random seed for training (123) / val (124) / test (125)
random.seed(123)

patientKeys = [('{:0>4d}'.format(i+1)) for i in range(10)]
patientDict = collections.defaultdict(list)
tumorList = [(x+1,y+1) for x in range(0, 500) for y in range(0, 1000)]
stromaList = [(x+1,y+1) for x in range(500, 1000) for y in range(0, 1000)]

# add similar graphs
for i in patientKeys:
    add_data(i, [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0] , [0, 0, 0, 0]], \
             nums=[500, 500, 0, 0, 0, 0], count=3)
    
patch1 = [[425, 575, 0, 1000], [425, 575, 0, 1000], [0, 0, 0, 0], [0, 0, 0, 0]]
patch2 = [[0, 500, 0, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
patch3 = [[0, 500, 0, 1000], [0, 500, 0, 1000], [0, 0, 0, 0], [0, 0, 0, 0]]
patch4 = [[0, 500, 0, 1000], [0, 500, 0, 1000], [0, 0, 0, 0], [0, 1000, 0, 1000]]
patch5 = [[0, 500, 0, 1000], [0, 500, 0, 1000], [0, 1000, 0, 1000], [0, 1000, 0, 1000]]

num1, num2, num3, num4, num5 = [400, 400, 100, 100, 0, 0], [300, 300, 400, 0, 0, 0], \
                                   [300, 300, 200, 200, 0, 0], [300, 300, 150, 150, 0, 100], \
                               [300, 300, 100, 100, 100, 100]

for fold in range(1):
    # add discriminative graphs
    add_data(patientKeys[10*fold], patch1, num1)
    add_data(patientKeys[10*fold], patch2, num2)

    add_data(patientKeys[10*fold+1], patch1, num1)
    add_data(patientKeys[10*fold+1], patch3, num3)

    add_data(patientKeys[10*fold+2], patch1, num1)
    add_data(patientKeys[10*fold+2], patch4, num4)

    add_data(patientKeys[10*fold+3], patch1, num1)
    add_data(patientKeys[10*fold+3], patch5, num5)

    add_data(patientKeys[10*fold+4], patch2, num2)
    add_data(patientKeys[10*fold+4], patch3, num3)

    add_data(patientKeys[10*fold+5], patch2, num2)
    add_data(patientKeys[10*fold+5], patch4, num4)

    add_data(patientKeys[10*fold+6], patch2, num2)
    add_data(patientKeys[10*fold+6], patch5, num5)

    add_data(patientKeys[10*fold+7], patch3, num3)
    add_data(patientKeys[10*fold+7], patch4, num4)

    add_data(patientKeys[10*fold+8], patch3, num3)
    add_data(patientKeys[10*fold+8], patch5, num5)

    add_data(patientKeys[10*fold+9], patch4, num4)
    add_data(patientKeys[10*fold+9], patch5, num5)

# Visualization

In [12]:
# import pickle
# with open(r'./data/patient_gumbel_test.pickle', 'rb') as handle:
#     patientDict = pickle.load(handle)

In [13]:
# Create widgets
id_ = \
widgets.Dropdown(
    options = patientDict.keys(),
    description='Patient ID: '
)
graphs  = widgets.IntSlider(
              min=0,
              max=len(patientDict[id_.value])-1,
              step=1,
              description='Graph Index: ',
              orientation='horizontal',
              continuous_update = False
)

# Update graph options based on patient id
def update_graphs(*args):
    graphs.max = len(patientDict[id_.value])-1
# Tie graph options to patient id
id_.observe(update_graphs, 'value')

nodeColorsDict = {'Tumor': 'c', 'Stroma': 'y', 'TIL1': 'r', 'TIL2': 'b', 'NK': 'g', 'MP': 'orange'}
def graph_visualization(id_, graphs):
    plt.figure(figsize = (8, 8))
    G = patientDict[id_][graphs]
    posDict = nx.get_node_attributes(G, 'pos')
    for label in nodeColorsDict:
        plt.plot([0], [0], color=nodeColorsDict[label], label=label)
    nodeColorList = [nodeColorsDict[i] for i in list(nx.get_node_attributes(G, 'cell_types').values())]
    nx.draw_networkx(G, pos=posDict, with_labels=False, node_size=10, node_color=nodeColorList)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.show()

_ = interact(graph_visualization, id_=id_, graphs=graphs)
    

interactive(children=(Dropdown(description='Patient ID: ', options=('0001', '0002', '0003', '0004', '0005', '0…

# Save data

In [5]:
# choose one out of three
# if not os.path.exists(r'./data/patient_gumbel_train.pickle'):
#     with open(r'./data/patient_gumbel_train.pickle', 'wb') as handle:
#         pickle.dump(patientDict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
# if not os.path.exists(r'./data/patient_gumbel_val.pickle'):
#     with open(r'./data/patient_gumbel_val.pickle', 'wb') as handle:
#         pickle.dump(patientDict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
# if not os.path.exists(r'./data/patient_gumbel_test.pickle'):
#     with open(r'./data/patient_gumbel_test.pickle', 'wb') as handle:
#         pickle.dump(patientDict, handle, protocol=pickle.HIGHEST_PROTOCOL)