From 8e6c11476eb5ce896289bded981ba7754e011e0e Mon Sep 17 00:00:00 2001 From: charavelg Date: Tue, 29 Jul 2025 11:35:46 +0200 Subject: [PATCH] Fixes #447: forward sklearn estimator attributes --- tslearn/svm/svm.py | 24 +++++++++++++++++++++--- tslearn/tests/test_svm.py | 24 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/tslearn/svm/svm.py b/tslearn/svm/svm.py index 7fcfd4eeb..bfd5433c8 100644 --- a/tslearn/svm/svm.py +++ b/tslearn/svm/svm.py @@ -13,6 +13,27 @@ class TimeSeriesSVMMixin: + + @property + def support_(self): + check_is_fitted(self, ['svm_estimator_', '_X_fit']) + return getattr(self, "svm_estimator_").support_ + + @property + def dual_coef_(self): + check_is_fitted(self, ['svm_estimator_', '_X_fit']) + return getattr(self, "svm_estimator_").dual_coef_ + + @property + def coef_(self): + check_is_fitted(self, ['svm_estimator_', '_X_fit']) + return getattr(self, "svm_estimator_").coef_ + + @property + def intercept_(self): + check_is_fitted(self, ['svm_estimator_', '_X_fit']) + return getattr(self, "svm_estimator_").intercept_ + def _preprocess_sklearn(self, X, y=None, fit_time=False): force_all_finite = self.kernel not in VARIABLE_LENGTH_METRICS if y is None: @@ -446,9 +467,6 @@ class TimeSeriesSVR(TimeSeriesSVMMixin, RegressorMixin, intercept_ : array, shape = [1] Constants in decision function. - sample_weight : array-like, shape = [n_samples] - Individual weights for each sample - svm_estimator_ : sklearn.svm.SVR The underlying sklearn estimator diff --git a/tslearn/tests/test_svm.py b/tslearn/tests/test_svm.py index d4983277f..1cc696356 100644 --- a/tslearn/tests/test_svm.py +++ b/tslearn/tests/test_svm.py @@ -1,11 +1,16 @@ import numpy as np +import pytest + +from sklearn.exceptions import NotFittedError + from tslearn.metrics import cdist_gak from tslearn.svm import TimeSeriesSVC, TimeSeriesSVR __author__ = 'Romain Tavenard romain.tavenard[at]univ-rennes2.fr' + def test_gamma_value_svm(): n, sz, d = 5, 10, 3 rng = np.random.RandomState(0) @@ -22,3 +27,22 @@ def test_gamma_value_svm(): cdist_mat = cdist_gak(time_series, sigma=np.sqrt(gamma / 2.)) np.testing.assert_allclose(sklearn_X, cdist_mat) + +def test_attributes(): + n, sz, d = 5, 10, 3 + rng = np.random.RandomState(0) + time_series = rng.randn(n, sz, d) + labels = rng.randint(low=0, high=2, size=n) + + for ModelClass in [TimeSeriesSVC, TimeSeriesSVR]: + linear_model = ModelClass(kernel="linear") + + for attr in ['coef_', 'support_', 'support_vectors_', + 'dual_coef_', 'coef_', 'intercept_']: + with pytest.raises(NotFittedError): + getattr(linear_model, attr) + + linear_model.fit(time_series, labels) + for attr in ['coef_', 'support_', 'support_vectors_', + 'dual_coef_', 'coef_', 'intercept_']: + assert hasattr(linear_model, attr)