Core Set Query Strategy

In [2]:
import numpy as np

from skactiveml.base import SingleAnnotatorPoolQueryStrategy
from skactiveml.utils import MISSING_LABEL, simple_batch
from sklearn.metrics import pairwise_distances

Implement Class Core Set

In [4]:
class CoreSet(SingleAnnotatorPoolQueryStrategy):
    def __init__(
        self, missing_label=MISSING_LABEL, random_state=None, method='greedy'
    ):
        super().__init__(
            missing_label=missing_label, random_state=random_state
        )

        self.method = method
        self.min_distances = None # should be a ndarray of the distance to the neareste cluster center
        self.selected_samples = []
    
    def query(
            self, 
            X, 
            y,
            candidates=None,
            batch_size=1,
            return_utilities=False,
        ):
        
         X, y, candidates, batch_size, return_utilities = self._validate_data(
            X, y, candidates, batch_size, return_utilities, reset=True
        )
         
         X_cand, mapping = self._transform_candidates(candidates, X, y)
         
         if self.method == 'greedy':
             query_indices = self.kCenterGreedy(X, batch_size)

         if return_utilities:
             return query_indices
         else:
             return query_indices
    
    def kCenterGreedy(self, X, batch_size):
        if len(self.selected_samples) > 0:
            self.update_distances(self.selected_samples, False)

        new_samples = []

        for _ in range(batch_size):
            if len(self.selectedSample) == 0:
                idx = np.random.choice(np.arange(X.shape[0]))
            else:
                idx = np.argmax(self.min_distances)
            assert idx not in self.selected_samples

            self.update_distances([idx], only_new=True)

            new_samples.append(idx)
            self.selected_samples.append(idx)
                
        return new_samples
    
    def update_distances(self, cluster_centers, only_new):
        if only_new:
            cluster_centers = [c for c in cluster_centers
                               if c not in self.selected_samples]

        cluster_center_feature = self.X[cluster_centers]
        dist_matrix = pairwise_distances(self.X, cluster_center_feature)
        dist = np.min(dist_matrix, axis=1).reshape(-1,1)

        if self.min_distances is None:
            self.min_distances = dist
        else:
            self.min_distances = np.minimum(self.min_distances, dist)
        
