In [21]:
import torch
import torch.nn as nn 
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
import torch.nn.functional as F

from torch_geometric.loader import DataLoader,RandomNodeSampler, NeighborLoader, NeighborSampler
from torch_geometric.utils import structured_negative_sampling, to_dense_adj, erdos_renyi_graph
from torch_geometric.nn import Node2Vec

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data

In [47]:
def load_data(edge_list_file, node_features_file):
    # edge_list, [n_edges, 2]
    # node_features, [n_nodes, dim_features]
    edges = np.load(edge_list_file)
    assert edges.shape[1] == 2

    features = np.load(node_features_file)
    n_nodes = features.shape[0]
    # print(features.shape)
    assert edges.min() >= 0
    assert edges.max() <= n_nodes-1
    # print(edges.min())
    print(f"Number of edges: {edges.shape[0]}, number of nodes: {n_nodes}, feature dim: {features.shape[1]}")
    
    data = Data(x = torch.from_numpy(features).float(), edge_index=torch.from_numpy(edges).t().contiguous())
    return data



def generative_er_graph_deepwalk(num_nodes, edge_index):
    # er_graph = erdos_renyi_graph(num_nodes=256, edge_prob=0.1, directed=False)
    er_graph = edge_index
    model = Node2Vec(er_graph, embedding_dim=128, walk_length=10,
                     context_size=10, walks_per_node=10,
                     num_negative_samples=1, p=1, q=1, sparse=True).to(device)
    loader = model.loader(batch_size=128, shuffle=True, num_workers=4)
    optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)
    
    for epoch in range(1, 101):
        model.train()
        total_loss = 0
        for pos_rw, neg_rw in loader:
            optimizer.zero_grad()
            loss = model.loss(pos_rw.to(device), neg_rw.to(device))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        train_loss = total_loss / len(loader)
        print(f'Epoch: {epoch:02d}, Loss: {train_loss:.4f}')
    z = model(torch.arange(num_nodes, device=device))
    return Data(x=z, edge_index=er_graph)

In [45]:
_data = load_data("data/bio-yeast-protein/bio-yeast-protein-inter.edges_edge_list_mapped.npy", 
"data/bio-yeast-protein/bio-yeast-protein-inter.edges.embeddings_features.npy")
_data = data.to(device)

Number of edges: 8959, number of nodes: 1870, feature dim: 128


In [48]:
data = generative_er_graph_deepwalk(num_nodes=1870, edge_index=_data.edge_index)

Epoch: 01, Loss: 9.0841
Epoch: 02, Loss: 7.6387
Epoch: 03, Loss: 6.8107
Epoch: 04, Loss: 6.2232
Epoch: 05, Loss: 5.8814
Epoch: 06, Loss: 5.5677
Epoch: 07, Loss: 5.2296
Epoch: 08, Loss: 5.0321
Epoch: 09, Loss: 4.8172
Epoch: 10, Loss: 4.5485
Epoch: 11, Loss: 4.3611
Epoch: 12, Loss: 4.1457
Epoch: 13, Loss: 3.9777
Epoch: 14, Loss: 3.8263
Epoch: 15, Loss: 3.6091
Epoch: 16, Loss: 3.4510
Epoch: 17, Loss: 3.3191
Epoch: 18, Loss: 3.1439
Epoch: 19, Loss: 3.0125
Epoch: 20, Loss: 2.9010
Epoch: 21, Loss: 2.7806
Epoch: 22, Loss: 2.6766
Epoch: 23, Loss: 2.5406
Epoch: 24, Loss: 2.4355
Epoch: 25, Loss: 2.3520
Epoch: 26, Loss: 2.2701
Epoch: 27, Loss: 2.1848
Epoch: 28, Loss: 2.0998
Epoch: 29, Loss: 2.0315
Epoch: 30, Loss: 1.9587
Epoch: 31, Loss: 1.8836
Epoch: 32, Loss: 1.8295
Epoch: 33, Loss: 1.7620
Epoch: 34, Loss: 1.7199
Epoch: 35, Loss: 1.6574
Epoch: 36, Loss: 1.6091
Epoch: 37, Loss: 1.5666
Epoch: 38, Loss: 1.5250
Epoch: 39, Loss: 1.4836
Epoch: 40, Loss: 1.4403
Epoch: 41, Loss: 1.4168
Epoch: 42, Loss:

In [49]:
data = data.to(device)

In [50]:
data.x = data.x.data 

In [122]:
# data = load_data("data/PP/ppi_edge_list_mapped.npy", "data/PP/ppi_embeddings_features.npy")

# data = load_data("data/bio-WormNet/bio-WormNet-v3-benchmark.edges_edge_list_mapped.npy", 
# "data/bio-WormNet/bio-WormNet-v3-benchmark.edges.embeddings_features.npy")



Number of edges: 8959, number of nodes: 1870, feature dim: 128


# Define Model

In [60]:
feat_dim = data.x.shape[1]
k_colors = 5

In [61]:
class Network(torch.nn.Module):
    def __init__(self, feat_dim, out_dim):
        super().__init__()

        # self.conv1 = GATConv(feat_dim, 64)
        # self.conv2 = GATConv(64, 64)

        self.conv1 = GCNConv(feat_dim, 64)
        self.conv2 = GCNConv(64, 64)
        # self.conv3 = GCNConv(512, 512)
        # self.conv4 = GCNConv(32, 32)

        self.lin1 = nn.Linear(64, 32)
        self.lin2 = nn.Linear(32, out_dim)
        # self.lin3 = nn.Linear(64, out_dim)


    def forward(self, x, edge_index):

        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        # x = self.conv3(x, edge_index)
        # x = x.relu()
        # x = F.dropout(x, p=0.5, training=self.training)
        # x = self.conv4(x, edge_index)
        # x = x.relu()
        # x = F.dropout(x, p=0.5, training=self.training)

        x = self.lin1(x)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        # x = x.relu()
        # x = F.dropout(x, p=0.5, training=self.training)
        # x = self.lin3(x)
        
        return x # return logits instead of probs 



In [62]:
data.x.shape 

torch.Size([1870, 128])

In [47]:


# test_net = Network(128, k_colors).to(device)

# test_out = test_net(data.x, data.edge_index)
# print(test_out.shape)

torch.Size([1870, 4])


In [63]:
# def cosine_loss 

def inner_product_loss(network_out):
    batch_size = network_out.shape[0]//2
    a = network_out[:batch_size, :]
    b = network_out[batch_size:, :]
    mean_dot_prod = torch.mean(torch.sum(a * b, dim=1), dim=0)
    # print(b.argmax(dim=1).shape)
    # same_color = (a.argmax(dim=1) == b.argmax(dim=1)).int().sum()
    # return same_color/a.shape[0]
    return mean_dot_prod

def negative_entropy_loss(network_out):
    # num_nodes x k
    log = torch.log(network_out)
    prod = network_out * log 
    s = torch.sum(prod, dim=-1)
    return s.mean()

def percent_same_color(network_out):
    # percent of pairs having same color
    # [2 * batch_size, feature_size]
    arg_max = torch.argmax(network_out, dim=1)
    arg_max_reshaped = arg_max.reshape([2, -1])
    result = torch.sum((arg_max_reshaped[0,:] == arg_max_reshaped[1,:]).int())/arg_max_reshaped.shape[1]
    # print(result.shape)
    return result
    
def assign_color_sampled_edges(sampled_edges, k):
    # given a sampled batch of edges, assign random color to each end points
    all_nodes = sampled_edges.ravel()
    # print(sampled_edges.shape)
    # colors = torch.randint(low=0, high=k, size=all_nodes.shape)
    color_dict = {node.item():torch.randint(low=0, high=k, size=(1,)).item() for node in all_nodes}
    count_same = 0
    for i in range(sampled_edges.shape[1]):
        assert (sampled_edges[0,i] in all_nodes)
        assert (sampled_edges[1,i] in all_nodes)
        if color_dict[sampled_edges[0,i].item()] == color_dict[sampled_edges[1,i].item()]:
            count_same+=1
    return count_same/sampled_edges.shape[1]

# test = negative_entropy_loss(out)
# test

In [64]:



# loader = NeighborLoader(data, num_neighbors=[10] * 2, shuffle=True, batch_size=10)


model = Network(data.x.shape[1], k_colors).to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003, weight_decay=5e-4)

cos_sim = torch.nn.CosineSimilarity()



In [65]:


def cosine_sim(node_probs, edge_index): 
    # assuming continous node numbering
    # for each edge, get a loss term, then take mean
    total_loss = 0.
    for i in range(edge_index.shape[1]):
        node1_prob = node_probs[edge_index[0, i],:]
        node2_prob = node_probs[edge_index[1, i],:]
        total_loss += torch.dot(node1_prob, node2_prob)
    return total_loss/edge_index.shape[1]

def same_color_prob(node_probs, edge_index):
    n_same_color = 0
    for i in range(edge_index.shape[1]):
        node1_prob = node_probs[edge_index[0, i],:]
        node2_prob = node_probs[edge_index[1, i],:]
        # print(torch.argmax(node1_prob), torch.argmax(node2_prob))
        if torch.argmax(node1_prob) == torch.argmax(node2_prob):
            # print(node1_prob, node2_prob)
            n_same_color+=1
    return n_same_color/edge_index.shape[1]

mse_criterion = torch.nn.MSELoss()

def mse_loss(node_probs, edge_index):
    total_loss = 0.
    for i in range(edge_index.shape[1]):
        node1_prob = node_probs[edge_index[0, i],:]
        node2_prob = node_probs[edge_index[1, i],:]
        total_loss += mse_criterion(node1_prob, node2_prob)
    return total_loss/edge_index.shape[1]

def ncut_loss(node_probs, edge_index):
    Y = node_probs 
    A = torch.squeeze(to_dense_adj(edge_index, max_num_nodes=node_probs.shape[0]))
    # print(A.shape, node_probs.shape)
    assert (A.shape[0] == node_probs.shape[0])
    
    D = torch.sum(A, dim=1)
    Gamma = Y.t()@D 
    sum_mat = torch.mul(torch.div(Y, Gamma)@((1-Y).t()), A)
    # sum_mat = torch.mul(Y@((1-Y).t()), A)
    # print(sum_mat)
    ncut = torch.sum(sum_mat)

    return ncut


def test_ncut_loss():
    # edge_index=torch.tensor([[0,1], [1, 0], [0, 3], [3, 0], [1, 2], [2, 1], [2, 3], [3, 2]])
    edge_index = torch.tensor([[0, 1, 0, 3, 1, 2, 2, 3],[1, 0, 3, 0, 2, 1, 3, 2]])
    # A = torch.tensor([
    #     [0, 1, 0, 1],
    #     [1, 0, 1, 0],
    #     [0, 1, 0, 1],
    #     [1, 0, 1, 0]
    # ])
    # print(edge_index)

    A = torch.squeeze(to_dense_adj(edge_index, max_num_nodes=4))
    # print(A)
    Y = torch.tensor([
        [0.5, 0.5],
        [0.25, 0.75],
        [0.75, 0.25],
        [0.1, 0.9]
    ])

    D = torch.tensor([2., 2., 2., 2.])

    Gamma = Y.t()@D
    loss_matrix = ((Y/Gamma)@((1-Y).t()))*A
    # loss_matrix = ((Y)@((1-Y).t()))*A
    # print(loss_matrix)
    correct_loss = loss_matrix.sum()

    # assert (ncut_loss(Y, edge_index) == correct_loss)

test_ncut_loss()

def size_reg(node_probs, k_colors):
    sizes = node_probs.sum(dim=1)
    mean_size = node_probs.shape[0]/k_colors
    size_reg = torch.dot(sizes-mean_size, sizes-mean_size)
    return size_reg



In [70]:

for i in range(100):
    model.train()
    out = model(_data.x, _data.edge_index)
    probs = out.softmax(dim=1)
    # print(probs[data.edge_index[0,0]], probs[data.edge_index[1,0]])
    # loss = cosine_sim(probs, data.edge_index)
    # loss = mse_loss(probs, data.edge_index)
    loss = -ncut_loss(probs, _data.edge_index)
    # print(i, loss)
    # loss = -ncut_loss(probs, data.edge_index) + 10e-9 * size_reg(probs, k_colors)
    # print(size_reg(probs, k_colors))
    # print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # testing
    model.eval()
    # # edges_chosen = data.edge_index[:,:data.edge_index.shape[1]//3]
    # edges_chosen = data.edge_index
    out = model(_data.x, _data.edge_index)
    probs = out.softmax(dim=1)
    collide_prob = cosine_sim(probs, data.edge_index)

    print(f"iteration: {i} loss: {loss.item()}, same color prob {collide_prob}")

iteration: 0 loss: -4.480984687805176, same color prob 0.1934989094734192
iteration: 1 loss: -4.475341796875, same color prob 0.19673869013786316
iteration: 2 loss: -4.459436416625977, same color prob 0.17213159799575806
iteration: 3 loss: -4.474292755126953, same color prob 0.15670636296272278
iteration: 4 loss: -4.468151092529297, same color prob 0.144439235329628
iteration: 5 loss: -4.452651023864746, same color prob 0.14735068380832672
iteration: 6 loss: -4.470206260681152, same color prob 0.1559072732925415
iteration: 7 loss: -4.472042560577393, same color prob 0.16430611908435822
iteration: 8 loss: -4.476047515869141, same color prob 0.18292781710624695
iteration: 9 loss: -4.470345497131348, same color prob 0.18367737531661987
iteration: 10 loss: -4.485378265380859, same color prob 0.17225442826747894
iteration: 11 loss: -4.463540077209473, same color prob 0.1613338738679886
iteration: 12 loss: -4.453773498535156, same color prob 0.15815061330795288
iteration: 13 loss: -4.4913587

In [77]:
np_probs = probs.detach().to('cpu').numpy()

In [81]:
# print(-ncut_loss(probs, data.edge_index))
# size_reg(probs, k_colors)*1e-9
np.save('bio-yeast-protein_color_probs', np_probs)

0.8535

In [258]:
# import random

# num_edges = data.edge_index.shape[1]
# # batch_size = 20000
# # batch_size = data.x.shape[0]
# batch_size = data.edge_index.shape[1]
# print(batch_size)
# # batch_size = 1000
# num_batches = num_edges//batch_size + 1
# index_list = list(range(num_edges))
# alpha = 1.

# model.train()
# for epoch in range(10):
#     random.shuffle(index_list)
#     print(epoch)
#     # for batch_index in range(num_batches):
#     for batch_index in range(1):
#         sampled_edge_index = index_list[batch_size*batch_index : min(batch_size*(batch_index+1), num_edges)]
#         # print(len(sampled_edge_index))
#         # print(min(batch_size*(batch_index+1), num_edges))
#         sampled_edges = data.edge_index[:, sampled_edge_index]
#         uniform_random_ratio = assign_color_sampled_edges(sampled_edges, k_colors)
#         # if batch_index == 10:
#         sampled_nodes = sampled_edges.ravel()
#         # print(sampled_nodes.device)

#         # print(sampled_nodes.shape)
#         # print(sampled_nodes.dtype)
#         _, n_id, adjs = sampler.sample(sampled_nodes.to('cpu'))
#         # print(n_id.shape)
#         optimizer.zero_grad()
#         out = model(data.x[n_id], adjs)

#         current_batch_size = out.shape[0]//2
#         # print(current_batch_size)
#             # print('======')
#         # print(out[0,:])
#         # print(out[current_batch_size,:])
#         # print(current_batch_size.dtype)
#         # loss = -criterion(out[:current_batch_size,:], out[current_batch_size:, :]) 
#         # loss = -cos_sim(out[:current_batch_size,:], out[current_batch_size:, :]).mean() 
        
#         # loss = -criterion(out[:batch_size,:], out[batch_size:, :]) + alpha * negative_entropy_loss(out)
#         # loss = inner_product_loss(out) - alpha * negative_entropy_loss(out)
#         loss = inner_product_loss(out)
#         # print(loss.item(), percent_same_color(out).item(), uniform_random_ratio)
#         print(loss.item(), inner_product_loss(out).item(), uniform_random_ratio)
#         # print(loss.item(), inner_product_loss(out).item())
#         # print()
#         loss.backward()
#         optimizer.step()
        

    
    



8959
0
0.5023680925369263 0.5023680925369263 0.5259515570934256
1
0.5000861287117004 0.5000861287117004 0.5018417234066302
2
0.5001647472381592 0.5001647472381592 0.5005022882018082
3
0.5003552436828613 0.5003552436828613 0.512557205045206
4
0.50022292137146 0.50022292137146 0.5182498046656993
5
0.5000616908073425 0.5000616908073425 0.4875544145551959
6
0.4999978244304657 0.4999978244304657 0.5093202366335529
7
0.5000165700912476 0.5000165700912476 0.5035160174126576
8
0.5000544190406799 0.5000544190406799 0.5138966402500279
9
0.5000529885292053 0.5000529885292053 0.5214867730773524


tensor([1, 2, 3, 4])


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [7]:
model.train()
margin = 0.01
train_running_loss = 0.0
for sample in loader:
    # print(device)
    sample.x = sample.x.to(device).to(dtype=torch.float32)
    sample.x.requires_grad=True
    # sample.x = requires_grad=True
    sample.edge_index = sample.edge_index.to(device).long()
    # print(sample.x.is_cuda)
    # print(type(sample.edge_index))
    # print(sample.edge_index.is_cuda)
    optimizer.zero_grad()
    i, j, k = structured_negative_sampling(sample.edge_index)
    negatives = (i,k)   #not neighbors
    positives = (i,j)   #neighbors 
    output = model(sample)
    #pos = model.similarity(sample.x[i], sample.x[j])
    #neg = model.similarity(sample.x[i], sample.x[k])
    pos = model.similarity(output[i], output[j])
    neg = model.similarity(output[i], output[k])
    diff =pos.diag() -neg.diag() +margin      # Note for coloring, we want negatives closer and positives further
    triplet_loss_matrix = diff.mean()
    loss = triplet_loss_matrix
    loss.backward()
    optimizer.step()
    train_running_loss += loss.detach().item()