Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.0)
project(torchsparse)
set(CMAKE_CXX_STANDARD 14)
set(TORCHSPARSE_VERSION 0.6.13)
set(TORCHSPARSE_VERSION 0.7.0)

option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)
Expand Down
2 changes: 1 addition & 1 deletion conda/pytorch-sparse/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: pytorch-sparse
version: 0.6.13
version: 0.7.0

source:
path: ../..
Expand Down
92 changes: 38 additions & 54 deletions csrc/cpu/neighbor_sample_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,35 +114,31 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
}

bool satisfy_time_constraint(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const std::string &src_node_type,
const int64_t &dst_time,
const int64_t &sampled_node) {
bool satisfy_time_constraint(
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const node_t &src_node_type, const int64_t &dst_time,
const int64_t &src_node) {
// whether src -> dst obeys the time constraint
try {
const auto *src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
return dst_time < src_time[sampled_node];
}
catch (int err) {
auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
return dst_time < src_time[src_node];
} catch (int err) {
// if the node type does not have timestamp, fall back to normal sampling
return true;
}
}


template <bool replace, bool directed, bool temporal>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const vector<node_t> &node_types,
const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops,
const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
//bool temporal = (!node_time_dict.empty());

const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops,
const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
// Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type;
for (const auto &k : edge_types)
Expand Down Expand Up @@ -174,11 +170,12 @@ hetero_sample(const vector<node_t> &node_types,
const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>();
// dummy value. will be reset to root time if is_temporal==true
auto *node_time_data = input_node.data_ptr<int64_t>();
int64_t *node_time_data;
// root_time[i] stores the timestamp of the computation tree root
// of the node samples[i]
if (temporal) {
node_time_data = node_time_dict.at(node_type).data_ptr<int64_t>();
torch::Tensor node_time = node_time_dict.at(node_type);
node_time_data = node_time.data_ptr<int64_t>();
}

auto &samples = samples_dict.at(node_type);
Expand Down Expand Up @@ -220,7 +217,7 @@ hetero_sample(const vector<node_t> &node_types,

const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
if (begin == end){
if (begin == end) {
continue;
}
// for temporal sampling, sampled src node cannot have timestamp greater
Expand Down Expand Up @@ -370,22 +367,17 @@ hetero_sample(const vector<node_t> &node_types,
template <bool replace, bool directed>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample_random(const vector<node_t> &node_types,
const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) {
hetero_sample_random(
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) {
c10::Dict<node_t, torch::Tensor> empty_dict;
return hetero_sample<replace, directed, false>(node_types,
edge_types,
colptr_dict,
row_dict,
input_node_dict,
num_neighbors_dict,
num_hops,
empty_dict);
return hetero_sample<replace, directed, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, empty_dict);
}

} // namespace
Expand Down Expand Up @@ -418,24 +410,20 @@ hetero_neighbor_sample_cpu(
const int64_t num_hops, const bool replace, const bool directed) {

if (replace && directed) {
return hetero_sample_random<true, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample_random<true, true>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else if (replace && !directed) {
return hetero_sample_random<true, false>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else if (!replace && directed) {
return hetero_sample_random<false, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops);
} else {
return hetero_sample_random<false, false>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops);
}
}
Expand All @@ -453,23 +441,19 @@ hetero_neighbor_temporal_sample_cpu(

if (replace && directed) {
return hetero_sample<true, true, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
} else if (replace && !directed) {
return hetero_sample<true, false, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
} else if (!replace && directed) {
return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
} else {
return hetero_sample<false, false, true>(
node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict);
}
}
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ test = pytest

[tool:pytest]
addopts = --capture=no

[isort]
multi_line_output=3
include_trailing_comma = True
skip=.gitignore,__init__.py
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension)

__version__ = '0.6.13'
__version__ = '0.7.0'
URL = 'https://github.com/rusty1s/pytorch_sparse'

WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
Expand Down
2 changes: 1 addition & 1 deletion torch_sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

__version__ = '0.6.13'
__version__ = '0.7.0'

for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
Expand Down