Skip to content

Commit

Permalink
Replace std::map with phmap::flat_hash_map
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianSzwichtenberg committed Nov 28, 2022
1 parent ef064b1 commit 13c6526
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pyg_lib/csrc/ops/cpu/matmul_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <map>
#include <numeric>
#include <tuple>
#include <type_traits>
Expand All @@ -9,6 +8,7 @@
#include <ATen/Parallel.h>
#include <torch/library.h>

#include "parallel_hashmap/phmap.h"
#include "pyg_lib/csrc/config.h"
#include "pyg_lib/csrc/utils/convert.h"

Expand Down Expand Up @@ -178,7 +178,7 @@ void grouped_matmul_out_kernel_mkl_impl(const std::vector<at::Tensor> input,
std::vector<at::Tensor> out) {
// matrix_params<M, N, K>
using matrix_params = std::tuple<int, int, int>;
std::map<matrix_params, std::vector<size_t>> groups;
phmap::flat_hash_map<matrix_params, std::vector<size_t>> groups;
for (size_t i = 0; i < input.size(); ++i) {
const matrix_params mp = {input[i].size(0), other[i].size(-1),
input[i].size(-1)};
Expand Down Expand Up @@ -305,7 +305,7 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input,
const int n = other.size(-1);
const int k = input.size(-1);
const int nk = n * k;
std::map<int, std::vector<size_t>> groups;
phmap::flat_hash_map<int, std::vector<size_t>> groups;
std::vector<offset_params> offsets = {{0, 0, 0}};
offsets.reserve(sizes.size() + 1);
for (size_t i = 0; i < sizes.size(); ++i) {
Expand Down

0 comments on commit 13c6526

Please sign in to comment.