From 21ffef33268efaea34e1743cd81e276c69928e7a Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 22 Apr 2022 15:51:03 +0000 Subject: [PATCH 1/5] version up --- CMakeLists.txt | 2 +- conda/pytorch-sparse/meta.yaml | 2 +- setup.py | 2 +- torch_sparse/__init__.py | 43 +++++++++++++++++----------------- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0759e9a1..070a4127 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.0) project(torchsparse) set(CMAKE_CXX_STANDARD 14) -set(TORCHSPARSE_VERSION 0.6.13) +set(TORCHSPARSE_VERSION 0.7.0) option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_PYTHON "Link to Python when building" ON) diff --git a/conda/pytorch-sparse/meta.yaml b/conda/pytorch-sparse/meta.yaml index 06d337a1..84b4b0e5 100644 --- a/conda/pytorch-sparse/meta.yaml +++ b/conda/pytorch-sparse/meta.yaml @@ -1,6 +1,6 @@ package: name: pytorch-sparse - version: 0.6.13 + version: 0.7.0 source: path: ../.. diff --git a/setup.py b/setup.py index ae69f75c..1fe1e782 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, CUDAExtension) -__version__ = '0.6.13' +__version__ = '0.7.0' URL = 'https://github.com/rusty1s/pytorch_sparse' WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None diff --git a/torch_sparse/__init__.py b/torch_sparse/__init__.py index bf3d939b..d5d2a3d9 100644 --- a/torch_sparse/__init__.py +++ b/torch_sparse/__init__.py @@ -3,7 +3,7 @@ import torch -__version__ = '0.6.13' +__version__ = '0.7.0' for library in [ '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw', @@ -37,35 +37,34 @@ f'{major}.{minor}. Please reinstall the torch_sparse that ' f'matches your PyTorch install.') -from .storage import SparseStorage # noqa -from .tensor import SparseTensor # noqa -from .transpose import t # noqa -from .narrow import narrow, __narrow_diag__ # noqa -from .select import select # noqa +from .add import add, add_, add_nnz, add_nnz_ # noqa +from .bandwidth import reverse_cuthill_mckee # noqa +from .cat import cat # noqa +from .coalesce import coalesce # noqa +from .convert import to_scipy # noqa +from .convert import from_scipy, from_torch_sparse, to_torch_sparse +from .diag import fill_diag, get_diag, remove_diag, set_diag # noqa +from .eye import eye # noqa from .index_select import index_select, index_select_nnz # noqa from .masked_select import masked_select, masked_select_nnz # noqa -from .permute import permute # noqa -from .diag import remove_diag, set_diag, fill_diag, get_diag # noqa -from .add import add, add_, add_nnz, add_nnz_ # noqa -from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa -from .reduce import sum, mean, min, max # noqa from .matmul import matmul # noqa -from .cat import cat # noqa -from .rw import random_walk # noqa from .metis import partition # noqa -from .bandwidth import reverse_cuthill_mckee # noqa -from .saint import saint_subgraph # noqa +from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa +from .narrow import __narrow_diag__, narrow # noqa from .padding import padded_index, padded_index_select # noqa +from .permute import permute # noqa +from .reduce import max, mean, min, sum # noqa +from .rw import random_walk # noqa +from .saint import saint_subgraph # noqa from .sample import sample, sample_adj # noqa - -from .convert import to_torch_sparse, from_torch_sparse # noqa -from .convert import to_scipy, from_scipy # noqa -from .coalesce import coalesce # noqa -from .transpose import transpose # noqa -from .eye import eye # noqa +from .select import select # noqa +from .spadd import spadd # noqa from .spmm import spmm # noqa from .spspmm import spspmm # noqa -from .spadd import spadd # noqa +from .storage import SparseStorage # noqa +from .tensor import SparseTensor # noqa +from .transpose import t # noqa +from .transpose import transpose # noqa __all__ = [ 'SparseStorage', From 8b8c1214bce306d52b473ecd4ce17b53c6fef9e0 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 22 Apr 2022 16:03:57 +0000 Subject: [PATCH 2/5] formatting --- csrc/cpu/neighbor_sample_cpu.cpp | 85 +++++++++++++------------------- 1 file changed, 35 insertions(+), 50 deletions(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index a5aba4de..6b8efa6e 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -114,34 +114,32 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, from_vector(cols), from_vector(edges)); } -bool satisfy_time_constraint(const c10::Dict &node_time_dict, - const std::string &src_node_type, - const int64_t &dst_time, - const int64_t &sampled_node) { +bool satisfy_time_constraint( + const c10::Dict &node_time_dict, + const std::string &src_node_type, const int64_t &dst_time, + const int64_t &sampled_node) { // whether src -> dst obeys the time constraint try { - const auto *src_time = node_time_dict.at(src_node_type).data_ptr(); + const auto *src_time = node_time_dict[src_node_type].data_ptr(); return dst_time < src_time[sampled_node]; - } - catch (int err) { + } catch (int err) { // if the node type does not have timestamp, fall back to normal sampling return true; } } - template tuple, c10::Dict, c10::Dict, c10::Dict> hetero_sample(const vector &node_types, - const vector &edge_types, - const c10::Dict &colptr_dict, - const c10::Dict &row_dict, - const c10::Dict &input_node_dict, - const c10::Dict> &num_neighbors_dict, - const int64_t num_hops, - const c10::Dict &node_time_dict) { - //bool temporal = (!node_time_dict.empty()); + const vector &edge_types, + const c10::Dict &colptr_dict, + const c10::Dict &row_dict, + const c10::Dict &input_node_dict, + const c10::Dict> &num_neighbors_dict, + const int64_t num_hops, + const c10::Dict &node_time_dict) { + // bool temporal = (!node_time_dict.empty()); // Create a mapping to convert single string relations to edge type triplets: unordered_map to_edge_type; @@ -220,7 +218,7 @@ hetero_sample(const vector &node_types, const auto &begin = slice_dict.at(dst_node_type).first; const auto &end = slice_dict.at(dst_node_type).second; - if (begin == end){ + if (begin == end) { continue; } // for temporal sampling, sampled src node cannot have timestamp greater @@ -370,22 +368,17 @@ hetero_sample(const vector &node_types, template tuple, c10::Dict, c10::Dict, c10::Dict> -hetero_sample_random(const vector &node_types, - const vector &edge_types, - const c10::Dict &colptr_dict, - const c10::Dict &row_dict, - const c10::Dict &input_node_dict, - const c10::Dict> &num_neighbors_dict, - const int64_t num_hops) { +hetero_sample_random( + const vector &node_types, const vector &edge_types, + const c10::Dict &colptr_dict, + const c10::Dict &row_dict, + const c10::Dict &input_node_dict, + const c10::Dict> &num_neighbors_dict, + const int64_t num_hops) { c10::Dict empty_dict; - return hetero_sample(node_types, - edge_types, - colptr_dict, - row_dict, - input_node_dict, - num_neighbors_dict, - num_hops, - empty_dict); + return hetero_sample( + node_types, edge_types, colptr_dict, row_dict, input_node_dict, + num_neighbors_dict, num_hops, empty_dict); } } // namespace @@ -418,24 +411,20 @@ hetero_neighbor_sample_cpu( const int64_t num_hops, const bool replace, const bool directed) { if (replace && directed) { - return hetero_sample_random( - node_types, edge_types, colptr_dict, - row_dict, input_node_dict, - num_neighbors_dict, num_hops); + return hetero_sample_random(node_types, edge_types, colptr_dict, + row_dict, input_node_dict, + num_neighbors_dict, num_hops); } else if (replace && !directed) { return hetero_sample_random( - node_types, edge_types, colptr_dict, - row_dict, input_node_dict, + node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, num_hops); } else if (!replace && directed) { return hetero_sample_random( - node_types, edge_types, colptr_dict, - row_dict, input_node_dict, + node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, num_hops); } else { return hetero_sample_random( - node_types, edge_types, colptr_dict, - row_dict, input_node_dict, + node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, num_hops); } } @@ -453,23 +442,19 @@ hetero_neighbor_temporal_sample_cpu( if (replace && directed) { return hetero_sample( - node_types, edge_types, colptr_dict, - row_dict, input_node_dict, + node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, num_hops, node_time_dict); } else if (replace && !directed) { return hetero_sample( - node_types, edge_types, colptr_dict, - row_dict, input_node_dict, + node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, num_hops, node_time_dict); } else if (!replace && directed) { return hetero_sample( - node_types, edge_types, colptr_dict, - row_dict, input_node_dict, + node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, num_hops, node_time_dict); } else { return hetero_sample( - node_types, edge_types, colptr_dict, - row_dict, input_node_dict, + node_types, edge_types, colptr_dict, row_dict, input_node_dict, num_neighbors_dict, num_hops, node_time_dict); } } From 31ce7d038c052b139a56f6aabe252e7105925e8f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 22 Apr 2022 16:57:06 +0000 Subject: [PATCH 3/5] fix --- csrc/cpu/neighbor_sample_cpu.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 6b8efa6e..905bd676 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -116,12 +116,12 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, bool satisfy_time_constraint( const c10::Dict &node_time_dict, - const std::string &src_node_type, const int64_t &dst_time, - const int64_t &sampled_node) { + const node_t &src_node_type, const int64_t &dst_time, + const int64_t &src_node) { // whether src -> dst obeys the time constraint try { - const auto *src_time = node_time_dict[src_node_type].data_ptr(); - return dst_time < src_time[sampled_node]; + auto src_time = node_time_dict.at(src_node_type).data_ptr(); + return dst_time < src_time[src_node]; } catch (int err) { // if the node type does not have timestamp, fall back to normal sampling return true; @@ -139,8 +139,6 @@ hetero_sample(const vector &node_types, const c10::Dict> &num_neighbors_dict, const int64_t num_hops, const c10::Dict &node_time_dict) { - // bool temporal = (!node_time_dict.empty()); - // Create a mapping to convert single string relations to edge type triplets: unordered_map to_edge_type; for (const auto &k : edge_types) @@ -172,11 +170,12 @@ hetero_sample(const vector &node_types, const torch::Tensor &input_node = kv.value(); const auto *input_node_data = input_node.data_ptr(); // dummy value. will be reset to root time if is_temporal==true - auto *node_time_data = input_node.data_ptr(); + int64_t *node_time_data; // root_time[i] stores the timestamp of the computation tree root // of the node samples[i] if (temporal) { - node_time_data = node_time_dict.at(node_type).data_ptr(); + torch::Tensor &node_time = node_time_dict.at(node_type); + node_time_data = node_time.data_ptr(); } auto &samples = samples_dict.at(node_type); From 24f67ad452152131c7487f21b25ecc4fcbd531b9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 22 Apr 2022 17:00:15 +0000 Subject: [PATCH 4/5] reset --- setup.cfg | 5 +++++ torch_sparse/__init__.py | 41 ++++++++++++++++++++-------------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/setup.cfg b/setup.cfg index dafda378..40fedbd7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,3 +17,8 @@ test = pytest [tool:pytest] addopts = --capture=no + +[isort] +multi_line_output=3 +include_trailing_comma = True +skip=.gitignore,__init__.py diff --git a/torch_sparse/__init__.py b/torch_sparse/__init__.py index d5d2a3d9..687e295f 100644 --- a/torch_sparse/__init__.py +++ b/torch_sparse/__init__.py @@ -37,34 +37,35 @@ f'{major}.{minor}. Please reinstall the torch_sparse that ' f'matches your PyTorch install.') -from .add import add, add_, add_nnz, add_nnz_ # noqa -from .bandwidth import reverse_cuthill_mckee # noqa -from .cat import cat # noqa -from .coalesce import coalesce # noqa -from .convert import to_scipy # noqa -from .convert import from_scipy, from_torch_sparse, to_torch_sparse -from .diag import fill_diag, get_diag, remove_diag, set_diag # noqa -from .eye import eye # noqa +from .storage import SparseStorage # noqa +from .tensor import SparseTensor # noqa +from .transpose import t # noqa +from .narrow import narrow, __narrow_diag__ # noqa +from .select import select # noqa from .index_select import index_select, index_select_nnz # noqa from .masked_select import masked_select, masked_select_nnz # noqa -from .matmul import matmul # noqa -from .metis import partition # noqa -from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa -from .narrow import __narrow_diag__, narrow # noqa -from .padding import padded_index, padded_index_select # noqa from .permute import permute # noqa -from .reduce import max, mean, min, sum # noqa +from .diag import remove_diag, set_diag, fill_diag, get_diag # noqa +from .add import add, add_, add_nnz, add_nnz_ # noqa +from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa +from .reduce import sum, mean, min, max # noqa +from .matmul import matmul # noqa +from .cat import cat # noqa from .rw import random_walk # noqa +from .metis import partition # noqa +from .bandwidth import reverse_cuthill_mckee # noqa from .saint import saint_subgraph # noqa +from .padding import padded_index, padded_index_select # noqa from .sample import sample, sample_adj # noqa -from .select import select # noqa -from .spadd import spadd # noqa + +from .convert import to_torch_sparse, from_torch_sparse # noqa +from .convert import to_scipy, from_scipy # noqa +from .coalesce import coalesce # noqa +from .transpose import transpose # noqa +from .eye import eye # noqa from .spmm import spmm # noqa from .spspmm import spspmm # noqa -from .storage import SparseStorage # noqa -from .tensor import SparseTensor # noqa -from .transpose import t # noqa -from .transpose import transpose # noqa +from .spadd import spadd # noqa __all__ = [ 'SparseStorage', From a418e03c4cc40f87b120a199dfbbfea228f3a47d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 22 Apr 2022 18:04:11 +0000 Subject: [PATCH 5/5] revert --- csrc/cpu/neighbor_sample_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpu/neighbor_sample_cpu.cpp b/csrc/cpu/neighbor_sample_cpu.cpp index 905bd676..e18a1f91 100644 --- a/csrc/cpu/neighbor_sample_cpu.cpp +++ b/csrc/cpu/neighbor_sample_cpu.cpp @@ -174,7 +174,7 @@ hetero_sample(const vector &node_types, // root_time[i] stores the timestamp of the computation tree root // of the node samples[i] if (temporal) { - torch::Tensor &node_time = node_time_dict.at(node_type); + torch::Tensor node_time = node_time_dict.at(node_type); node_time_data = node_time.data_ptr(); }