In [1]:
from sklearn.base import BaseEstimator, TransformerMixin
import numpy as np
from sklearn.utils import check_random_state
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from scipy.sparse import csr_matrix
import warnings
warnings.filterwarnings('ignore')  # we still have warnings because of sparse arrays but will be fixed
from sklearn.pipeline import Pipeline, make_pipeline

We create two ConstrainedDataset classes: one that makes a copy of X and one that keeps only a view of X

In [2]:
class ConstrainedDatasetView():

    def __init__(self, X, c):
        self.c = c
        self.X = X
        self.shape = (len(c), X.shape[1])

    def __getitem__(self, item):
        # Note that to avoid useless memory consumption, when splitting we delete the points that are not used
        c_sliced = self.c[item]
        return ConstrainedDatasetView(self.X, c_sliced)

    def __len__(self):
        return self.shape

    def __str__(self):
        return self.asarray().__str__()
    
    def __repr__(self):
        return self.asarray().__repr__()

    def asarray(self):
        return self.X[self.c]
    
class ConstrainedDatasetCopy():

    def __init__(self, X, c):
        self.c = c
        self.X = X.copy()
        self.shape = (len(c), X.shape[1])

    def __getitem__(self, item):
        # Note that to avoid useless memory consumption, when splitting we delete the points that are not used
        c_sliced = self.c[item]
        X = self.X
        return ConstrainedDatasetCopy(X, c_sliced)

    def __len__(self):
        return self.shape

    def __str__(self):
        return self.asarray().__str__()
    
    def __repr__(self):
        return self.asarray().__repr__()

    def asarray(self):
        return np.stack([self.X[self.c[:, 0].ravel()], self.X[self.c[:, 1].ravel()]], axis=1)



Let's create a metric learner that could be problematic: a one where the operations on X will be done on the view. If we multithread like in a cross validation we may think that there will be some threads that will want to access the data at the same time.

In [3]:
class ProblematicMetricLearner(BaseEstimator, TransformerMixin):
    def __init__(self, return_embedding=True):
        self.A = None
        self.return_embedding = return_embedding
        
    def fit(self, constrained_dataset, y):            
        constraints = self.prepare_input(constrained_dataset, y)
        diffs = constrained_dataset.X  # Here we use a view 
        self.metric = diffs.T.dot(diffs)
        if constrained_dataset.X.__array_interface__['data'][0] == X.__array_interface__['data'][0]:
            shared = 'is'
        else: 
            shared = 'is not'
        print('Memory {} shared with the initial object X.'.format(shared))
    
    def fit_transform(self, constrained_dataset, y):
        self.fit(constrained_dataset, y)
        return self.transform(constrained_dataset)
    
    def predict(self, constrained_dataset):
        return self.decision_function(constrained_dataset)

    def decision_function(self, constrained_dataset):
        X_embedded = self.transform(constrained_dataset)
        squared_distances = np.sum((X_embedded[:, None] - X_embedded)**2, axis=2)
        return squared_distances[constrained_dataset.c[:, 0], constrained_dataset.c[:, 1]]
    
    def transform(self, constrained_dataset):
        X_embedded = constrained_dataset.X.dot(self.metric)
        if self.return_embedding:
            return X_embedded
        else: 
            return np.sqrt(np.sum((X_embedded[:, None] - X_embedded)**2, axis=2))
    
    @staticmethod
    def prepare_input(X, y):
        a = X.c[y==0][:, 0]
        b = X.c[y==0][:, 1]
        c = X.c[y==1][:, 0]
        d = X.c[y==1][:, 1]
        X = X.X
        return [a, b, c, d]
        

Let's create some data

In [4]:
X = np.random.randn(1000, 10)
c = np.random.randint(0, 1000, (20000, 2))
y = np.random.randint(0, 2, 20000)
view_dataset = ConstrainedDatasetView(X, c)
copy_dataset = ConstrainedDatasetCopy(X, c)
pml = ProblematicMetricLearner(return_embedding=True)

We see that view_dataset.X is a view of X whereas copy_dataset is not.

In [5]:
print(view_dataset.X.__array_interface__['data'][0] == X.__array_interface__['data'][0])
print(copy_dataset.X.__array_interface__['data'][0] == X.__array_interface__['data'][0])

True
False


As expected, in a cross validation without multithreading, slices are made that will have a view of X in the case of ``view_dataset``, and algorithms only works on the view. For the ``copy_dataset``, copies are made at each slice and the algorithms work on views.

In [6]:
print('Cross val on the view object:')
cross_val_score(pml, view_dataset, y, scoring='roc_auc', n_jobs=1)
print('Cross val on the copied object:')
cross_val_score(pml, copy_dataset, y, scoring='roc_auc', n_jobs=1)

Cross val on the view object:
Memory is shared with the initial object X.
Memory is shared with the initial object X.
Memory is shared with the initial object X.
Cross val on the copied object:
Memory is not shared with the initial object X.
Memory is not shared with the initial object X.
Memory is not shared with the initial object X.


array([ 0.49628755,  0.50579987,  0.52288073])

What will happen then if we multithread in the ``view_dataset`` case? (let's also check the case of ``copied_dataset`` but we know it is a copy from start so we already know algorithms will work on copies of the original X)

In [7]:
print('Cross val on the view object:')
cross_val_score(pml, view_dataset, y, scoring='roc_auc', n_jobs=4)
print('Cross val on the copied object:')
cross_val_score(pml, copy_dataset, y, scoring='roc_auc', n_jobs=4)

Cross val on the view object:
Memory is not shared with the initial object X.
Memory is not shared with the initial object X.
Memory is not shared with the initial object X.
Cross val on the copied object:
Memory is not shared with the initial object X.
Memory is not shared with the initial object X.
Memory is not shared with the initial object X.


array([ 0.49628755,  0.50579987,  0.52288073])

We can notice that for the ``view_dataset`` case, a copy appears to be created somewhere, which is cool because there is no bug to to concurrent data access. It is probably because of this: https://pythonhosted.org/joblib/parallel.html#working-with-numerical-data-in-shared-memory-memmaping
```
The arguments passed as input to the Parallel call are serialized and reallocated in the memory of each worker process.
```