In [1]:
import torch
import torch.nn as nn 
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from torch_geometric.nn import GCNConv, SAGEConv 
import torch.nn.functional as F

from torch_geometric.loader import DataLoader,RandomNodeSampler, NeighborLoader, NeighborSampler
from torch_geometric.utils import structured_negative_sampling

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data

In [2]:
def load_data(edge_list_file, node_features_file):
    # edge_list, [n_edges, 2]
    # node_features, [n_nodes, dim_features]
    edges = np.load(edge_list_file)
    assert edges.shape[1] == 2

    features = np.load(node_features_file)
    n_nodes = features.shape[0]
    
    assert edges.min() >= 0
    assert edges.max() <= n_nodes-1
    
    print(f"Number of edges: {edges.shape[0]}, number of nodes: {n_nodes}, feature dim: {features.shape[1]}")
    
    data = Data(x = torch.from_numpy(features).float(), edge_index=torch.from_numpy(edges).t().contiguous())
    return data


In [3]:
# data = load_data("data/PP/ppi_edge_list_mapped.npy", "data/PP/ppi_embeddings_features.npy")
data = load_data("data/bio-yeast-protein/bio-yeast-protein-inter.edges_edge_list_mapped.npy", 
"data/bio-yeast-protein/bio-yeast-protein-inter.edges.embeddings_features.npy")
data = data.to(device)

Number of edges: 8959, number of nodes: 1870, feature dim: 128


# Define Model

In [253]:
feat_dim = data.x.shape[1]
k_colors = 2
graph_layers_sizes = [feat_dim, 64, 64, 64]
mlp_sizes = [64, 64, 32, k_colors]

In [254]:
class Network(torch.nn.Module):
    def __init__(self, gl_sizes, ll_sizes):
        super().__init__()

        self.gl_sizes = gl_sizes
        self.ll_sizes = ll_sizes
        self.num_conv_layers = len(gl_sizes)-1
        self.num_lin_layers = len(ll_sizes)-1

        self.convs = torch.nn.ModuleList([SAGEConv(gl_sizes[i], gl_sizes[i+1]) for i in range(self.num_conv_layers)])
        self.mlp = torch.nn.ModuleList([nn.Linear(ll_sizes[i], ll_sizes[i+1]) for i in range(self.num_lin_layers)])

    def forward(self, x, adjs):
        # x node features of batch
        # adjs, sampled bipartite graph

        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]
            edge_index = edge_index.to(device)
            x = self.convs[i]((x, x_target), edge_index)
            if i < self.num_conv_layers-1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)  
        
        assert x.shape == (x.shape[0], self.gl_sizes[-1])

        # apply mlp for each node 
        for i, layer in enumerate(self.mlp):
            x = layer(x)
            if i < self.num_lin_layers-1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training) 
        
        # x: [num_nodes, k]

        assert x.shape == (x.shape[0], self.ll_sizes[-1])
        
        return x.softmax(dim=-1)



In [255]:
sampler = NeighborSampler(
    edge_index=data.edge_index,
    # sizes = [15, 15, 15] # same as number of graph layers 
    sizes = [-1, -1, -1]
)


test_net = Network(graph_layers_sizes, mlp_sizes).to(device)
print(test_net.convs)
print(test_net.gl_sizes)
print(test_net.mlp)
print(test_net.ll_sizes)

_, _n_id, _adjs = sampler.sample([1,2])
test_out = test_net(data.x[_n_id], _adjs)
print(test_out)

ModuleList(
  (0): SAGEConv(128, 64)
  (1): SAGEConv(64, 64)
  (2): SAGEConv(64, 64)
)
[128, 64, 64, 64]
ModuleList(
  (0): Linear(in_features=64, out_features=64, bias=True)
  (1): Linear(in_features=64, out_features=32, bias=True)
  (2): Linear(in_features=32, out_features=2, bias=True)
)
[64, 64, 32, 2]
tensor([[0.4576, 0.5424],
        [0.5058, 0.4942]], device='cuda:0', grad_fn=<SoftmaxBackward>)


In [256]:
# def cosine_loss 

def inner_product_loss(network_out):
    batch_size = network_out.shape[0]//2
    a = network_out[:batch_size, :]
    b = network_out[batch_size:, :]
    mean_dot_prod = torch.mean(torch.sum(a * b, dim=1), dim=0)
    # print(b.argmax(dim=1).shape)
    # same_color = (a.argmax(dim=1) == b.argmax(dim=1)).int().sum()
    # return same_color/a.shape[0]
    return mean_dot_prod

def negative_entropy_loss(network_out):
    # num_nodes x k
    log = torch.log(network_out)
    prod = network_out * log 
    s = torch.sum(prod, dim=-1)
    return s.mean()

def percent_same_color(network_out):
    # percent of pairs having same color
    # [2 * batch_size, feature_size]
    arg_max = torch.argmax(network_out, dim=1)
    arg_max_reshaped = arg_max.reshape([2, -1])
    result = torch.sum((arg_max_reshaped[0,:] == arg_max_reshaped[1,:]).int())/arg_max_reshaped.shape[1]
    # print(result.shape)
    return result
    
def assign_color_sampled_edges(sampled_edges, k):
    # given a sampled batch of edges, assign random color to each end points
    all_nodes = sampled_edges.ravel()
    # print(sampled_edges.shape)
    # colors = torch.randint(low=0, high=k, size=all_nodes.shape)
    color_dict = {node.item():torch.randint(low=0, high=k, size=(1,)).item() for node in all_nodes}
    count_same = 0
    for i in range(sampled_edges.shape[1]):
        assert (sampled_edges[0,i] in all_nodes)
        assert (sampled_edges[1,i] in all_nodes)
        if color_dict[sampled_edges[0,i].item()] == color_dict[sampled_edges[1,i].item()]:
            count_same+=1
    return count_same/sampled_edges.shape[1]

# test = negative_entropy_loss(out)
# test

In [257]:



# loader = NeighborLoader(data, num_neighbors=[10] * 2, shuffle=True, batch_size=10)


model = Network(graph_layers_sizes, mlp_sizes).to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003, weight_decay=5e-4)

cos_sim = torch.nn.CosineSimilarity()



In [258]:
import random

num_edges = data.edge_index.shape[1]
# batch_size = 20000
# batch_size = data.x.shape[0]
batch_size = data.edge_index.shape[1]
print(batch_size)
# batch_size = 1000
num_batches = num_edges//batch_size + 1
index_list = list(range(num_edges))
alpha = 1.

model.train()
for epoch in range(10):
    random.shuffle(index_list)
    print(epoch)
    # for batch_index in range(num_batches):
    for batch_index in range(1):
        sampled_edge_index = index_list[batch_size*batch_index : min(batch_size*(batch_index+1), num_edges)]
        # print(len(sampled_edge_index))
        # print(min(batch_size*(batch_index+1), num_edges))
        sampled_edges = data.edge_index[:, sampled_edge_index]
        uniform_random_ratio = assign_color_sampled_edges(sampled_edges, k_colors)
        # if batch_index == 10:
        sampled_nodes = sampled_edges.ravel()
        # print(sampled_nodes.device)

        # print(sampled_nodes.shape)
        # print(sampled_nodes.dtype)
        _, n_id, adjs = sampler.sample(sampled_nodes.to('cpu'))
        # print(n_id.shape)
        optimizer.zero_grad()
        out = model(data.x[n_id], adjs)

        current_batch_size = out.shape[0]//2
        # print(current_batch_size)
            # print('======')
        # print(out[0,:])
        # print(out[current_batch_size,:])
        # print(current_batch_size.dtype)
        # loss = -criterion(out[:current_batch_size,:], out[current_batch_size:, :]) 
        # loss = -cos_sim(out[:current_batch_size,:], out[current_batch_size:, :]).mean() 
        
        # loss = -criterion(out[:batch_size,:], out[batch_size:, :]) + alpha * negative_entropy_loss(out)
        # loss = inner_product_loss(out) - alpha * negative_entropy_loss(out)
        loss = inner_product_loss(out)
        # print(loss.item(), percent_same_color(out).item(), uniform_random_ratio)
        print(loss.item(), inner_product_loss(out).item(), uniform_random_ratio)
        # print(loss.item(), inner_product_loss(out).item())
        # print()
        loss.backward()
        optimizer.step()
        

    
    



8959
0
0.5023680925369263 0.5023680925369263 0.5259515570934256
1
0.5000861287117004 0.5000861287117004 0.5018417234066302
2
0.5001647472381592 0.5001647472381592 0.5005022882018082
3
0.5003552436828613 0.5003552436828613 0.512557205045206
4
0.50022292137146 0.50022292137146 0.5182498046656993
5
0.5000616908073425 0.5000616908073425 0.4875544145551959
6
0.4999978244304657 0.4999978244304657 0.5093202366335529
7
0.5000165700912476 0.5000165700912476 0.5035160174126576
8
0.5000544190406799 0.5000544190406799 0.5138966402500279
9
0.5000529885292053 0.5000529885292053 0.5214867730773524


tensor([1, 2, 3, 4])


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [7]:
model.train()
margin = 0.01
train_running_loss = 0.0
for sample in loader:
    # print(device)
    sample.x = sample.x.to(device).to(dtype=torch.float32)
    sample.x.requires_grad=True
    # sample.x = requires_grad=True
    sample.edge_index = sample.edge_index.to(device).long()
    # print(sample.x.is_cuda)
    # print(type(sample.edge_index))
    # print(sample.edge_index.is_cuda)
    optimizer.zero_grad()
    i, j, k = structured_negative_sampling(sample.edge_index)
    negatives = (i,k)   #not neighbors
    positives = (i,j)   #neighbors 
    output = model(sample)
    #pos = model.similarity(sample.x[i], sample.x[j])
    #neg = model.similarity(sample.x[i], sample.x[k])
    pos = model.similarity(output[i], output[j])
    neg = model.similarity(output[i], output[k])
    diff =pos.diag() -neg.diag() +margin      # Note for coloring, we want negatives closer and positives further
    triplet_loss_matrix = diff.mean()
    loss = triplet_loss_matrix
    loss.backward()
    optimizer.step()
    train_running_loss += loss.detach().item()