Skip to content

Commit

Permalink
support new topk API in benchmark scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 6, 2019
1 parent 7198db0 commit e00adba
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion benchmark/kernel/sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def forward(self, data):
xs += [global_mean_pool(x, batch)]
if i % 2 == 0 and i < len(self.convs) - 1:
pool = self.pools[i // 2]
x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch)
x, edge_index, _, batch, _, _ = pool(x, edge_index,
batch=batch)
x = self.jump(xs)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
Expand Down
3 changes: 2 additions & 1 deletion benchmark/kernel/top_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def forward(self, data):
xs += [global_mean_pool(x, batch)]
if i % 2 == 0 and i < len(self.convs) - 1:
pool = self.pools[i // 2]
x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch)
x, edge_index, _, batch, _, _ = pool(x, edge_index,
batch=batch)
x = self.jump(xs)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
Expand Down

0 comments on commit e00adba

Please sign in to comment.