Skip to content

Commit be7976f

Browse files
[fix] Fix the wrong formula
1 parent 279a24f commit be7976f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/pytorch_cpp_wrapper_base.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ PyTorchCppWrapperBase::get_entropy(at::Tensor input_tensor)
117117
input_tensor.to(torch::kCUDA);
118118
// Calculate the entropy at each pixel
119119
at::Tensor log_p = torch::log_softmax(input_tensor, /*dim=*/1);//at::argmax(input_tensor, 1).to(torch::kCPU).to(at::kByte);
120-
at::Tensor p = torch::log_softmax(input_tensor, /*dim=*/1);
120+
at::Tensor p = torch::softmax(input_tensor, /*dim=*/1);
121121

122-
at::Tensor entropy = torch::sum(p * log_p, /*dim=*/1);
122+
at::Tensor entropy = -torch::sum(p * log_p, /*dim=*/1);
123123

124124
return entropy;
125125
}

0 commit comments

Comments
 (0)