Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for ARM64 #3059

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions make/compiler_flags
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ endif

## Set OS specific library filename extensions
ifeq ($(OS),Windows_NT)
WINARM64 := $(shell echo | $(CXX) -E -dM - | findstr __aarch64__)
LIBRARY_SUFFIX ?= .dll
STR_SEARCH ?= findstr
endif

ifeq ($(OS),Darwin)
LIBRARY_SUFFIX ?= .dylib
STR_SEARCH ?= grep
endif

ifeq ($(OS),Linux)
LIBRARY_SUFFIX ?= .so
STR_SEARCH ?= grep
endif

## Set default compiler
Expand All @@ -42,6 +44,11 @@ ifeq (default,$(origin CXX))
endif
endif

ARM64_CHECK := $(shell echo | $(CXX) -E -dM - | $(STR_SEARCH) __aarch64__)
ifneq ($(ARM64_CHECK),)
ARM64 = true
endif

# Detect compiler type
# - CXX_TYPE: {gcc, clang, mingw32-gcc, other}
# - CXX_MAJOR: major version of CXX
Expand Down Expand Up @@ -164,7 +171,7 @@ ifeq ($(OS),Windows_NT)

make/ucrt:
pound := \#
UCRT_STRING := $(shell echo '$(pound)include <windows.h>' | $(CXX) -E -dM - | findstr _UCRT)
UCRT_STRING := $(shell echo '$(pound)include <windows.h>' | $(CXX) -E -dM - | $(STR_SEARCH) _UCRT)
ifneq (,$(UCRT_STRING))
IS_UCRT ?= true
else
Expand Down Expand Up @@ -211,6 +218,10 @@ endif
## makes reentrant version lgamma_r available from cmath
CXXFLAGS_OS += -D_REENTRANT

ifeq ($(ARM64), true)
CXXFLAGS_OS += -ffp-contract=off
endif

## silence warnings occuring due to the TBB and Eigen libraries
CXXFLAGS_WARNINGS += -Wno-ignored-attributes

Expand Down Expand Up @@ -275,7 +286,7 @@ endif
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,--disable-new-dtags

# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
ifeq ($(WINARM64),)
ifneq ($(OS), Windows_NT)
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_LIB)"
endif

Expand All @@ -299,7 +310,7 @@ CXXFLAGS_TBB ?= -I $(TBB)/include
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)

# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
ifeq ($(WINARM64),)
ifneq ($(OS), Windows_NT)
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)"
endif
LDLIBS_TBB ?= -ltbb
Expand Down
9 changes: 5 additions & 4 deletions make/libraries
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,11 @@ ifeq (Windows_NT, $(OS))
TBB_CXXFLAGS += -D_UCRT
endif
# TBB does not have assembly code for Windows ARM64, so we need to use GCC builtins
ifneq ($(WINARM64),)
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
endif
ifeq ($(ARM64),true)
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
WINARM64 = true
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
endif
SH_CHECK := $(shell command -v sh 2>/dev/null)
ifdef SH_CHECK
WINDOWS_HAS_SH ?= true
Expand Down
6 changes: 6 additions & 0 deletions stan/math/prim/fun/inv_sqrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ inline auto inv_sqrt(const Container& x) {
template <typename Container, require_not_var_matrix_t<Container>* = nullptr,
require_container_st<std::is_arithmetic, Container>* = nullptr>
inline auto inv_sqrt(const Container& x) {
// Eigen 3.4.0 has precision issues on ARM64 with vectorised rsqrt
// Resolved in current master branch, below can be removed on next release
#ifdef __aarch64__
return apply_scalar_unary<inv_sqrt_fun, Container>::apply(x);
#else
return apply_vector_unary<Container>::apply(
x, [](const auto& v) { return v.array().rsqrt(); });
#endif
}

} // namespace math
Expand Down
6 changes: 4 additions & 2 deletions test/unit/math/fwd/core/std_numeric_limits_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ TEST(AgradFwdNumericLimits, All_Fvar) {
EXPECT_FALSE(std::numeric_limits<fvar<double> >::traps);
EXPECT_FALSE(std::numeric_limits<fvar<fvar<double> > >::traps);

EXPECT_FALSE(std::numeric_limits<fvar<double> >::tinyness_before);
EXPECT_FALSE(std::numeric_limits<fvar<fvar<double> > >::tinyness_before);
EXPECT_EQ(std::numeric_limits<fvar<double> >::tinyness_before,
std::numeric_limits<double>::tinyness_before);
EXPECT_EQ(std::numeric_limits<fvar<fvar<double> > >::tinyness_before,
std::numeric_limits<double>::tinyness_before);

EXPECT_TRUE(std::numeric_limits<fvar<double> >::round_style);
EXPECT_TRUE(std::numeric_limits<fvar<fvar<double> > >::round_style);
Expand Down
6 changes: 4 additions & 2 deletions test/unit/math/mix/core/std_numeric_limits_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ TEST(AgradMixNumericLimits, All_Fvar) {
EXPECT_FALSE(std::numeric_limits<fvar<var> >::traps);
EXPECT_FALSE(std::numeric_limits<fvar<fvar<var> > >::traps);

EXPECT_FALSE(std::numeric_limits<fvar<var> >::tinyness_before);
EXPECT_FALSE(std::numeric_limits<fvar<fvar<var> > >::tinyness_before);
EXPECT_EQ(std::numeric_limits<fvar<var> >::tinyness_before,
std::numeric_limits<double>::tinyness_before);
EXPECT_EQ(std::numeric_limits<fvar<fvar<var> > >::tinyness_before,
std::numeric_limits<double>::tinyness_before);

EXPECT_TRUE(std::numeric_limits<fvar<var> >::round_style);
EXPECT_TRUE(std::numeric_limits<fvar<fvar<var> > >::round_style);
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/prim/fun/offset_multiplier_transform_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ TEST(prob_transform, offset_multiplier_constrain_matrix) {
EXPECT_FLOAT_EQ(result(i), stan::math::offset_multiplier_constrain(
x(i), offsetd, sigma(i), lp1));
}
EXPECT_EQ(lp0, lp1);
EXPECT_FLOAT_EQ(lp0, lp1);
auto x_free = stan::math::offset_multiplier_free(result, offsetd, sigma);
for (size_t i = 0; i < x.size(); ++i) {
EXPECT_FLOAT_EQ(x.coeff(i), x_free.coeff(i));
Expand All @@ -211,7 +211,7 @@ TEST(prob_transform, offset_multiplier_constrain_matrix) {
EXPECT_FLOAT_EQ(result(i), stan::math::offset_multiplier_constrain(
x(i), offset(i), sigma(i), lp1));
}
EXPECT_EQ(lp0, lp1);
EXPECT_FLOAT_EQ(lp0, lp1);
auto x_free = stan::math::offset_multiplier_free(result, offset, sigma);
for (size_t i = 0; i < x.size(); ++i) {
EXPECT_FLOAT_EQ(x.coeff(i), x_free.coeff(i));
Expand Down
7 changes: 5 additions & 2 deletions test/unit/math/prim/prob/neg_binomial_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,11 @@ TEST(ProbDistributionsNegBinomial, chiSquareGoodnessFitTest3) {

double chi = 0;

for (int j = 0; j < K; j++)
chi += ((bin[j] - expect[j]) * (bin[j] - expect[j]) / expect[j]);
for (int j = 0; j < K; j++) {
if (expect[j] != 0) {
chi += ((bin[j] - expect[j]) * (bin[j] - expect[j]) / expect[j]);
}
}

EXPECT_LT(chi, boost::math::quantile(boost::math::complement(mydist, 1e-6)));
}
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/test_ad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,7 @@ void expect_common_unary_vectorized(const F& f) {
for (double x1 : args)
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
auto int_args = internal::common_int_args();
for (int x1 : args)
for (int x1 : int_args)
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
}

Expand Down Expand Up @@ -2022,7 +2022,7 @@ void expect_common_unary_vectorized(const F& f) {
for (double x1 : args)
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
auto int_args = internal::common_int_args();
for (int x1 : args)
for (int x1 : int_args)
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
for (auto x1 : common_complex())
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
Expand Down