From 0148a67297c9327c7e4db25cc33464a0cafde968 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 29 Aug 2022 11:57:47 +0000 Subject: [PATCH] update --- benchmark/hetero/rgcn.py | 121 +++++++++++++++++++++++++++++++-------- 1 file changed, 98 insertions(+), 23 deletions(-) diff --git a/benchmark/hetero/rgcn.py b/benchmark/hetero/rgcn.py index 6d80ab2f09f8..338b6e778352 100644 --- a/benchmark/hetero/rgcn.py +++ b/benchmark/hetero/rgcn.py @@ -2,6 +2,8 @@ import time from collections import defaultdict +import dgl +import pyg_lib import torch from torch_geometric import seed_everything @@ -11,18 +13,22 @@ SEED = 12345 parser = argparse.ArgumentParser() -parser.add_argument('--num_node_types', type=int, default=10) -parser.add_argument('--num_edge_types', type=int, default=100) -parser.add_argument('--avg_num_nodes', type=int, default=100) +parser.add_argument('--num_node_types', type=int, default=30) +parser.add_argument('--num_edge_types', type=int, default=60) +parser.add_argument('--avg_num_nodes', type=int, default=400) parser.add_argument('--avg_degree', type=int, default=10) parser.add_argument('--channels', type=int, default=64) +parser.add_argument('--device', type=str, default='cpu') +parser.add_argument('--warmup', type=str, default=10) +parser.add_argument('--iterations', type=int, default=100) class SequentialRGCN(MessagePassing): def __init__(self, num_edge_types, channels): super().__init__(aggr='sum') torch.manual_seed(SEED) - self.weight = torch.randn(num_edge_types, channels, channels) + weight = torch.randn(num_edge_types, channels, channels) + self.register_buffer('weight', weight) def forward(self, x_dict, edge_index_dict): outs_dict = defaultdict(list) @@ -33,8 +39,8 @@ def forward(self, x_dict, edge_index_dict): outs_dict[dst].append(out) out_dict = {} - for key in x_dict.keys(): - out_dict[key] = torch.stack(outs_dict[key], dim=0).sum(dim=0) + for key, outs in outs_dict.items(): + out_dict[key] = torch.stack(outs, dim=0).sum(dim=0) return out_dict @@ -44,7 +50,8 @@ class VerticalRGCN(MessagePassing): def __init__(self, num_edge_types, channels): super().__init__(aggr='sum') torch.manual_seed(SEED) - self.weight = torch.randn(num_edge_types, channels, channels) + weight = torch.randn(num_edge_types, channels, channels) + self.register_buffer('weight', weight) def forward(self, x, edge_index, edge_type): edge_index = edge_index.clone() @@ -68,7 +75,8 @@ class HorizontalRGCN(MessagePassing): def __init__(self, num_edge_types, channels): super().__init__(aggr='sum') torch.manual_seed(SEED) - self.weight = torch.randn(num_edge_types, channels, channels) + weight = torch.randn(num_edge_types, channels, channels) + self.register_buffer('weight', weight) def forward(self, x, edge_index, edge_type): edge_index = edge_index.clone() @@ -86,6 +94,34 @@ def forward(self, x, edge_index, edge_type): return out +class DGLTypedRGCN(MessagePassing): + def __init__(self, num_edge_types, channels): + super().__init__(aggr='sum') + torch.manual_seed(SEED) + weight = torch.randn(num_edge_types, channels, channels) + self.register_buffer('weight', weight) + + def forward(self, x, edge_index, edge_sizes): + return self.propagate(edge_index, x=x, edge_sizes=edge_sizes) + + def message(self, x_j, edge_sizes): + return dgl.ops.segment_mm(x_j, self.weight, edge_sizes) + + +class CutlassTypedRGCN(MessagePassing): + def __init__(self, num_edge_types, channels): + super().__init__(aggr='sum') + torch.manual_seed(SEED) + weight = torch.randn(num_edge_types, channels, channels) + self.register_buffer('weight', weight) + + def forward(self, x, edge_index, edge_type_ptr): + return self.propagate(edge_index, x=x, edge_type_ptr=edge_type_ptr) + + def message(self, x_j, edge_type_ptr): + return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight) + + if __name__ == '__main__': args = parser.parse_args() print(args) @@ -98,31 +134,70 @@ def forward(self, x, edge_index, edge_type): avg_num_nodes=args.avg_num_nodes, avg_degree=args.avg_degree, ) - hetero_data = dataset[0] + hetero_data = dataset[0].to(args.device) for node_type in hetero_data.node_types: store = hetero_data[node_type] - store.x = torch.randn(store.num_nodes, args.channels) + x = torch.randn(store.num_nodes, args.channels, device=args.device) + store.x = x homo_data = hetero_data.to_homogeneous() - print(homo_data) + edge_sizes = homo_data.edge_type.bincount().cpu() + edge_type_ptr = torch.ops.torch_sparse.ind2ptr( + homo_data.edge_type, + args.num_edge_types, + ) - conv = SequentialRGCN(args.num_edge_types, args.channels) - t = time.perf_counter() - out_dict = conv(hetero_data.x_dict, hetero_data.edge_index_dict) + conv = SequentialRGCN(args.num_edge_types, args.channels).to(args.device) + for i in range(args.warmup + args.iterations): + if i == args.warmup: + torch.cuda.synchronize() + t = time.perf_counter() + out_dict = conv(hetero_data.x_dict, hetero_data.edge_index_dict) + torch.cuda.synchronize() t = time.perf_counter() - t out1 = torch.cat([out for out in out_dict.values()]) print(f'{conv}: {t}') - print(out1[:5, :5]) - conv = VerticalRGCN(args.num_edge_types, args.channels) - t = time.perf_counter() - out2 = conv(homo_data.x, homo_data.edge_index, homo_data.edge_type) + conv = VerticalRGCN(args.num_edge_types, args.channels).to(args.device) + for i in range(args.warmup + args.iterations): + if i == args.warmup: + torch.cuda.synchronize() + t = time.perf_counter() + out2 = conv(homo_data.x, homo_data.edge_index, homo_data.edge_type) + torch.cuda.synchronize() t = time.perf_counter() - t print(f'{conv}: {t}') - print(out2[:5, :5]) - conv = HorizontalRGCN(args.num_edge_types, args.channels) - t = time.perf_counter() - out3 = conv(homo_data.x, homo_data.edge_index, homo_data.edge_type) + conv = HorizontalRGCN(args.num_edge_types, args.channels).to(args.device) + for i in range(args.warmup + args.iterations): + if i == args.warmup: + torch.cuda.synchronize() + t = time.perf_counter() + out3 = conv(homo_data.x, homo_data.edge_index, homo_data.edge_type) + torch.cuda.synchronize() t = time.perf_counter() - t print(f'{conv}: {t}') - print(out3[:5, :5]) + + conv = DGLTypedRGCN(args.num_edge_types, args.channels).to(args.device) + for i in range(args.warmup + args.iterations): + if i == args.warmup: + torch.cuda.synchronize() + t = time.perf_counter() + out4 = conv(homo_data.x, homo_data.edge_index, edge_sizes) + torch.cuda.synchronize() + t = time.perf_counter() - t + print(f'{conv}: {t}') + + conv = CutlassTypedRGCN(args.num_edge_types, args.channels).to(args.device) + for i in range(args.warmup + args.iterations): + if i == args.warmup: + torch.cuda.synchronize() + t = time.perf_counter() + out5 = conv(homo_data.x, homo_data.edge_index, edge_type_ptr) + torch.cuda.synchronize() + t = time.perf_counter() - t + print(f'{conv}: {t}') + + assert torch.allclose(out1, out2, atol=1e-4) + assert torch.allclose(out2, out3, atol=1e-4) + assert torch.allclose(out3, out4, atol=1e-4) + assert torch.allclose(out4, out5, atol=1e-4)