Skip to content

Commit

Permalink
消除冗余的变量
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Mar 11, 2024
1 parent 975dcde commit 72cfc15
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions source/layer/details/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,30 @@ void ConvolutionLayer::InitIm2ColWeight() {
CHECK(kernel->channels() == kernel_c);
}

std::vector<arma::fmat> kernel_matrix_arr(kernel_count);
kernel_matrix_arr_.resize(kernel_count);
for (uint32_t k = 0; k < kernel_count; ++k) {
arma::fmat kernel_matrix_c;
if (Is1x1KernelNoPadding(kernel_h, kernel_w)) {
kernel_matrix_c = arma::fmat(row_len * kernel_c, 1);
} else {
kernel_matrix_c = arma::fmat(1, row_len * kernel_c);
const std::shared_ptr<Tensor<float>>& kernel = this->weights_.at(k);
for (uint32_t ic = 0; ic < kernel->channels(); ++ic) {
memcpy(kernel_matrix_c.memptr() + row_len * ic, kernel->matrix_raw_ptr(ic),
row_len * sizeof(float));
}
}
const std::shared_ptr<Tensor<float>>& kernel = this->weights_.at(k);
for (uint32_t ic = 0; ic < kernel->channels(); ++ic) {
memcpy(kernel_matrix_c.memptr() + row_len * ic, kernel->matrix_raw_ptr(ic),
row_len * sizeof(float));
}
kernel_matrix_arr.at(k) = kernel_matrix_c;
}

if (!kernel_matrix_arr.empty()) {
if (groups_ == 1) {
CHECK(kernel_matrix_arr.size() == kernel_count / groups_)
<< "The number of kernel matrix and kernel_count_group do not match";
} else {
CHECK(kernel_matrix_arr.size() == kernel_count)
<< "The number of kernel matrix and kernel_count do not match";
}
kernel_matrix_arr_.at(k) = kernel_matrix_c;
}

this->kernel_matrix_arr_ = std::move(kernel_matrix_arr);
if (groups_ == 1) {
CHECK(kernel_matrix_arr_.size() == kernel_count / groups_)
<< "The number of kernel matrix and kernel_count_group do not match";
} else {
CHECK(kernel_matrix_arr_.size() == kernel_count)
<< "The number of kernel matrix and kernel_count do not match";
}
}

void ConvolutionLayer::ComputeOutput(sftensor input, sftensor output_tensor, uint32_t kernel_h,
Expand Down

0 comments on commit 72cfc15

Please sign in to comment.