In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
 
# The GPU id to use, usually either "0" or "1";
os.environ["CUDA_VISIBLE_DEVICES"]="2";  

In [None]:
import numpy as np
import tensorflow as tf
import random as rn

# The below is necessary for starting Numpy generated random numbers
# in a well-defined initial state.

np.random.seed(42)

# The below is necessary for starting core Python generated random numbers
# in a well-defined state.

rn.seed(12345)

# Force TensorFlow to use single thread.
# Multiple threads are a potential source of non-reproducible results.
# For further details, see: https://stackoverflow.com/questions/42022950/

session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
                              inter_op_parallelism_threads=1)

from tensorflow.keras import backend as K

# The below tf.set_random_seed() will make random number generation
# in the TensorFlow backend have a well-defined initial state.
# For further details, see:
# https://www.tensorflow.org/api_docs/python/tf/set_random_seed

tf.set_random_seed(1234)

sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)

In [None]:
import networkx as nx
import pandas as pd
import numpy as np
import os
import random
import matplotlib.pyplot as plt

from tqdm import tqdm
from scipy.spatial import cKDTree as KDTree
from tensorflow.keras.utils import to_categorical

import stellargraph as sg
from stellargraph.data import EdgeSplitter
from stellargraph.mapper import GraphSAGELinkGenerator
from stellargraph.layer import GraphSAGE, link_classification
from stellargraph.layer.graphsage import AttentionalAggregator
from stellargraph.data import UniformRandomWalk
from stellargraph.data import UnsupervisedSampler
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from sklearn import preprocessing, feature_extraction, model_selection
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.metrics import accuracy_score

from stellargraph import globalvar

In [None]:
from numpy.random import seed
seed(42)
from tensorflow import set_random_seed
set_random_seed(42)

In [None]:
def plotNeighbor(barcodes_df):
    d_list=[]
    for exp in barcodes_df.experiment.unique():
        cells = barcodes_df.loc[barcodes_df.experiment==exp, 'cellID'].unique()
        for cell in cells:
            barcodes_df_tmp = barcodes_df.loc[(barcodes_df.experiment==exp) & (barcodes_df.cellID==cell),:].copy()
            barcodes_df_tmp.reset_index(drop=True, inplace=True)
            if not barcodes_df_tmp.empty:
                # Find mean distance to nearest neighbor
                kdT = KDTree(np.array([barcodes_df_tmp.RNACentroidX.values,barcodes_df_tmp.RNACentroidY.values]).T)
                d,i = kdT.query(np.array([barcodes_df_tmp.RNACentroidX.values,barcodes_df_tmp.RNACentroidY.values]).T,k=2)
                d_list.append(d)
    d = np.vstack(d_list)
    print(d.shape)
    plt.hist(d[:,1],bins=200);
    plt.axvline(x=np.percentile(d[:,1],97),c='r')
    print(np.percentile(d[:,1],97))
    d_th = np.percentile(d[:,1],97)
    return d_th
#     plt.xlim([0,5])

## Download spatial gene expression data

In [None]:
! wget http://zhuang.harvard.edu/MERFISHData/140genesData.xlsx -O ../data/MERFISH_Chen_et_al_2015/barcodes.xlsx

## Load spatial gene expression data

In [17]:
barcodes_df = pd.read_excel("../data/MERFISH_Chen_et_al_2015/barcodes.xlsx", sep=',', names=['experiment','library','cellID','intCodeword','geneName','isExactMatch','isCorrectedMatch','CellPositionX','CellPositionY','RNACentroidX','RNACentroidY'], header=0)
barcodes_df.shape

Unnamed: 0,experiment,library,cellID,intCodeword,geneName,isExactMatch,isCorrectedMatch,CellPositionX,CellPositionY,RNACentroidX,RNACentroidY
0,1,1,0,33796,SCUBE3,0,1,475.5,630.6,78.723714,154.452489
1,1,1,0,34048,SCUBE3,0,1,475.5,630.6,81.297819,229.918727
2,1,1,0,33794,SON,0,1,475.5,630.6,92.627268,212.018163
3,1,1,0,32802,AFF4,0,1,475.5,630.6,101.404081,220.580093
4,1,1,0,33856,FOSB,0,1,475.5,630.6,107.676392,173.956553


In [None]:
# Remove unsassigned barcodes in Moffit et al.
remove_genes = ['blank001', 'blank002', 'blank003', 'blank004', 'blank005',
       'notarget001', 'notarget002', 'notarget003', 'notarget004',
       'notarget005']
barcodes_df = barcodes_df[~barcodes_df.geneName.isin(remove_genes)]
barcodes_df.reset_index(drop=True, inplace=True)
barcodes_df.shape

In [None]:
d_th = plotNeighbor(barcodes_df)

In [None]:
def buildGraph(barcodes_df, d_th):
    G = nx.Graph()
    n =0
    for exp in barcodes_df.experiment.unique():
        cells = barcodes_df.loc[barcodes_df.experiment==exp, 'cellID'].unique()
        for cell in cells:
            barcodes_df_tmp = barcodes_df.loc[(barcodes_df.experiment==exp) & (barcodes_df.cellID==cell),:].copy()
            barcodes_df_tmp.reset_index(drop=True, inplace=True)
            
            if not barcodes_df_tmp.empty:
                gene_list = barcodes_df.geneName.unique()
                # add attributes to df
                one_hot_encoding = dict(zip(gene_list,to_categorical(np.arange(gene_list.shape[0]),num_classes=gene_list.shape[0]).tolist()))
                barcodes_df_tmp["feature"] = barcodes_df_tmp['geneName'].map(one_hot_encoding).tolist()

                kdT = KDTree(np.array([barcodes_df_tmp.RNACentroidX.values,barcodes_df_tmp.RNACentroidY.values]).T)
                res = kdT.query_pairs(d_th)
                res = [(x[0]+n,x[1]+n) for x in list(res)]

                # Add nodes
                G.add_nodes_from((barcodes_df_tmp.index.values+n), test=False, val=False, label=0)
                nx.set_node_attributes(G,dict(zip((barcodes_df_tmp.index.values+n), barcodes_df_tmp.feature)), 'feature')
                # Add edges
                G.add_edges_from(res)

                n = n + barcodes_df_tmp.shape[0]

    return G

In [None]:
G = buildGraph(barcodes_df, d_th)

In [None]:
barcodes_df.shape

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10,10))
for s in range(1,4):
    plt.subplot(3,1,s)
    X = barcodes_df[(barcodes_df.cellID==0) & (barcodes_df.experiment==s)].RNACentroidX
    Y = barcodes_df[(barcodes_df.cellID==0) & (barcodes_df.experiment==s)].RNACentroidY

    plt.scatter(X,Y,s=0.1)
    plt.axis('scaled')

In [None]:
G.number_of_edges()

In [None]:
G.number_of_nodes()

In [None]:
# Remove components with less than N nodes
N=3
for component in tqdm(list(nx.connected_components(G))):
    if len(component)<N:
        for node in component:
            G.remove_node(node)

In [None]:
G.number_of_edges()

In [None]:
G.number_of_nodes()

In [None]:
np.sum(list(dict(G.degree()).values()))/G.number_of_nodes()

#### 1. Create the Stellargraph with node features.

In [None]:
G = sg.StellarGraph(G, node_features="feature")

In [None]:
print(G.info())

#### 2. Specify the other optional parameter values: root nodes, the number of walks to take per node, the length of each walk, and random seed.

In [None]:
nodes = list(G.nodes())
number_of_walks = 1
length = 2

#### 3. Create the UnsupervisedSampler instance with the relevant parameters passed to it.

In [None]:
unsupervised_samples = UnsupervisedSampler(G, nodes=nodes, length=length, number_of_walks=number_of_walks, seed=42)

#### 4. Create a node pair generator:

In [None]:
batch_size = 50
epochs = 3
num_samples = [20, 10]

In [None]:
train_gen = GraphSAGELinkGenerator(G, batch_size, num_samples, seed=42).flow(unsupervised_samples)

In [None]:
layer_sizes = [50, 50]
assert len(layer_sizes) == len(num_samples)

graphsage = GraphSAGE(
    layer_sizes=layer_sizes, generator=train_gen, aggregator=AttentionalAggregator, bias=True, dropout=0.0, normalize="l2", kernel_regularizer='l1'
)

In [None]:
# Build the model and expose input and output sockets of graphsage, for node pair inputs:
x_inp, x_out = graphsage.build()

In [None]:
prediction = link_classification(
    output_dim=1, output_act="sigmoid", edge_embedding_method='ip'
)(x_out)

In [None]:
import os, datetime

logdir = os.path.join("logs", datetime.datetime.now().strftime("MERFISH-%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir)
earlystop_callback = tf.keras.callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, patience=1)

model = keras.Model(inputs=x_inp, outputs=prediction)

model.compile(
    optimizer=keras.optimizers.Adam(lr=0.5e-4),
    loss=keras.losses.binary_crossentropy,
    metrics=[keras.metrics.binary_accuracy]
)

model.summary()

In [None]:
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

history = model.fit_generator(
    train_gen,
    epochs=epochs,
    verbose=1,
    use_multiprocessing=True,
    workers=12,
    shuffle=True,
    callbacks=[tensorboard_callback,earlystop_callback]
)

### Extracting node embeddings

In [None]:
from stellargraph.mapper import GraphSAGENodeGenerator
import pandas as pd
import numpy as np

In [None]:
x_inp_src = x_inp[0::2]
x_out_src = x_out[0]
embedding_model = keras.Model(inputs=x_inp_src, outputs=x_out_src)


In [None]:
# Save the model
embedding_model.save('../models/MERFISH_Chen_et_al/nn_model.h5')

# Recreate the exact same model purely from the file
embedding_model = keras.models.load_model('../models/MERFISH_Chen_et_al/nn_model.h5', custom_objects={'AttentionalAggregator':AttentionalAggregator})

In [None]:
embedding_model.summary()

In [None]:
embedding_model.compile(
    optimizer=keras.optimizers.Adam(lr=0.5e-4),
    loss=keras.losses.binary_crossentropy,
    metrics=[keras.metrics.binary_accuracy]
)

In [None]:
nodes = list(G.nodes())
node_gen = GraphSAGENodeGenerator(G, batch_size, num_samples, seed=42).flow(nodes)

In [None]:
node_embeddings = embedding_model.predict_generator(node_gen, workers=12, verbose=1)

In [None]:
np.save('../results/MERFISH_Chen_et_al/embedding_MERFISH_Chen_et_al.npy',node_embeddings)

In [None]:
quit()