Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SCML: Add warm_start parameter #345

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 36 additions & 17 deletions metric_learn/scml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class _BaseSCML(MahalanobisMixin):

def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None,
gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10,
verbose=False, preprocessor=None, random_state=None):
verbose=False, preprocessor=None, random_state=None,
warm_start=False):
self.beta = beta
self.basis = basis
self.n_basis = n_basis
Expand All @@ -34,6 +35,7 @@ def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None,
self.verbose = verbose
self.preprocessor = preprocessor
self.random_state = random_state
self.warm_start = warm_start
super(_BaseSCML, self).__init__(preprocessor)

def _fit(self, triplets, basis=None, n_basis=None):
Expand Down Expand Up @@ -74,13 +76,14 @@ def _fit(self, triplets, basis=None, n_basis=None):

n_triplets = triplets.shape[0]

# weight vector
w = np.zeros((1, n_basis))
# avarage obj gradient wrt weights
avg_grad_w = np.zeros((1, n_basis))
if not self.warm_start or not hasattr(self, "w_"):
# weight vector
self.w_ = np.zeros((1, n_basis))
# average obj gradient wrt weights
self.avg_grad_w_ = np.zeros((1, n_basis))
# l2 norm in time of all obj gradients wrt weights
self.ada_grad_w_ = np.zeros((1, n_basis))

# l2 norm in time of all obj gradients wrt weights
ada_grad_w = np.zeros((1, n_basis))
# slack for not dividing by zero
delta = 0.001

Expand All @@ -93,27 +96,28 @@ def _fit(self, triplets, basis=None, n_basis=None):

idx = rand_int[iter]

slack_val = 1 + np.matmul(dist_diff[idx, :], w.T)
slack_val = 1 + np.matmul(dist_diff[idx, :], self.w_.T)
slack_mask = np.squeeze(slack_val > 0, axis=1)

grad_w = np.sum(dist_diff[idx[slack_mask], :],
axis=0, keepdims=True)/self.batch_size
avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1)

ada_grad_w = np.sqrt(np.square(ada_grad_w) + np.square(grad_w))
self.avg_grad_w_ = (iter * self.avg_grad_w_ + grad_w) / (iter+1)

scale_f = -(iter+1) / (self.gamma * (delta + ada_grad_w))
self.ada_grad_w_ = np.sqrt(np.square(self.ada_grad_w_) + np.square(grad_w))

scale_f = -(iter+1) / (self.gamma * (delta + self.ada_grad_w_))

# proximal operator with negative trimming equivalent
w = scale_f * np.minimum(avg_grad_w + self.beta, 0)
self.w_ = scale_f * np.minimum(self.avg_grad_w_ + self.beta, 0)

if (iter + 1) % self.output_iter == 0:
# regularization part of obj function
obj1 = np.sum(w)*self.beta
obj1 = np.sum(self.w_)*self.beta

# Every triplet distance difference in the space given by L
# plus a slack of one
slack_val = 1 + np.matmul(dist_diff, w.T)
slack_val = 1 + np.matmul(dist_diff, self.w_.T)
# Mask of places with positive slack
slack_mask = slack_val > 0

Expand All @@ -129,7 +133,7 @@ def _fit(self, triplets, basis=None, n_basis=None):
# update the best
if obj < best_obj:
best_obj = obj
best_w = w
best_w = self.w_

if self.verbose:
print("max iteration reached.")
Expand Down Expand Up @@ -355,6 +359,13 @@ class SCML(_BaseSCML, _TripletsClassifierMixin):
random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int.

warm_start : bool, default=False
When set to True, reuse the solution of the previous call to fit as
initialization, otherwise, just erase the previous solution.
Repeatedly calling fit when warm_start is True can result in a different
solution than when calling fit a single time because of the way the data
is shuffled.

Attributes
----------
components_ : `numpy.ndarray`, shape=(n_features, n_features)
Expand Down Expand Up @@ -465,6 +476,13 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int.

warm_start : bool, default=False
When set to True, reuse the solution of the previous call to fit as
initialization, otherwise, just erase the previous solution.
Repeatedly calling fit when warm_start is True can result in a different
solution than when calling fit a single time because of the way the data
is shuffled.

Attributes
----------
components_ : `numpy.ndarray`, shape=(n_features, n_features)
Expand Down Expand Up @@ -506,13 +524,14 @@ class SCML_Supervised(_BaseSCML, TransformerMixin):
def __init__(self, k_genuine=3, k_impostor=10, beta=1e-5, basis='lda',
n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500,
batch_size=10, verbose=False, preprocessor=None,
random_state=None):
random_state=None, warm_start=False):
self.k_genuine = k_genuine
self.k_impostor = k_impostor
_BaseSCML.__init__(self, beta=beta, basis=basis, n_basis=n_basis,
max_iter=max_iter, output_iter=output_iter,
batch_size=batch_size, verbose=verbose,
preprocessor=preprocessor, random_state=random_state)
preprocessor=preprocessor, random_state=random_state,
warm_start=warm_start)

def fit(self, X, y):
"""Create constraints from labels and learn the SCML model.
Expand Down
16 changes: 16 additions & 0 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,22 @@ def test_large_output_iter(self):
scml.fit(triplets)
assert msg == raised_error.value.args[0]

@pytest.mark.parametrize("basis", ("lda", "triplet_diffs"))
def test_warm_start(self, basis):
X, y = load_iris(return_X_y=True)
# Should work with warm_start=True even with first fit
scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5,
random_state=42, warm_start=True)
scml.fit(X, y)
# Re-fitting should continue from previous fit
before = class_separation(scml.transform(X), y)
scml.fit(X, y)
# We used the whole same dataset, so it can led to overfitting
after = class_separation(scml.transform(X), y)
if basis == "lda":
assert before > after # For lda, class separation improved with re-fit
else:
assert before < after # For triplet_diffs, it got worse

class TestLSML(MetricTestCase):
def test_iris(self):
Expand Down