Skip to content

Commit

Permalink
Implement Graph Attention Embedding module
Browse files Browse the repository at this point in the history
  • Loading branch information
emalgorithm committed Dec 4, 2020
1 parent 1e7046d commit e6da864
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions examples/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.metrics import average_precision_score, roc_auc_score

from torch_geometric.datasets import JODIEDataset
from torch_geometric.nn import TGN, SAGEConv
from torch_geometric.nn import TGN, TransformerConv
from torch_geometric.nn.models.tgn import IdentityMessage, LastAggregator

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -43,16 +43,16 @@ def __call__(self, n_id, t):
val_ratio=0.15, test_ratio=0.15)


class GraphEmbedding(torch.nn.Module):
class GraphAttentionEmbedding(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(GraphEmbedding, self).__init__()
self.conv = SAGEConv(in_channels, out_channels)
super(GraphAttentionEmbedding, self).__init__()
self.conv = TransformerConv(in_channels, out_channels)

def forward(self, x, adj_t):
if adj_t.nnz() > 0:
x = self.conv((x, x[:adj_t.size(0)]), adj_t)
else:
x = self.conv.lin_r(x)
x = self.conv.lin_skip(x)
return x


Expand All @@ -76,7 +76,7 @@ def forward(self, z_src, z_dst):
model = TGN(data.num_nodes, raw_msg_dim, memory_dim, time_dim,
message_module=IdentityMessage(raw_msg_dim, memory_dim, time_dim),
aggregator_module=LastAggregator()).to(device)
gnn = GraphEmbedding(in_channels=memory_dim, out_channels=100).to(device)
gnn = GraphAttentionEmbedding(in_channels=memory_dim, out_channels=100).to(device)
link_pred = LinkPredictor(in_channels=100).to(device)

optimizer = torch.optim.Adam(
Expand Down

0 comments on commit e6da864

Please sign in to comment.