Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Aug 29, 2022
1 parent 2c8cf79 commit 0148a67
Showing 1 changed file with 98 additions and 23 deletions.
121 changes: 98 additions & 23 deletions benchmark/hetero/rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import time
from collections import defaultdict

import dgl
import pyg_lib
import torch

from torch_geometric import seed_everything
Expand All @@ -11,18 +13,22 @@
SEED = 12345

parser = argparse.ArgumentParser()
parser.add_argument('--num_node_types', type=int, default=10)
parser.add_argument('--num_edge_types', type=int, default=100)
parser.add_argument('--avg_num_nodes', type=int, default=100)
parser.add_argument('--num_node_types', type=int, default=30)
parser.add_argument('--num_edge_types', type=int, default=60)
parser.add_argument('--avg_num_nodes', type=int, default=400)
parser.add_argument('--avg_degree', type=int, default=10)
parser.add_argument('--channels', type=int, default=64)
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--warmup', type=str, default=10)
parser.add_argument('--iterations', type=int, default=100)


class SequentialRGCN(MessagePassing):
def __init__(self, num_edge_types, channels):
super().__init__(aggr='sum')
torch.manual_seed(SEED)
self.weight = torch.randn(num_edge_types, channels, channels)
weight = torch.randn(num_edge_types, channels, channels)
self.register_buffer('weight', weight)

def forward(self, x_dict, edge_index_dict):
outs_dict = defaultdict(list)
Expand All @@ -33,8 +39,8 @@ def forward(self, x_dict, edge_index_dict):
outs_dict[dst].append(out)

out_dict = {}
for key in x_dict.keys():
out_dict[key] = torch.stack(outs_dict[key], dim=0).sum(dim=0)
for key, outs in outs_dict.items():
out_dict[key] = torch.stack(outs, dim=0).sum(dim=0)

return out_dict

Expand All @@ -44,7 +50,8 @@ class VerticalRGCN(MessagePassing):
def __init__(self, num_edge_types, channels):
super().__init__(aggr='sum')
torch.manual_seed(SEED)
self.weight = torch.randn(num_edge_types, channels, channels)
weight = torch.randn(num_edge_types, channels, channels)
self.register_buffer('weight', weight)

def forward(self, x, edge_index, edge_type):
edge_index = edge_index.clone()
Expand All @@ -68,7 +75,8 @@ class HorizontalRGCN(MessagePassing):
def __init__(self, num_edge_types, channels):
super().__init__(aggr='sum')
torch.manual_seed(SEED)
self.weight = torch.randn(num_edge_types, channels, channels)
weight = torch.randn(num_edge_types, channels, channels)
self.register_buffer('weight', weight)

def forward(self, x, edge_index, edge_type):
edge_index = edge_index.clone()
Expand All @@ -86,6 +94,34 @@ def forward(self, x, edge_index, edge_type):
return out


class DGLTypedRGCN(MessagePassing):
def __init__(self, num_edge_types, channels):
super().__init__(aggr='sum')
torch.manual_seed(SEED)
weight = torch.randn(num_edge_types, channels, channels)
self.register_buffer('weight', weight)

def forward(self, x, edge_index, edge_sizes):
return self.propagate(edge_index, x=x, edge_sizes=edge_sizes)

def message(self, x_j, edge_sizes):
return dgl.ops.segment_mm(x_j, self.weight, edge_sizes)


class CutlassTypedRGCN(MessagePassing):
def __init__(self, num_edge_types, channels):
super().__init__(aggr='sum')
torch.manual_seed(SEED)
weight = torch.randn(num_edge_types, channels, channels)
self.register_buffer('weight', weight)

def forward(self, x, edge_index, edge_type_ptr):
return self.propagate(edge_index, x=x, edge_type_ptr=edge_type_ptr)

def message(self, x_j, edge_type_ptr):
return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight)


if __name__ == '__main__':
args = parser.parse_args()
print(args)
Expand All @@ -98,31 +134,70 @@ def forward(self, x, edge_index, edge_type):
avg_num_nodes=args.avg_num_nodes,
avg_degree=args.avg_degree,
)
hetero_data = dataset[0]
hetero_data = dataset[0].to(args.device)
for node_type in hetero_data.node_types:
store = hetero_data[node_type]
store.x = torch.randn(store.num_nodes, args.channels)
x = torch.randn(store.num_nodes, args.channels, device=args.device)
store.x = x
homo_data = hetero_data.to_homogeneous()
print(homo_data)
edge_sizes = homo_data.edge_type.bincount().cpu()
edge_type_ptr = torch.ops.torch_sparse.ind2ptr(
homo_data.edge_type,
args.num_edge_types,
)

conv = SequentialRGCN(args.num_edge_types, args.channels)
t = time.perf_counter()
out_dict = conv(hetero_data.x_dict, hetero_data.edge_index_dict)
conv = SequentialRGCN(args.num_edge_types, args.channels).to(args.device)
for i in range(args.warmup + args.iterations):
if i == args.warmup:
torch.cuda.synchronize()
t = time.perf_counter()
out_dict = conv(hetero_data.x_dict, hetero_data.edge_index_dict)
torch.cuda.synchronize()
t = time.perf_counter() - t
out1 = torch.cat([out for out in out_dict.values()])
print(f'{conv}: {t}')
print(out1[:5, :5])

conv = VerticalRGCN(args.num_edge_types, args.channels)
t = time.perf_counter()
out2 = conv(homo_data.x, homo_data.edge_index, homo_data.edge_type)
conv = VerticalRGCN(args.num_edge_types, args.channels).to(args.device)
for i in range(args.warmup + args.iterations):
if i == args.warmup:
torch.cuda.synchronize()
t = time.perf_counter()
out2 = conv(homo_data.x, homo_data.edge_index, homo_data.edge_type)
torch.cuda.synchronize()
t = time.perf_counter() - t
print(f'{conv}: {t}')
print(out2[:5, :5])

conv = HorizontalRGCN(args.num_edge_types, args.channels)
t = time.perf_counter()
out3 = conv(homo_data.x, homo_data.edge_index, homo_data.edge_type)
conv = HorizontalRGCN(args.num_edge_types, args.channels).to(args.device)
for i in range(args.warmup + args.iterations):
if i == args.warmup:
torch.cuda.synchronize()
t = time.perf_counter()
out3 = conv(homo_data.x, homo_data.edge_index, homo_data.edge_type)
torch.cuda.synchronize()
t = time.perf_counter() - t
print(f'{conv}: {t}')
print(out3[:5, :5])

conv = DGLTypedRGCN(args.num_edge_types, args.channels).to(args.device)
for i in range(args.warmup + args.iterations):
if i == args.warmup:
torch.cuda.synchronize()
t = time.perf_counter()
out4 = conv(homo_data.x, homo_data.edge_index, edge_sizes)
torch.cuda.synchronize()
t = time.perf_counter() - t
print(f'{conv}: {t}')

conv = CutlassTypedRGCN(args.num_edge_types, args.channels).to(args.device)
for i in range(args.warmup + args.iterations):
if i == args.warmup:
torch.cuda.synchronize()
t = time.perf_counter()
out5 = conv(homo_data.x, homo_data.edge_index, edge_type_ptr)
torch.cuda.synchronize()
t = time.perf_counter() - t
print(f'{conv}: {t}')

assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out2, out3, atol=1e-4)
assert torch.allclose(out3, out4, atol=1e-4)
assert torch.allclose(out4, out5, atol=1e-4)

0 comments on commit 0148a67

Please sign in to comment.