In [1]:
import dgl
from dgl.data import DGLDataset
import torch
import os
import pandas as pd
import numpy as np
import dgl
import torch
import itertools
import numpy as np
import scipy.sparse as sp
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from sklearn.metrics import roc_auc_score
device = torch.device('cpu')
print(device)

cpu


In [10]:
class COMP4222Dataset(DGLDataset):
    def __init__(self):
        super().__init__(name='comp-4222')
        self.df_startups = pd.read_csv('./data/startups_formatted.csv')
        self.df_investors = pd.read_csv('./data/investors_formatted.csv')
        self.df_investments = pd.read_csv('./data/funding_round_formatted.csv')
        self.startup_node = len(self.df_startups)
        self.investor_node = len(self.df_investors)
        self.investments_edge = len(self.df_investments)

    def process(self):

        self.df_startups = pd.read_csv('./data/startups_formatted.csv')
        self.df_investors = pd.read_csv('./data/investors_formatted.csv')
        self.df_investments = pd.read_csv('./data/funding_round_formatted.csv')
        self.startup_node = len(self.df_startups)
        self.investor_node = len(self.df_investors)
        self.investments_edge = len(self.df_investments)

        data_dict = {
            ("investor", "raise", "startup"): (torch.tensor(self.df_investments.investor_object_id.values.tolist()), torch.tensor(self.df_investments.funded_object_id.values.tolist())),
            ("startup", "israised", "investor"): (torch.tensor(self.df_investments.funded_object_id.values.tolist()), torch.tensor(self.df_investments.investor_object_id.values.tolist()))
            }     
        self.graph = dgl.heterograph(data_dict)
        
        edge_feature = [i for i in self.df_investments.columns if i not in ["funding_round_id", "funded_object_id", "investor_object_id"]]

        self.graph.nodes['investor'].data['feat'] = torch.tensor(np.pad(self.df_investors.iloc[:, 2:].to_numpy(), [(0,0),(0,120)], mode='constant', constant_values=0))
        self.graph.nodes['startup'].data['feat'] = torch.tensor(self.df_startups.iloc[:, 2:].to_numpy())
        self.graph.edges['raise'].data['feat'] = torch.tensor(self.df_investments[edge_feature].to_numpy())
        self.graph.edges['israised'].data['feat'] = torch.tensor(self.df_investments[edge_feature].to_numpy())




    def __getitem__(self, i):
        return self.graph

    def __len__(self):
        return 1

dataset = COMP4222Dataset()
graph = dataset[0]

print(graph)

Graph(num_nodes={'investor': 7594, 'startup': 21485},
      num_edges={('investor', 'raise', 'startup'): 60983, ('startup', 'israised', 'investor'): 60983},
      metagraph=[('investor', 'startup', 'raise'), ('startup', 'investor', 'israised')])


In [12]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}
            ,aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')
    def message_func(edges):
        return {'m': F.relu(self.W_msg(th.cat([edges.src['h'], edges.data['h']], 0)))}
    def forward(self, graph, inputs):
        # inputs are features of nodes
        with g.local_scope():
            h = self.conv1(graph, inputs)
            h = {k: F.relu(v) for k, v in h.items()}
            h = self.conv2(graph, h)
        return h

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']


def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()
  
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).detach().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

In [4]:
graph = COMP4222Dataset()[0]
graph.edges(etype=('startup', 'israised', 'investor'))

(tensor([13832, 13832, 16402,  ...,  2868, 19442,  7179]),
 tensor([   0, 2241, 3887,  ..., 6297, 1875, 2317]))

In [8]:
k = 5
model = Model(221, 20, 1, graph.etypes).to(device)
investor_feats = graph.nodes['investor'].data['feat']
startup_feats = graph.nodes['startup'].data['feat']
node_features = {'investor': investor_feats.float(), 'startup': startup_feats.float()}
opt = torch.optim.Adam(model.parameters())
for epoch in range(100):
    negative_graph = construct_negative_graph(graph, k, ('startup', 'israised', 'investor'))
    pos_score, neg_score = model(graph, negative_graph, node_features, ('startup', 'israised', 'investor'))
    loss = compute_loss(pos_score, neg_score)
    auc = compute_auc(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    if(epoch % 5 == 0):
        print("epoch:%3d, loss: %3f, auc:%3f "%(epoch, loss.item(), auc))

epoch:  0, loss: 1.000016, auc:0.473668 
epoch:  5, loss: 0.994850, auc:0.712508 
epoch: 10, loss: 0.983599, auc:0.765705 
epoch: 15, loss: 0.964864, auc:0.787621 
epoch: 20, loss: 0.936221, auc:0.800731 
epoch: 25, loss: 0.895072, auc:0.809990 
epoch: 30, loss: 0.840024, auc:0.816009 
epoch: 35, loss: 0.770826, auc:0.821177 
epoch: 40, loss: 0.693689, auc:0.825770 
epoch: 45, loss: 0.620002, auc:0.828342 
epoch: 50, loss: 0.553994, auc:0.831597 
epoch: 55, loss: 0.501721, auc:0.832870 
epoch: 60, loss: 0.461781, auc:0.834489 
epoch: 65, loss: 0.435270, auc:0.834825 
epoch: 70, loss: 0.416396, auc:0.835700 
epoch: 75, loss: 0.405865, auc:0.836372 
epoch: 80, loss: 0.397375, auc:0.836888 
epoch: 85, loss: 0.391261, auc:0.837956 
epoch: 90, loss: 0.389722, auc:0.837937 
epoch: 95, loss: 0.387516, auc:0.838343 
