Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Link Prediction on Heterogeneous Graphs with Heterogeneous Graph Learning #3958

Closed
sophiakrix opened this issue Jan 26, 2022 · 151 comments
Closed
Labels

Comments

@sophiakrix
Copy link

sophiakrix commented Jan 26, 2022

🚀 The feature, motivation and pitch

I am working with heterogeneous knowledge graphs and am trying to do link prediction on them. The specific issue I am facing is that I cannot find any working implementation that would allow me to do link prediction on a graph with multiple node types and multiple edge types and predict the existence and the type of edge between nodes.

It would be great to have a working example of how to do link prediction on a heterogeneous graph with the heterogeneous graph learning module.

@rusty1s
Copy link
Member

rusty1s commented Jan 27, 2022

You are right that we can strengthen our set of examples in that regard. Currently, we only have a single example of link-level prediction for movie recommendation, see here. Do you have some specific example in mind? Happy to work with you on this together.

@sophiakrix
Copy link
Author

sophiakrix commented Jan 31, 2022

Hi @rusty1s , thanks for your reply! And yes, I have a specific example in mind. I have been trying to work on link prediction with the public biomedical knowledge graph ogbl-biokg. I think this would be a good example, since it's also a benchmarking graph for many other approaches.

Indeed, I'd be happy to work on this together. When I check the link prediction example for movie recommendation that you mentioned, I see that there are the following tasks to do:

  • use the ogbl-biokg pyg class PyGLinkPropPredDataset
  • add node features (torch.nn.Embedding(num_nodes, embedding_dim))
  • adapt the decoder to consider edges of different edge types
  • modify loss (contrastive loss, either hinge or cross entropy)

Is there anything I am missing?

@rusty1s
Copy link
Member

rusty1s commented Feb 1, 2022

Sounds amazing. I'm wondering whether we need to make ogbl-biokg directly available in PyG, as OGB already supports PyG datasets out-of-the-box. Also not sure what you mean by "adding node features" - please clarify :)

@sophiakrix
Copy link
Author

That would be a good idea! I saw that you also contributed to the availability of a pytorch-geometric formatted ogb dataset here. It looks like this has been done exactly in the way that is required for pytorch geometric - so could we directly use this class?

As to the node features: Since @anniekmyatt did it in her example, I thought it might be necessary to have node features. But of course, if this is not necessary we can leave it out.

@anniekmyatt
Copy link
Contributor

Thanks for mentioning me, I’m happy to help out if you’d like me to!

I think in your case you need to use a contrastive loss (either hinge or cross entropy), using examples of positive and negative edges.

i think the ogbl-biokg doesn’t have features on the nodes, @rusty1s , so I thought @sophiakrix needs to add some if she wants to use a GNN.

@anniekmyatt
Copy link
Contributor

anniekmyatt commented Feb 1, 2022

@sophiakrix , the encoder in the movielens example does have a heterogeneous encoder due to this line:

self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')

I think you mainly need to change the decoder and loss function, and think about how you feed in your batches of edges (are you just looking to predict on one edge type or would it be beneficial to optimise for different edge types in your use case?).

@rusty1s
Copy link
Member

rusty1s commented Feb 2, 2022

@sophiakrix Yes, you can directly use the PyG dataset from the ogb suite.

Indeed, the ogbl-biokg dataset does not provide node features, and people usually use torch.nn.Embedding(num_nodes, embedding_dim) to learn them jointly.

@sophiakrix
Copy link
Author

Thanks for all the tips ! And thanks @anniekmyatt for the nice implementation of your MovieLense example :) I'll start working on adapting it now to the ogbl-biokg. I'll keep you updated :)

@sophiakrix
Copy link
Author

sophiakrix commented Feb 2, 2022

When trying to use the ogbl-biokg dataset with the PyGLinkPropPredDataset class, I realised that this is a Data object, and not a HeteroData object. Do I need to convert it to one in order to set the node features?

>>> data # MovieLens
HeteroData(
   [1mmovie [0m={ x=[9742, 404] },
   [1muser [0m={ num_nodes=610 },
   [1m(user, rates, movie) [0m={
    edge_index=[2, 100836],
    edge_label=[100836]
  }
)

And when I look at the ogbl-biokg data object, it gives me the following:

>>> data # ogbl-biokg
Data(
  num_nodes_dict={
    disease=10687,
    drug=10533,
    function=45085,
    protein=17499,
    sideeffect=9969
  },
  edge_index_dict={
    (disease, disease-protein, protein)=[2, 73547],
...
  },
  edge_reltype={
    (disease, disease-protein, protein)=[73547, 1],
...
  }
)

@anniekmyatt did access the node types explicitly for the MovieLens dataset, and in the case of ogbl-biokg, I cannot access the nodes and their features like this. Is there another way to set them?

data['user'].x = torch.eye(data['user'].num_nodes, device=device)

@rusty1s
Copy link
Member

rusty1s commented Feb 3, 2022

Currently, ogb does not yet make use of our newly released HeteroData class. As such, it's a good idea to manually convert it to HeteroData, e.g.:

hetero_data = HeteroData()
for node_type, num_nodes in data.num_nodes_dict.items():
    hetero_data[node_type].num_nodes = num_nodes
for edge_type, edge_index in data.edge_index_dict.items():
    hetero_data[edge_type].edge_index = edge_index

@sophiakrix
Copy link
Author

sophiakrix commented Feb 3, 2022

Great, I thought that that's the best way to go, too. When I then tried to train the model, I encountered a few issues:

1) Message passing error for direct edge types

When passing in a direct edge type to the model

# direct edge
pred = model(train_data.x_dict, train_data.edge_index_dict, train_data['disease', 'protein'].edge_label_index)

the following error gets thrown:

Exception has occurred: ValueError
`MessagePassing.propagate` only supports `torch.LongTensor` of shape `[2, num_messages]` or `torch_sparse.SparseTensor` for argument `edge_index`.
  File "/Users/krixs/git/PHC-1196/src/my_pytorch_geometric/hetero_link_pred.py", line 148, in forward
    z_dict = self.encoder(x_dict, edge_index_dict)
  File "/Users/krixs/git/PHC-1196/src/my_pytorch_geometric/hetero_link_pred.py", line 165, in train
    train_data['disease', 'protein'].edge_label_index)
  File "/Users/krixs/git/PHC-1196/src/my_pytorch_geometric/hetero_link_pred.py", line 186, in <module>
    loss = train()

The edge_index_dict that is passed here is a torch.LongTensor, and of the required shape, so I am not sure what is the exact problem. Maybe that it is a dictionary and not an individual tensor?

>>> edge_index_dict
{('disease', 'disease-protein', 'protein'): tensor([[ 3936,  511...,   741]]), ...}
>>> edge_index_dict[('disease', 'disease-protein', 'protein')].shape
torch.Size([2, 58839])
>>> edge_index_dict[('disease', 'disease-protein', 'protein')]).type()
'torch.LongTensor'

2) Error for reverse edge types

The reverse edge types are lacking the edge_label and edge_label_index attribute. This causes an error during forwarding:

# reverse edge
pred = model(train_data.x_dict, train_data.edge_index_dict, train_data['protein', 'disease'].edge_label_index)

Error:


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/anaconda3/envs/env_pyg/lib/python3.7/site-packages/torch_geometric/data/storage.py", line 51, in __getattr__
    f"'{self.__class__.__name__}' object has no attribute '{key}'")
AttributeError: 'EdgeStorage' object has no attribute 'edge_label_index'

Would it be okay to simply add these attributes from the corresponding direct edge type?

@rusty1s
Copy link
Member

rusty1s commented Feb 3, 2022

How does your model look like? It looks like you are mixing encoder and decoder parts. The encoder takes x_dict and edge_index_dict as input, and it will probably use a model based on to_hetero. The decoder operates on single edge types only, and will, e.g., in your case take the output embeddings of disease and protein as input.

@sophiakrix
Copy link
Author

sophiakrix commented Feb 3, 2022

I did not change the model from @anniekmyatt, since she said the encoder would already accept heterogeneous input. It looks as follows:

class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
 
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['disease'][row], z_dict['protein'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, hetero_data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)
        # binary cross entropy loss with logits
        self.loss = torch.nn.BCEWithLogitsLoss()

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

I will adapt the decoder in a following step, but I am yet unsure how to handle that there are different node type pairs, and not only one ('user', 'movie') as in the example.

@anniekmyatt
Copy link
Contributor

anniekmyatt commented Feb 3, 2022

Yes, I hard-coded the edge type in the movielens example because we were only interested in one edge type and there's nothing in that decoder to handle other edge types.

To handle different edge types, you can make a decoder in which you pass the edge type as an input to the forward function. You could use a bilinear decoder, like the DistMult decoder, because it learns parameters related to the different edge types. You then need to loop over the different edge types and aggregate the loss.

Let me know if this makes sense or if you'd like more detail.

@anniekmyatt
Copy link
Contributor

anniekmyatt commented Feb 3, 2022

Just a few more thoughts:

  • One for now: looking at your code, you are trying to get train_data['protein', 'disease'].edge_label_index. I don't think this exists, but instead it should be: train_data[('disease', 'disease-protein', 'protein')].edge_label_index, you need the edge type in there.
  • A question: Are you trying to get predictions for both the out and return edge, or just ('disease', 'disease-protein', 'protein')? If it's just one way, you may not need a DistMult decoder and also no edge_label_index for ('protein', 'rev_disease-protein', 'disease')?
  • Related to the above, but for later: have a look at the input disjoint_train_ratio to the RandomLinkSplit transformation and see if you want to set this to something non-zero (e.g. something like 0.2). Doing this ensures that the training edges are not also message passing edges. If you don't want disjoint training edges, I think it's ok to manually set the edge_label_index of the return edge as the reverse edge of train_data[('disease', 'disease-protein', 'protein')].edge_label_index (so swap the first and second row). If you do want the supervision edges to be disjoint from the rest of the graph used for message passing, it's worth checking that your reverse edge_label_index edges don't turn out to be message passing edges, which means you might have information leakage. My recommendation is to set the parameter to 0 to start with and get your code to run and learn, and then you can look at the impact on the performance of setting it to nonzero later.

@sophiakrix
Copy link
Author

Hi @anniekmyatt ! Thanks for your thoughts :)

As to your question: I actually only need prediction for the out edge ('disease', 'disease-protein', 'protein'), and not the return edge ('protein', 'rev_disease-protein', 'disease'). What kind of decoder would you suggest in this case?
Also, maybe I did not quite understand ... What should the edge_label_index indicate?
And thanks for the hint with the disjoint_train_ratio, I will definitely try that out!

@sophiakrix
Copy link
Author

sophiakrix commented Feb 4, 2022

Something to consider for the message passing error for direct edge types:

The same error already happens during lazy initialization in this line here:

# Due to lazy initialization, we need to run one model step so the number
# of parameters can be inferred:
with torch.no_grad():
    model.encoder(train_data.x_dict, train_data.edge_index_dict)

Error:

Traceback (most recent call last):
  File "[/Users/krixs/.vscode/extensions/ms-python.python-2021.12.1559732655/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py]()", line 420, in evaluate_expression
    compiled = compile_as_eval(expression)
  File "[/Users/krixs/.vscode/extensions/ms-python.python-2021.12.1559732655/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py]()", line 374, in compile_as_eval
    return compile(_expression_to_evaluate(expression), '<string>', 'eval')
  File "<string>", line 1
    with torch.no_grad():
       ^
SyntaxError: invalid syntax

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<string>", line 2, in <module>
  File "[/opt/anaconda3/envs/env_pyg/lib/python3.7/site-packages/torch/fx/graph_module.py]()", line 616, in wrapped_call
    raise e.with_traceback(None)
ValueError: `MessagePassing.propagate` only supports `torch.LongTensor` of shape `[2, num_messages]` or `torch_sparse.SparseTensor` for argument `edge_index`.

I commented these lines out since they were throwing the error (bad decision, I know now)... Do you have an idea why the encoder has a problem with this input? As mentioned in the documentation, after converting a homogeneous model to a heterogeneous model with to_hetero(), a dictionary with edge types should be a valid input:

As a result, the model now expects dictionaries with node and edge types as keys as input arguments, rather than single tensors utilized in homogeneous graphs.

The edge_index_dict is again the same as mentioned above:

>>> train_data.edge_index_dict
{('disease', 'disease-protein', 'protein'): tensor([[ 4882,  113..., 13954]]),  ...}

>>> train_data.edge_index_dict['disease','disease-protein','protein'].shape
torch.Size([2, 58839])

>>> train_data.edge_index_dict['disease','disease-protein','protein'].type()
'torch.LongTensor'

@anniekmyatt
Copy link
Contributor

I'm trying to reproduce this but don't think I have enough info to do so. Your edge_index_dict looks fine to me.

Out of curiosity, what does your train_data.x_dict look like? I know the error complains about the edge_index but it would help me to try and reproduce your error.

@anniekmyatt
Copy link
Contributor

anniekmyatt commented Feb 6, 2022

Ha, turns out I could reproduce your error! 🥳

To be honest, I'm not sure what's happening, something is going wrong in the cls_call of the wrappd_call() method in graph_module.py. However, when I use a HeteroConv wrapper, instead of a homogeneous GNN followed by the .to_hetero() method, as described here, then the lazy loading of the model completes without issues, so I recommend doing that. You also get better control over adding things like dropout and batch normalisation using the HeteroConv wrapper.

@anniekmyatt
Copy link
Contributor

anniekmyatt commented Feb 6, 2022

I realised you had asked a few questions earlier that are still unanswered:

What should the edge_label_index indicate?

The edge_label_index indicates which edges are your supervision edges, so the ones that you are going to get model predictions for, so the samples that you will optimise your model on. On the other hand, edge_index defines all the edges in your graph and is used for message passing in your GNN. If you have the disjoint_train_ratio set to zero, the supervision edges are a subset of the full graph, so they are also included in edge_index, but if you set it to non-zero, your supervision edges and message passing edges will not overlap (they'll be disjoint).

What kind of decoder would you suggest in this case?

This is a good question. I have to admit that I am not sure that we are going about this the best way for the bio-kg by adapting the movielens example. Maybe @rusty1s can weigh in with an opinion here? Looking at the OGB leaderboard, common ways for learning knowledge graph embeddings like ComplEx or even DistMult, which are available out of the box in something like DGL-KE, seem to do well.

Our approach appears to be to learn the graph structure through both learnable node embeddings and a GNN. Given you are interested in only one edge type the decoder could either be just a dot product, or you could tack on something like a DistMult as well, to spend some more time learning parameters for the different edge types, but then you have a model with loads of parameters, which will overfit very easily. I'm not sure if this is a good idea. It depends on how much time you have, @sophiakrix, can you try out different things and see if it works, or are you tight for time?

I'm wondering whether it might be better to ditch the GNN completely and just work on implementing one or more of the KG algorithms (like DistMult, RESCAL, ComplEx...) combined with learnable node embeddings ? @rusty1s , is this something that's on your roadmap at all or is it out of scope? Or do you think it's worth trying out the GNN+learnable embeddings approach?

@rusty1s
Copy link
Member

rusty1s commented Feb 7, 2022

@anniekmyatt

Ha, turns out I could reproduce your error!

Can you show me how?

so I recommend doing that. You also get better control over adding things like dropout and batch normalisation using the HeteroConv wrapper.

I wouldn't say this is true at all, so I am interested to find out more about this error. The only reason one should use HeteroConv and manually hassling around with dictionaries is in case one wants to utilize different GNN layers/operators across different types. You can make well use of dropout and batch norm inside to_hetero as well :)

@sophiakrix edge_label_index should indicate the edges for which you have ground-truth information, and edge_label indicates their "class". For a KG, this would be its relation type. ogbl-bio provides pre-defined splits via dataset.get_edge_split(), so you can either simply use that directly or map it to the edge_label_index/edge_label scenario of PyG.

@anniekmyatt Yes, we could think about adding different encoder and decoder parts to PyG directly. Currently, we leave the task of KGEs (without multi-hop reasoning), e.g., ComplEx, etc, to other libraries.

@sophiakrix
Copy link
Author

sophiakrix commented Feb 7, 2022

@rusty1s Here is the code I used and which is throwing the error, so that you can reproduce it:

import os.path as osp
import argparse

import torch
from torch.nn import Linear
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets.movie_lens import MovieLens
from torch_geometric.nn import SAGEConv, to_hetero
from ogb.linkproppred import PygLinkPropPredDataset
from torch_geometric.data import HeteroData

parser = argparse.ArgumentParser()
parser.add_argument('--use_weighted_loss', action='store_true',
                    help='Whether to use weighted MSE loss.')
parser.add_argument('--dataset', default='ogbl-biokg', help='The dataset to use.')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# variables
embedding_dim = 10
hidden_channels = 32

dataset = PygLinkPropPredDataset(name = 'ogbl-biokg', root='../../dataset/')
data = dataset[0].to(device)
hetero_data = HeteroData()
for node_type, num_nodes in data.num_nodes_dict.items():
    hetero_data[node_type].num_nodes = num_nodes
for edge_type, edge_index in data.edge_index_dict.items():
    hetero_data[edge_type].edge_index = edge_index
# Create node types holding a feature matrix:
for ntype, num_nodes in data.num_nodes_dict.items():
    hetero_data[ntype].x = torch.eye(num_nodes, device=device)


# get all edge types of ogbl-biokg graph before converting it to undirected
edge_types = hetero_data.edge_types


# Add a reverse relation for every type (merge has to be set to False) for message passing:
hetero_data = T.ToUndirected(merge=False)(hetero_data)

# get all reverse edge types of ogbl-biokg graph
reverse_edge_types = [(x[2], "{}_{}".format('rev', x[1]), x[0]) for x in edge_types]

# Perform a link-level split into training, validation, and test edges:
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=edge_types,
    rev_edge_types=reverse_edge_types,
)(hetero_data)


# We have an unbalanced dataset with many labels for rating 3 and 4, and very
# few for 0 and 1. Therefore we use a weighted MSE loss.
if args.use_weighted_loss:
    weight = torch.bincount(train_data['disease', 'protein'].edge_label)
    weight = weight.max() / weight
else:
    weight = None


def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

def calculate_loss(loss, pred, target):
    input = torch.randn(3, requires_grad=True)
    target = torch.empty(3).random_(2)
    output = loss(input, target)
    return output

class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
 
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['disease'][row], z_dict['protein'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, hetero_data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)
        # binary cross entropy loss with logits
        self.loss = torch.nn.BCEWithLogitsLoss()

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)


model = Model(hidden_channels).to(device)

# Due to lazy initialization, we need to run one model step so the number
# of parameters can be inferred:
with torch.no_grad():
    model.encoder(train_data.x_dict, train_data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict, train_data['disease', 'protein'].edge_label_index)
    target = train_data['user', 'movie'].edge_label
    loss = calculate_loss(model.loss, pred, target)
    #loss = weighted_mse_loss(pred, target, weight)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(data):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict,
                 data['user', 'movie'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = data['user', 'movie'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)


for epoch in range(1, 301):
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    test_rmse = test(test_data)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
          f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')

@rusty1s
Copy link
Member

rusty1s commented Feb 7, 2022

Thank you. I think this is resolved in master, as it does not throw an error for me. There was some issue in to_hetero which could not handle edge types with - or whitespaces in its name. You can try to install from master and see if this is fixed for you as well:

pip install git+https://github.com/pyg-team/pytorch_geometric.git

@sophiakrix
Copy link
Author

sophiakrix commented Feb 7, 2022

Thanks for the info! I think it solves the error. The issue is that now when I try the lazy initialization, I get a CUDA out of memory error. I think this is due to the fact that I try to pass in the entire dataset in one go.

Therefore, I thought of doing mini-batching for the training, but I am facing an issue here. I followed the example for mini-batch training with the HGTLoader:

train_loader = HGTLoader(
    train_data,
    # Sample 512 nodes per type and per iteration for 4 iterations
    num_samples={key: [512] * 2 for key in train_data.node_types},
    # Use a batch size of 16 for sampling training nodes of type paper
    batch_size=16,
    input_nodes=('disease', None), 
)

def train():
    model.train()
    total_examples = total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to('cuda:0')
        batch_size = batch['disease'].batch_size
        pred = model(train_data.x_dict, train_data.edge_index_dict, train_data['disease', 'protein'].edge_label_index)
        target = train_data['disease', 'protein'].edge_label
        loss = calculate_loss(model.loss, pred, target)
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples 

When trying to call the next batch, it gives me the following error:

/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [15,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
...
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [55,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
...
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:646: indexSelectSmallIndex: block: [37,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

Traceback (most recent call last):
  File "src/my_pytorch_geometric/hetero_link_pred.py", line 237, in <module>
    loss = train()
  File "src/my_pytorch_geometric/hetero_link_pred.py", line 208, in train
    for batch in train_loader:
  File "/home/krixs/.conda/envs/env_pyg_gpu/lib/python3.8/site-packages/torch_geometric/loader/base.py", line 21, in __next__
    return self.transform_fn(next(self.iterator))
  File "/home/krixs/.conda/envs/env_pyg_gpu/lib/python3.8/site-packages/torch_geometric/loader/hgt_loader.py", line 139, in transform_fn
    data = filter_hetero_data(self.data, node_dict, row_dict, col_dict,
  File "/home/krixs/.conda/envs/env_pyg_gpu/lib/python3.8/site-packages/torch_geometric/loader/utils.py", line 152, in filter_hetero_data
    filter_edge_store_(data[edge_type], out[edge_type],
  File "/home/krixs/.conda/envs/env_pyg_gpu/lib/python3.8/site-packages/torch_geometric/loader/utils.py", line 120, in filter_edge_store_
    out_store[key] = index_select(value, perm[index], dim=0)
  File "/home/krixs/.conda/envs/env_pyg_gpu/lib/python3.8/site-packages/torch_geometric/loader/utils.py", line 24, in index_select
    return torch.index_select(value, 0, index, out=out)
RuntimeError: CUDA error: device-side assert triggered

Do you have an idea why that could be? Also, lazy initialization is not possible here, no?

Update:
I tried it on cpu to see what happens there, and I found that this error is thrown:

File "/Users/krixs/git/PHC-1196/src/my_pytorch_geometric/hetero_link_pred.py", line 207, in train
    for batch in train_loader:
  File "/opt/anaconda3/envs/env_pyg/lib/python3.7/site-packages/torch_geometric/loader/base.py", line 21, in __next__
    return self.transform_fn(next(self.iterator))
  File "/opt/anaconda3/envs/env_pyg/lib/python3.7/site-packages/torch_geometric/loader/hgt_loader.py", line 140, in transform_fn
    edge_dict, self.perm_dict)
  File "/opt/anaconda3/envs/env_pyg/lib/python3.7/site-packages/torch_geometric/loader/utils.py", line 148, in filter_hetero_data
    node_dict[node_type])
KeyError: 'drug'

Is the parameter input_nodes for the HGTLoader that I am using the right one or does this maybe cause the issue?

Or is the issue that I am using a node sampler and not an edge sampler ? I then tried to use an edge sampler from the docs for the samplers, the NeighborLoader:
When using the NeighborLoader with num_neighbors={key: [1] for key in train_data.edge_types} as a parameter, it is not sampling all node types. This might be the issue for why it cannot access drug in the node_dict. Is there a workaround
?@rusty1s

@anniekmyatt
Copy link
Contributor

You can make well use of dropout and batch norm inside to_hetero as well :)

Ah, thanks! I had some issues with getting this to run in the past. I must admit I didn't check recently and just went straight to my workaround using the HeteroConv wrapper.

Adding dropout to the homogeneous GNN before .to_hetero() does work for me but I'm still getting an error when including batch normalisation. I'll raise a separate issue for this, with the code that I ran. It's possible I'm doing something wrong or maybe a fix is needed.

@rusty1s
Copy link
Member

rusty1s commented Feb 8, 2022

@sophiakrix Please see my newly created issue here: #4026.
@anniekmyatt Fixed this bug in #4027.

@sophiakrix
Copy link
Author

@anniekmyatt Thanks for all your answers and for putting the effort in to reproduce the error. Could you possibly show how you implemented the HeteroConv for this case?

As to your question how my x_dict looks like:

train_data.x_dict
{'disease': tensor([[1., 0., 0.,... 0., 1.]]), 'drug': tensor([[1., 0., 0.,... 0., 1.]]), 'function': tensor([[1., 0., 0.,... 0., 1.]]), 'protein': tensor([[1., 0., 0.,... 0., 1.]]), 'sideeffect': tensor([[1., 0., 0.,... 0., 1.]])}

About the decoder part: I think a simple dot product would already be sufficient, just to try it out. Of course, it would be good to be able to compare more decoders, but I'm quite tight on time for this, since I have until the end of this month.

@anniekmyatt
Copy link
Contributor

Sure! As @rusty1s explained above, it turns out you don't need to use the HeteroConv wrapper if you use the same GNN layers for the different edge types, and it's a lot more lines of code, but here it is:

class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels):
        super().__init__()

        num_layers = 2
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            if _ == num_layers - 1:
                # Overwrite number of channels in last layer
                hidden_channels = out_channels
            conv = HeteroConv({
                edge_type: SAGEConv((-1, -1), out_channels)
                for edge_type in metadata[1]
            })
            self.convs.append(conv)

        self.batchnorm_dict = torch.nn.ModuleDict()
        for node_type in metadata[0]:
            self.batchnorm_dict[node_type] = BatchNorm(hidden_channels)

    def forward(self, x_dict, edge_index_dict, p_dropout=0.0):

        # Dropout
        x_dict = {key: F.dropout(x, p=p_dropout, training=self.training) for key, x in x_dict.items()}
        for i in range(len(self.convs)-1):
            x_dict = self.convs[i](x_dict, edge_index_dict)
            # Batch normalisation
            x_dict = {key: self.batchnorm_dict[key](x) for key, x in x_dict.items()}
            # Activation function
            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
            # Dropout
            x_dict = {key: F.dropout(x, p=p_dropout, training=self.training) for key, x in x_dict.items()}
        return self.convs[-1](x_dict, edge_index_dict)   

@anniekmyatt
Copy link
Contributor

Yeah, if you're a bit tight for time I would start as simple as possible and get it to a state where it trains, and then you can add complexity later if needed. :-)

Thanks for sharing your x_dict. I know this is alraedy on your todo list but you'll have to convert the x_dict you currently have (with the one-hot encoded nodes) to a dictionary of torch.nn.Embeddings (one for each node type). I found it useful to use a torch.nn.ModuleDict() to hold the embeddings for the different node types, and there is an example here of how to initialise them (the example is for a ModuleList() but you'll see the analogy).

Also note that you can access the embd.weight parameters directly (assuming embd=torch.nn.Embeddings(num_nodes, dim_embd), rather than having to call embd(input), where input would be something like range(0, num_nodes[node_type]). I hope this helps but let me know if you have any questions.

@rusty1s rusty1s closed this as completed May 5, 2023
@csinnewcastle
Copy link

Hi
we I'm getting zero edges for some edge types
Hi
sorry for asking too many questions but when i use LinkNeighborLoader. i got null values for some edges:
see below:
sampled_data = next(iter(train_loader))
print(sampled_data)
(protein, protein_has_go_annotation, go)={
edge_index=[2, 0],
e_id=[0]
},
(go, go_is_subtype_of_go, go)={
edge_index=[2, 0],
e_id=[0]
},
(gene, gene_expressed_in, tissue)={
edge_index=[2, 0],
e_id=[0]
},
(side_effect, side_effect_same_as, phenotype)={
edge_index=[2, 0],
e_id=[0]
},
(protein, protein_has_signature, signature)={
edge_index=[2, 0],
e_id=[0]
},
(protein, protein_expressed_in_tissue, tissue)={
edge_index=[2, 0],
e_id=[0]
},

@rusty1s
Copy link
Member

rusty1s commented May 7, 2023

This would indicate that these edge types are not reachable from your set of seed links within the specified number of hops.

@csinnewcastle
Copy link

csinnewcastle commented May 7, 2023 via email

@rusty1s
Copy link
Member

rusty1s commented May 8, 2023

How is your LinkNeighborLoader defined and how does the metadata of your graph look like?

@csinnewcastle
Copy link

csinnewcastle commented May 8, 2023 via email

@rusty1s
Copy link
Member

rusty1s commented May 9, 2023

So it looks like a two-hop sampling in LinkNeighborLoader gets you the following edge types:

  • Sampling from drug: (protein, rev_HasTarget, drug) and (disorder, rev_HasIndication, drug), etc in first hop, (gene, rev_EncodedBy, protein), etc in second hop
  • Sampling from disorder: (phenotype, rev_disorder_has_phenotype, disorder) and (drug, drug_has_contraindication, disorder), etc in first hop, (disorder, disorder_has_phenotype, phenotype), etc in second hop.

As such, two-hop sampling is not able to reach information such as (protein, protein_in_pathway, pathway) or (go, go_is_subtype_of_go, go), which may lead to this confusion. I think it is working as expected. You should see more edge types being used when increasing the number of hops to sample, e.g., num_neighbors=[20, 10, 5].

@csinnewcastle
Copy link

i think something is wrong with my model. The model predicted negative, positive, greater than one values where as the positive and negative lables restricted to zero or one:
here is the ground truth lables
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.], dtype=torch.float64)

and here the predicted values:
38%|███▊ | 12/32 [00:08<00:13, 1.46it/s]tensor([ 1.8630e+00, -3.2289e+00, 9.1313e-01, 1.2432e+00, 8.8727e-02,
-2.2828e+00, 2.3988e+00, 1.8867e+00, 4.1952e+00, 2.9706e+00,
2.0249e+00, 2.2234e+00, 3.4324e+00, 2.3036e+00, 1.1494e+00,
2.3170e+00, 1.2391e+00, 2.3136e+00, 2.2425e+00, 1.6034e+00,
-2.1276e+00, 2.1982e+00, -3.0888e+00, 2.4332e+00, 1.2297e+00,
4.6500e-01, 2.6752e+00, 1.3331e+00, -4.6077e-01, 2.6557e+00,
4.2696e+00, 1.9773e-03, 1.1621e+00, 2.1305e+00, 2.3351e+00,
2.3038e+00, 2.4990e+00, -7.9550e-02, 3.3453e+00, 3.4135e+00,
9.9770e-01, 3.8268e-01, 5.3325e-01, -7.7455e-01, -4.0327e+00,
1.9870e+00, -6.8525e-01, 1.4183e+00, 4.2616e+00, 1.6797e+00,
1.7513e+00, 3.0142e+00, 2.1077e+00, 1.7509e+00, 4.1865e+00,
2.5612e+00, 2.4603e+00, 2.0171e+00, 1.8789e+00, 1.8583e+00,
2.6781e+00, -1.8065e+00, -5.2053e-01, 1.0457e+00, 2.0234e+00,
-2.1778e+00, 3.8039e+00, 2.7337e+00, 1.9319e+00, 3.5085e+00,
2.0493e+00, 1.3974e+00, 3.9519e+00, 2.6310e+00, 3.7914e+00,
9.7869e-01, -2.8534e+00, 3.0431e+00, -1.7535e+00, 2.8432e+00,
4.9280e-02, 3.8580e+00, 1.4825e+00, -1.9501e+00, 2.8456e+00,
2.5212e+00, 6.4191e-01, 3.0405e+00, 2.3211e+00, 3.1636e+00,
4.5764e+00, -2.5025e+00, 3.2457e+00, -1.7259e+00, 1.9469e+00,
-3.4567e+00, 2.4896e+00, 9.3197e-01, 1.3943e+00, -2.0276e+00,
1.0304e+00, 3.6472e+00, 1.1238e+00, 2.0936e+00, 3.6303e+00,
3.7402e+00, 1.4922e+00, 1.9742e+00, -5.9113e-02, 2.0257e+00,
4.3590e+00, 2.0540e+00, 2.6439e+00, 1.3991e+00, 1.8917e-01,
4.2410e-01, 9.0878e-01, 3.7381e+00, 2.1973e+00, 4.1795e+00,
3.7084e+00, 1.4763e+00, 3.4061e+00, 3.8349e+00, 2.4475e-01,
-1.8156e+00, 3.9751e+00, 2.0130e+00, -4.8818e+00, -9.5985e-01,
-2.4859e+00, -2.7214e+00, 5.2708e-01, -4.8323e+00, -3.8380e+00,
-1.1169e-01, -1.1484e+00, -2.9812e+00, -2.5723e+00, -3.0313e+00,
-4.0180e+00, -1.3646e+00, -1.9903e+00, -4.8715e+00, -2.6089e+00,
-9.5172e-02, -5.0475e+00, -2.6625e+00, -2.5117e+00, -5.6681e-02,
-2.5121e+00, -3.8516e+00, -3.3512e+00, -1.3734e+00, -3.6199e+00,
-2.9914e+00, -3.5693e+00, -3.2872e+00, -2.5349e+00, -1.0579e-01,
-6.2436e+00, -2.6983e+00, -2.0578e+00, -6.9590e-02, -5.6218e+00,
-2.9002e+00, -5.8185e+00, -3.0743e+00, -1.6587e+00, -4.5810e+00,
-4.1390e+00, -5.8950e-02, -3.4252e+00, 1.4722e+00, -3.0487e+00,
-6.0375e+00, -2.2656e+00, 1.5832e+00, -6.0606e+00, -3.4087e+00,
-3.4297e+00, 8.1181e-02, -2.5673e+00, 5.6356e-01, -2.5600e-02,
-6.6609e-02, -1.9676e+00, -2.6752e+00, -4.1793e+00, -1.2933e-01,
-9.7865e-02, -2.3381e+00, -4.9418e+00, -7.1481e-02, -3.0567e+00,
-6.8370e-02, -3.8809e+00, -2.9300e+00, -2.9773e+00, -2.7330e+00,
8.9686e-01, -3.4359e+00, -1.7012e+00, -1.3977e+00, 2.5568e+00,
-9.4408e-02, -1.6950e+00, 4.6314e+00, -2.5780e+00, -3.8606e+00,
-6.4176e-02, -1.5882e+00, -2.5180e+00, -5.5974e-02, -1.1586e+00,
-3.2618e+00, -1.1745e+00, -8.5106e-02, -7.0210e-02, -6.8782e-01,
-4.9181e+00, -3.8041e+00, -2.3013e+00, -4.4391e+00, -4.5679e+00,
-3.4222e+00, -1.8728e+00, -4.8914e+00, -4.8873e+00, -9.8429e-02,
-2.2175e+00, -3.4156e+00, -3.9221e+00, -2.0816e+00, -3.4288e+00,
-3.2976e+00, -4.9349e+00, -2.6586e+00, -4.4469e+00, 4.0888e+00,
-1.1407e-01, -2.9210e+00, -3.8012e+00, -4.2664e+00, -3.4204e+00,
-7.5668e-02, -2.3584e+00, -3.4289e+00, -1.0807e-01, -3.4255e+00,
1.1967e+00, 3.5717e+00, -1.9604e+00, -3.9841e+00, -8.7152e-02,
-1.6667e+00, -7.0630e-02, -2.2634e+00, -2.0987e+00, -3.5781e+00,
-3.9770e+00, -3.4198e+00, -3.9029e+00, -3.8353e+00, -1.6777e+00,
-4.9369e+00, -3.4251e+00, -2.0060e+00, -1.3982e+00, -7.6269e-02,
-3.1730e+00, -1.6826e+00, -3.4364e+00, -7.5763e-02, -2.2212e+00,
-4.4231e+00, -4.0973e+00, -3.1877e+00, -2.1366e+00, -1.6135e+00,
-4.6035e+00, -1.5241e+00, -3.9998e+00, -8.0984e-02, -7.7511e-01,
-3.4651e+00, -3.4374e+00, -1.8893e+00, -1.1381e-01, -3.8668e+00,
-5.9586e+00, -6.0272e-02, -4.4324e+00, -3.4187e+00, -2.9221e+00,
-3.4336e+00, 1.1219e+00, -2.5040e+00, -1.9852e+00, -2.7178e-01,
-6.2054e-02, 3.6871e+00, -4.3756e+00, -3.7502e-01, -6.4151e+00,
-9.4931e-01, -2.5021e+00, -1.7616e+00, -1.8922e+00, -1.1983e+00,
-9.5109e-02, 1.1930e+00, -1.8029e+00, -6.7222e-02, -3.4260e+00,
-2.0257e-02, -2.2211e+00, -3.9963e+00, -2.3943e+00, -3.4091e+00,
-1.1113e-01, -1.6822e+00, 1.5916e+00, -2.9468e+00, -7.6186e-02,
-1.8441e+00, -6.2654e-02, -8.0868e-02, 1.9813e+00, -2.5082e+00,
-3.4263e+00, 2.0191e+00, -3.9836e+00, -6.3628e-02, -7.0857e-02,
-3.5591e+00, -2.4894e+00, -1.9051e+00, -1.2705e+00, -3.8984e+00,
2.0198e+00, -6.3911e+00, -4.6325e+00, -7.0322e-02, -7.6017e-02,
-3.4214e+00, -6.7916e-02, -3.2783e+00, -3.0816e+00, -3.0380e+00,
-2.8326e-01, -3.4232e+00, -8.3812e-02, -1.6765e+00, -1.6699e+00,
-7.6361e-02, -2.9977e+00, 1.6071e+00, -3.8253e+00, -3.0310e+00,
-4.3900e+00, -7.2348e-02, -2.3064e+00, -3.6658e+00, -2.2937e+00,
-2.4126e+00, -3.8377e+00, -2.9024e+00, -3.0599e+00, -2.9349e+00,
-2.8848e+00, -3.0300e+00, 1.9775e+00, -3.8302e+00, -2.7495e+00,
-4.4386e+00, -1.7012e+00, -1.6780e+00, -2.9391e+00, -2.4895e+00,
-3.3883e+00, -3.8524e+00, -2.8383e+00, -3.5256e+00],
grad_fn=)

I expected the prediction can be only zero or one?

Thanks

@rusty1s
Copy link
Member

rusty1s commented May 12, 2023

This depends on how your model output looks like. If you return logits, then the range of output can be (-inf, inf), and you can squash that to (0, 1) via sigmoid.

@csinnewcastle
Copy link

csinnewcastle commented May 15, 2023 via email

@rusty1s
Copy link
Member

rusty1s commented May 15, 2023

In this case, it should be sigmoid after the final prediction. You can also use sigmoid as intermediate activations, but usually something like ReLU performs better.

@csinnewcastle
Copy link

csinnewcastle commented May 15, 2023 via email

@ashkspark
Copy link

ashkspark commented May 22, 2023

When we perform negative sampling (Medium Blog) in (Defining Edge-level Training & Splits Defining Mini-batch Loaders), I believe we don't create any negative homogenous links, e.g., (users<--->users or movies<-->movies). In other words, all negative samples are of types (users--rates--movies or movies--re-rates--users). Just want to confirm.

Related to above, is the discussion here relevant. Could you explain where I can use that?

@rusty1s
Copy link
Member

rusty1s commented May 23, 2023

Yes, negative edges are just sampled for the supervision loss, which operates on the single edge type (user, rates, movie). We do not modify the underlying graph.

@ashkspark
Copy link

Thanks @rusty1s! Kind of related to this. Could you explain the difference between defining mini-batch loaders (here) versus performing negative sampling "for every training epoch" (here). I am putting the code snippet for both:

# In the first hop, we sample at most 20 neighbors.
# In the second hop, we sample at most 10 neighbors.
# In addition, during training, we want to sample negative edges on-the-fly with
# a ratio of 2:1.
# We can make use of the `loader.LinkNeighborLoader` from PyG:
from torch_geometric.loader import LinkNeighborLoader

# Define seed edges:
edge_label_index = train_data["user", "rates", "movie"].edge_label_index
edge_label = train_data["user", "rates", "movie"].edge_label
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 10],
    neg_sampling_ratio=2.0,
    edge_label_index=(("user", "rates", "movie"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

Versus

from torch_geometric.utils import negative_sampling
    

def train_link_predictor(
    model, train_data, val_data, optimizer, criterion, n_epochs=100
):

    for epoch in range(1, n_epochs + 1):

        model.train()
        optimizer.zero_grad()
        z = model.encode(train_data.x, train_data.edge_index)

        # sampling training negatives for every training epoch
        neg_edge_index = negative_sampling(
            edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
            num_neg_samples=train_data.edge_label_index.size(1), method='sparse')

        edge_label_index = torch.cat(
            [train_data.edge_label_index, neg_edge_index],
            dim=-1,
        )
        edge_label = torch.cat([
            train_data.edge_label,
            train_data.edge_label.new_zeros(neg_edge_index.size(1))
        ], dim=0)

        out = model.decode(z, edge_label_index).view(-1)
        loss = criterion(out, edge_label)
        loss.backward()
        optimizer.step()

        val_auc = eval_link_predictor(model, val_data)

        if epoch % 10 == 0:
            print(f"Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Val AUC: {val_auc:.3f}")

    return model

@rusty1s
Copy link
Member

rusty1s commented May 24, 2023

The difference is whether you want to pre-compute negative samples or not. Usually, you want to pre-compute them for validation and testing, such that your results are comparable across different epochs/runs. During training, it is usually better to sample negatives on the fly to avoid overfitting on a specific set of negatives. That's why we usually use both RandomLinkSplit(add_negative_train_samples=False) (to get negatives for validation/testing) and on-the-fly negative sampling for training. Hope this helps.

@ashkspark
Copy link

ashkspark commented May 24, 2023

Thanks @rusty1s! I fully understood RandomLinkSplit(add_negative_train_samples=False). My question is how LinkNeighborLoader sample negative edges on-the-fly for each subgraph/batch? Because we first run LinkNeighborLoader and obtain train_loader and later loop through different epochs during model training.

@rusty1s
Copy link
Member

rusty1s commented May 25, 2023

There is an option in LinkNeighborLoader to sample negative edges. It will then give you a edge_label where 1 means positive edge and 0 means negative edge. Internally, it randomly draws source and destination nodes and starts neighborhood sampling for these nodes as well.

@ashkspark
Copy link

Thanks @rusty1s! Sorry for confusion. I am aware of the negative sampling option of LinkNeighborLoader. I mean why is it on-the-fly? Unless the negsampling seed is determined at epoch runtime, it is not on-the-fly?

@rusty1s
Copy link
Member

rusty1s commented May 25, 2023

We randomly sample a new set of negative edges in every mini-batch. This is what I mean by on-the-fly.

@ashkspark
Copy link

Hi @rusty1s ! From discussion (#3958 (comment)), I am deciding between GENConv vs GINConv. Any suggestion? They seem highly similar to me based on their aggregation formulations.

@rusty1s
Copy link
Member

rusty1s commented May 29, 2023

Yes, they are both very similar, and should perform very similar. I cannot give a good rule of thumb here on which to use, at best just try out both :)

@csinnewcastle
Copy link

Hello great Pytorch Geometric team:)
I want to feed my model with the whole graph not minibatches subgraph. so that i want to add negative sampling to training, validation, test datasets during randomlinksplit, i used this function to split the edges to train, val, test data:
import torch_geometric.transforms as T
transform = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
disjoint_train_ratio=0.3, # TODO
neg_sampling_ratio=2.0, # TODO
add_negative_train_samples=True,
edge_types=("drug", "HasIndication", "disorder"),
rev_edge_types=("disorder", "rev_HasIndication", "drug"),
)
train_data, val_data, test_data = transform(data)

Apply the transformation to the dataset

However, negative sampling adding to training data but not Val and not test?
What mistake did I make??
BIG THANKS

@csinnewcastle
Copy link

also I have a nother issue when i did check train_data['drug', 'HasIndication','disorder'].edge_label

I got this:
tensor([2, 2, 2, ..., 0, 0, 0])

@csinnewcastle
Copy link

this function assign ones and zeros lables to edge lables even though i put False for add negative sampling parameter
import torch_geometric.transforms as T
transform = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
disjoint_train_ratio=0.3, # TODO
neg_sampling_ratio=1.0, # TODO
add_negative_train_samples=False,
edge_types=("drug", "HasIndication", "disorder"),
rev_edge_types=("disorder", "rev_HasIndication", "drug"),
)
train_data, val_data, test_data = transform(data)

@csinnewcastle
Copy link

I think there is a big mistake in RandomLinkSplit :
RandomLinkSplit does not only splits the dataset into subsets but adding edge lables to train, val , and test
after applying RandomLinkSplit to the data which does not contain edge lables at all no neg no pos:
train tensor([1., 1., 1., ..., 1., 1., 1.])
val tensor([1., 1., 1., ..., 0., 0., 0.])
test tensor([1., 1., 1., ..., 0., 0., 0.])
and it may the labeling is not correct that may be why my loss values up and down

@ashkspark
Copy link

also I have a nother issue when i did check train_data['drug', 'HasIndication','disorder'].edge_label

I got this: tensor([2, 2, 2, ..., 0, 0, 0])

This was fixed on 2.2 and afterwards. Simply upgrade your PYG version.

@csinnewcastle
Copy link

csinnewcastle commented Jun 9, 2023 via email

@csinnewcastle
Copy link

Hi
Thanks
I found the mistake. Actually is my mistake; I should assign the neg_sampling_ratio parameter to zero, I did assign this parameter to 1, and then edge labels added to the Train, Val and test and then during the mini batch loader, added labels which resulted in values 2 to edge label
transform = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
disjoint_train_ratio=0.3, # TODO
neg_sampling_ratio=1.0, # TODO
add_negative_train_samples=False,
edge_types=("drug", "HasIndication", "disorder"),
rev_edge_types=("disorder", "rev_HasIndication", "drug"),
)
train_data, val_data, test_data = transform(data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

8 participants