diff --git a/tests/test_survival_svm.py b/tests/test_survival_svm.py index 16efb80a..f372a68f 100644 --- a/tests/test_survival_svm.py +++ b/tests/test_survival_svm.py @@ -530,21 +530,20 @@ def test_fit_and_predict_regression_rbf(make_whas500): @staticmethod @pytest.mark.slow() - @pytest.mark.filterwarnings("ignore:Optimization did not converge") - def test_fit_and_predict_hybrid_rbf(make_whas500): + def test_fit_and_predict_hybrid_polynomial(make_whas500): whas500 = make_whas500(to_numeric=True) ssvm = FastKernelSurvivalSVM( - optimizer="rbtree", rank_ratio=0.5, kernel="rbf", + optimizer="rbtree", rank_ratio=0.5, kernel="poly", coef0=1e-4, degree=2, max_iter=50, fit_intercept=True, random_state=0 ) ssvm.fit(whas500.x, whas500.y) assert not ssvm._get_tags()["pairwise"] - assert abs(5.0289145697617164 - ssvm.intercept_) <= 0.04 + assert round(abs(6.105417583975533 - ssvm.intercept_), 5) == 0 pred = ssvm.predict(whas500.x) rmse = np.sqrt(mean_squared_error(whas500.y['lenfol'], pred)) - assert abs(880.20361811281487 - rmse) <= 75 + assert round(abs(754.7810877051903 - rmse), 5) == 0 @staticmethod @pytest.mark.slow()