Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/cpu/diag_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M,
auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();

auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask = torch::zeros({E + num_diag}, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>();

int64_t r, c;
Expand Down
4 changes: 2 additions & 2 deletions csrc/cpu/metis_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
vwgt = optional_node_weight.value().data_ptr<int64_t>();

int64_t objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part = torch::empty({nvtxs}, rowptr.options());
auto part_data = part.data_ptr<int64_t>();

if (recursive) {
Expand Down Expand Up @@ -99,7 +99,7 @@ mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,

mtmetis_pid_type nparts = num_parts;
mtmetis_wgt_type objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part = torch::empty({nvtxs}, rowptr.options());
mtmetis_pid_type *part_data = (mtmetis_pid_type *)part.data_ptr<int64_t>();

double *opts = mtmetis_init_options();
Expand Down
6 changes: 3 additions & 3 deletions csrc/cpu/relabel_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
std::unordered_map<int64_t, int64_t> n_id_map;
std::unordered_map<int64_t, int64_t>::iterator it;

auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
auto out_rowptr = torch::empty({idx.numel() + 1}, rowptr.options());
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();

out_rowptr_data[0] = 0;
Expand All @@ -76,12 +76,12 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
out_rowptr_data[i + 1] = offset;
}

auto out_col = torch::empty(offset, col.options());
auto out_col = torch::empty({offset}, col.options());
auto out_col_data = out_col.data_ptr<int64_t>();

torch::optional<torch::Tensor> out_value = torch::nullopt;
if (optional_value.has_value()) {
out_value = torch::empty(offset, optional_value.value().options());
out_value = torch::empty({offset}, optional_value.value().options());

AT_DISPATCH_ALL_TYPES(optional_value.value().scalar_type(), "relabel", [&] {
auto value_data = optional_value.value().data_ptr<scalar_t>();
Expand Down
6 changes: 3 additions & 3 deletions csrc/cpu/sample_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
auto col_data = col.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>();

auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
auto out_rowptr = torch::empty({idx.numel() + 1}, rowptr.options());
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0;

Expand Down Expand Up @@ -117,9 +117,9 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone();

int64_t E = out_rowptr_data[idx.numel()];
auto out_col = torch::empty(E, col.options());
auto out_col = torch::empty({E}, col.options());
auto out_col_data = out_col.data_ptr<int64_t>();
auto out_e_id = torch::empty(E, col.options());
auto out_e_id = torch::empty({E}, col.options());
auto out_e_id_data = out_e_id.data_ptr<int64_t>();

i = 0;
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/spmm_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);

auto out = torch::zeros(row.numel(), grad.options());
auto out = torch::zeros({row.numel()}, grad.options());

auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
Expand Down
4 changes: 2 additions & 2 deletions csrc/cpu/spspmm_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,

if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options());
torch::ones({colA.numel()}, optional_valueB.value().options());

if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options());
torch::ones({colB.numel()}, optional_valueA.value().options());

auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
Expand Down
4 changes: 2 additions & 2 deletions csrc/cpu/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
return torch::multinomial(weight.value(), num_samples, replace);

if (replace) {
const auto out = torch::empty(num_samples, at::kLong);
const auto out = torch::empty({num_samples}, at::kLong);
auto *out_data = out.data_ptr<int64_t>();
for (int64_t i = 0; i < num_samples; i++) {
out_data[i] = uniform_randint(population);
Expand All @@ -72,7 +72,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
// Sample without replacement via Robert Floyd algorithm:
// https://www.nowherenearithaca.com/2013/05/
// robert-floyds-tiny-and-beautiful.html
const auto out = torch::empty(num_samples, at::kLong);
const auto out = torch::empty({num_samples}, at::kLong);
auto *out_data = out.data_ptr<int64_t>();
std::unordered_set<int64_t> samples;
for (int64_t i = population - num_samples; i < population; i++) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/cuda/convert_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
CHECK_CUDA(ind);
cudaSetDevice(ind.get_device());

auto out = torch::empty(M + 1, ind.options());
auto out = torch::empty({M + 1}, ind.options());

if (ind.numel() == 0)
return out.zero_();
Expand Down Expand Up @@ -57,7 +57,7 @@ torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
CHECK_CUDA(ptr);
cudaSetDevice(ptr.get_device());

auto out = torch::empty(E, ptr.options());
auto out = torch::empty({E}, ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/diag_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();

auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask = torch::zeros({E + num_diag}, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>();

if (E == 0)
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/spmm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);

auto out = torch::zeros(row.numel(), grad.options());
auto out = torch::zeros({row.numel()}, grad.options());

auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
Expand Down
10 changes: 5 additions & 5 deletions csrc/cuda/spspmm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,

if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options());
torch::ones({colA.numel()}, optional_valueB.value().options());

if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options());
torch::ones({colB.numel()}, optional_valueA.value().options());

auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
Expand Down Expand Up @@ -108,19 +108,19 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
cudaMalloc(&buffer, bufferSize);

// Step 3: Compute CSR row pointer.
rowptrC = torch::empty(M + 1, rowptrA.options());
rowptrC = torch::empty({M + 1}, rowptrA.options());
auto rowptrC_data = rowptrC.data_ptr<int>();
cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data,
colB_data, descr, 0, NULL, NULL, descr, rowptrC_data,
nnzTotalDevHostPtr, info, buffer);

// Step 4: Compute CSR entries.
colC = torch::empty(nnzC, rowptrC.options());
colC = torch::empty({nnzC}, rowptrC.options());
auto colC_data = colC.data_ptr<int>();

if (optional_valueA.has_value())
optional_valueC = torch::empty(nnzC, optional_valueA.value().options());
optional_valueC = torch::empty({nnzC}, optional_valueA.value().options());

scalar_t *valA_data = NULL, *valB_data = NULL, *valC_data = NULL;
if (optional_valueA.has_value()) {
Expand Down