In [62]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn 
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.loader import DataLoader,RandomNodeSampler, NeighborLoader, NeighborSampler
from torch_geometric.utils import structured_negative_sampling, to_dense_adj, erdos_renyi_graph
from torch_geometric.nn import Node2Vec

torch.set_printoptions(precision=4, sci_mode=False)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare Dataset

## Loading biological graph data

In [14]:
def load_edges_and_feature(edge_list_file, node_features_file):
    edges = np.load(edge_list_file)
    assert edges.shape[1] == 2
    features = np.load(node_features_file)
    n_nodes = features.shape[0]
    assert edges.min() >= 0
    assert edges.max() <= n_nodes-1 

    print(f"Number of edges: {edges.shape[0]}, number of nodes: {n_nodes}, feature dim: {features.shape[1]}")
    
    data = Data(x = torch.from_numpy(features).float(), edge_index=torch.from_numpy(edges).t().contiguous())
    return data

def load_edges(edge_list_file):
    edges = np.load(edge_list_file)
    assert edges.shape[1] == 2
    print(f"Number of edges: {edges.shape[0]}, min node index: {edges.min()}, min node index: {edges.max()}")

    data = Data(edge_index=torch.from_numpy(edges).t().contiguous())
    return data 

In [15]:
data = load_edges_and_feature(
    "data/bio-yeast-protein/bio-yeast-protein-inter.edges_edge_list_mapped.npy", 
    "data/bio-yeast-protein/bio-yeast-protein-inter.edges.embeddings_features.npy")
data = data.to(device)

data_no_feat = load_edges("data/bio-yeast-protein/bio-yeast-protein-inter.edges_edge_list_mapped.npy")
data_no_feat = data_no_feat.to(device)


Number of edges: 8959, number of nodes: 1870, feature dim: 128
Number of edges: 8959, min node index: 0, min node index: 1869


## Generate DeepWalk features

https://github.com/phanein/deepwalk

https://github.com/shenweichen/GraphEmbedding

In [None]:
# def generative_graph_deepwalk(num_nodes, edge_index):
#     # er_graph = erdos_renyi_graph(num_nodes=256, edge_prob=0.1, directed=False)
#     er_graph = edge_index
#     model = Node2Vec(er_graph, embedding_dim=128, walk_length=10,
#                      context_size=10, walks_per_node=10,
#                      num_negative_samples=1, p=1, q=1, sparse=True).to(device)
#     loader = model.loader(batch_size=128, shuffle=True, num_workers=4)
#     optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)
    
#     for epoch in range(1, 101):
#         model.train()
#         total_loss = 0
#         for pos_rw, neg_rw in loader:
#             optimizer.zero_grad()
#             loss = model.loss(pos_rw.to(device), neg_rw.to(device))
#             loss.backward()
#             optimizer.step()
#             total_loss += loss.item()
#         train_loss = total_loss / len(loader)
#         print(f'Epoch: {epoch:02d}, Loss: {train_loss:.4f}')
#     z = model(torch.arange(num_nodes, device=device))
#     return Data(x=z, edge_index=er_graph)

In [24]:
# Directly use loaded feature
# TODO: Get deepwalk features on the fly
data 

Data(x=[1870, 128], edge_index=[2, 8959])

# Train GNN Model

## Define model

In [25]:
# Model parameters
feat_dim = data.x.shape[1]

In [33]:
class Network(torch.nn.Module):
    def __init__(self, feat_dim, out_dim):
        super().__init__()

        self.conv1 = GCNConv(feat_dim, 64)
        self.conv2 = GCNConv(64, 64)

        self.lin1 = nn.Linear(64, 32)
        self.lin2 = nn.Linear(32, out_dim)

    def forward(self, x, edge_index):

        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.lin1(x)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)

        return x 

def test_model():
    test_net = Network(128, k_colors).to(device)
    test_out = test_net(data.x, data.edge_index)
    print(test_out.shape)

test_model()

torch.Size([1870, 5])


## Train model

In [35]:
def inner_product_loss(network_out):
    batch_size = network_out.shape[0]//2
    a = network_out[:batch_size, :]
    b = network_out[batch_size:, :]
    mean_dot_prod = torch.mean(torch.sum(a * b, dim=1), dim=0)
    # print(b.argmax(dim=1).shape)
    # same_color = (a.argmax(dim=1) == b.argmax(dim=1)).int().sum()
    # return same_color/a.shape[0]
    return mean_dot_prod


def negative_entropy_loss(network_out):
    # num_nodes x k
    log = torch.log(network_out)
    prod = network_out * log 
    s = torch.sum(prod, dim=-1)
    return s.mean()


def percent_same_color(network_out):
    # percent of pairs having same color
    # [2 * batch_size, feature_size]
    arg_max = torch.argmax(network_out, dim=1)
    arg_max_reshaped = arg_max.reshape([2, -1])
    result = torch.sum((arg_max_reshaped[0,:] == arg_max_reshaped[1,:]).int())/arg_max_reshaped.shape[1]
    # print(result.shape)
    return result
    

def assign_color_sampled_edges(sampled_edges, k):
    # given a sampled batch of edges, assign random color to each end points
    all_nodes = sampled_edges.ravel()
    # print(sampled_edges.shape)
    # colors = torch.randint(low=0, high=k, size=all_nodes.shape)
    color_dict = {node.item():torch.randint(low=0, high=k, size=(1,)).item() for node in all_nodes}
    count_same = 0
    for i in range(sampled_edges.shape[1]):
        assert (sampled_edges[0,i] in all_nodes)
        assert (sampled_edges[1,i] in all_nodes)
        if color_dict[sampled_edges[0,i].item()] == color_dict[sampled_edges[1,i].item()]:
            count_same+=1
    return count_same/sampled_edges.shape[1]


def cosine_sim(node_probs, edge_index): 
    # assuming continous node numbering
    # for each edge, get a loss term, then take mean
    total_loss = 0.
    for i in range(edge_index.shape[1]):
        node1_prob = node_probs[edge_index[0, i],:]
        node2_prob = node_probs[edge_index[1, i],:]
        total_loss += torch.dot(node1_prob, node2_prob)
    return total_loss/edge_index.shape[1]


def same_color_prob(node_probs, edge_index):
    n_same_color = 0
    for i in range(edge_index.shape[1]):
        node1_prob = node_probs[edge_index[0, i],:]
        node2_prob = node_probs[edge_index[1, i],:]
        # print(torch.argmax(node1_prob), torch.argmax(node2_prob))
        if torch.argmax(node1_prob) == torch.argmax(node2_prob):
            # print(node1_prob, node2_prob)
            n_same_color+=1
    return n_same_color/edge_index.shape[1]


def mse_loss(node_probs, edge_index):
    total_loss = 0.
    for i in range(edge_index.shape[1]):
        node1_prob = node_probs[edge_index[0, i],:]
        node2_prob = node_probs[edge_index[1, i],:]
        total_loss += mse_criterion(node1_prob, node2_prob)
    return total_loss/edge_index.shape[1]


def ncut_loss(node_probs, edge_index):
    Y = node_probs 
    A = torch.squeeze(to_dense_adj(edge_index, max_num_nodes=node_probs.shape[0]))
    # print(A.shape, node_probs.shape)
    assert (A.shape[0] == node_probs.shape[0])
    
    D = torch.sum(A, dim=1)
    Gamma = Y.t()@D 
    sum_mat = torch.mul(torch.div(Y, Gamma)@((1-Y).t()), A)
    # sum_mat = torch.mul(Y@((1-Y).t()), A)
    # print(sum_mat)
    ncut = torch.sum(sum_mat)

    return ncut


def test_ncut_loss():
    # edge_index=torch.tensor([[0,1], [1, 0], [0, 3], [3, 0], [1, 2], [2, 1], [2, 3], [3, 2]])
    edge_index = torch.tensor([[0, 1, 0, 3, 1, 2, 2, 3],[1, 0, 3, 0, 2, 1, 3, 2]])
    # A = torch.tensor([
    #     [0, 1, 0, 1],
    #     [1, 0, 1, 0],
    #     [0, 1, 0, 1],
    #     [1, 0, 1, 0]
    # ])
    # print(edge_index)

    A = torch.squeeze(to_dense_adj(edge_index, max_num_nodes=4))
    # print(A)
    Y = torch.tensor([
        [0.5, 0.5],
        [0.25, 0.75],
        [0.75, 0.25],
        [0.1, 0.9]
    ])

    D = torch.tensor([2., 2., 2., 2.])

    Gamma = Y.t()@D
    loss_matrix = ((Y/Gamma)@((1-Y).t()))*A
    # loss_matrix = ((Y)@((1-Y).t()))*A
    # print(loss_matrix)
    correct_loss = loss_matrix.sum()

    # assert (ncut_loss(Y, edge_index) == correct_loss)


def size_reg(node_probs, k_colors):
    sizes = node_probs.sum(dim=1)
    mean_size = node_probs.shape[0]/k_colors
    size_reg = torch.dot(sizes-mean_size, sizes-mean_size)
    return size_reg


In [46]:
k_colors = 4
model = Network(data.x.shape[1], k_colors).to(device)

In [47]:
# loader = NeighborLoader(data, num_neighbors=[10] * 2, shuffle=True, batch_size=10)
lr = 0.001
weight_decay = 5e-4

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

cos_sim = torch.nn.CosineSimilarity()

In [51]:

for i in range(100):
    model.train()
    out = model(data.x, data.edge_index)
    probs = out.softmax(dim=1)
    # print(probs[data.edge_index[0,0]], probs[data.edge_index[1,0]])
    # loss = cosine_sim(probs, data.edge_index)
    # loss = mse_loss(probs, data.edge_index)
    loss = -ncut_loss(probs, data.edge_index)
    # print(i, loss)
    # loss = -ncut_loss(probs, data.edge_index) + 10e-9 * size_reg(probs, k_colors)
    # print(size_reg(probs, k_colors))
    # print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    model.eval()
    out = model(data.x, data.edge_index)
    probs = out.softmax(dim=1)
    collide_prob = cosine_sim(probs, data.edge_index)

    print(f"iteration: {i} loss: {loss.item()}, same color prob {collide_prob}")

iteration: 0 loss: -3.4324069023132324, same color prob 0.18293415009975433
iteration: 1 loss: -3.4294562339782715, same color prob 0.20677298307418823
iteration: 2 loss: -3.458091974258423, same color prob 0.22450792789459229
iteration: 3 loss: -3.433349609375, same color prob 0.20599651336669922
iteration: 4 loss: -3.4459166526794434, same color prob 0.18340717256069183
iteration: 5 loss: -3.4536495208740234, same color prob 0.17122671008110046
iteration: 6 loss: -3.4480795860290527, same color prob 0.167527973651886
iteration: 7 loss: -3.459488868713379, same color prob 0.16992127895355225
iteration: 8 loss: -3.4509153366088867, same color prob 0.18261969089508057
iteration: 9 loss: -3.4428887367248535, same color prob 0.19409187138080597
iteration: 10 loss: -3.4625682830810547, same color prob 0.20731747150421143
iteration: 11 loss: -3.450997829437256, same color prob 0.2113988697528839
iteration: 12 loss: -3.4603829383850098, same color prob 0.2116210162639618
iteration: 13 loss: 

In [61]:

print(probs[:10,:])

tensor([[    0.2751,     0.1258,     0.3271,     0.2720],
        [    0.0082,     0.9753,     0.0090,     0.0075],
        [    0.3097,     0.0088,     0.3695,     0.3120],
        [    0.3146,     0.0009,     0.3724,     0.3121],
        [    0.3144,     0.0001,     0.3736,     0.3119],
        [    0.3146,     0.0010,     0.3724,     0.3121],
        [    0.0007,     0.9979,     0.0008,     0.0006],
        [    0.3142,     0.0022,     0.3712,     0.3123],
        [    0.2866,     0.0871,     0.3415,     0.2848],
        [    0.3023,     0.0252,     0.3666,     0.3060]], device='cuda:0',
       grad_fn=<SliceBackward>)


## Sample colors 
- Sample colors from trained model
- Sample colors from uniform distribution

# Dynamic Programming to find K-path

# Analysis and Plotting