# Pickling cuML Models for Persistence

This notebook demonstrates simple pickling of both single-GPU and multi-GPU cuML models for persistence

In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

## Single GPU Model Pickling

All single-GPU estimators are pickleable. The following example demonstrates the creation of a synthetic dataset, training, and pickling of the resulting model for storage. Trained single-GPU models can also be used to distribute the inference on a Dask cluster, which the `Distributed Model Pickling` section below demonstrates.

In [2]:
from cuml.datasets import make_blobs

X, y = make_blobs(n_samples=50,
                  n_features=10,
                  centers=5,
                  cluster_std=0.4,
                  random_state=0)

In [3]:
from cuml.cluster import KMeans

model = KMeans(n_clusters=5)

model.fit(X)

KMeans(handle=<cuml.raft.common.handle.Handle object at 0x7f5b25331f70>, n_clusters=5, max_iter=300, tol=0.0001, verbose=4, random_state=1, init='scalable-k-means++', n_init=1, oversampling_factor=2.0, max_samples_per_batch=32768, output_type='input')

In [4]:
import pickle

pickle.dump(model, open("kmeans_model.pkl", "wb"))

In [5]:
model = pickle.load(open("kmeans_model.pkl", "rb"))

In [6]:
model.cluster_centers_

array([[-5.768463  ,  2.3276033 , -3.7457774 , -1.8541754 , -5.169583  ,
         7.667088  ,  2.7118316 ,  8.495609  ,  1.7038484 ,  1.188427  ],
       [ 4.6476874 ,  8.37788   , -9.070581  ,  9.459332  ,  8.450422  ,
        -1.0210547 ,  3.3920872 , -7.8629856 , -0.7527663 ,  0.4838412 ],
       [-2.9414442 ,  4.6401706 , -4.5027533 ,  2.2855108 ,  1.644645  ,
        -2.4937892 , -5.2241607 , -1.5499198 , -8.063638  ,  2.816936  ],
       [-4.271077  ,  5.5611653 , -5.6640916 , -1.8229512 , -9.2925    ,
         0.73028314,  4.4586773 , -2.8876226 , -5.1257744 ,  9.694357  ],
       [ 5.5837426 , -4.1515303 ,  4.369667  , -3.0020502 ,  3.6388965 ,
        -4.341912  , -3.3187115 ,  6.503671  , -6.865036  , -1.0266498 ]],
      dtype=float32)

## Distributed Model Pickling

The distributed estimator wrappers inside of the `cuml.dask` are not intended to be pickled directly. The Dask cuML estimators provide a function `get_combined_model()`, which returns the trained single-GPU model for pickling. The combined model can be used for inference on a single-GPU, and the `ParallelPostFit` wrapper from the [Dask-ML](https://ml.dask.org/meta-estimators.html) library can be used to perform distributed inference on a Dask cluster.

In [7]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster

cluster = LocalCUDACluster()
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:36550  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 1  Cores: 1  Memory: 270.37 GB


In [8]:
from cuml.dask.datasets import make_blobs

n_workers = len(client.scheduler_info()["workers"].keys())

X, y = make_blobs(n_samples=5000, 
                  n_features=30,
                  centers=5, 
                  cluster_std=0.4, 
                  random_state=0,
                  n_parts=n_workers*5)

X = X.persist()
y = y.persist()

In [9]:
from cuml.dask.cluster import KMeans

dist_model = KMeans(n_clusters=5)

In [10]:
dist_model.fit(X)

<cuml.dask.cluster.kmeans.KMeans at 0x7f5af41b9a50>

In [11]:
import pickle

single_gpu_model = dist_model.get_combined_model()
pickle.dump(single_gpu_model, open("kmeans_model.pkl", "wb"))

In [12]:
single_gpu_model = pickle.load(open("kmeans_model.pkl", "rb"))

In [13]:
single_gpu_model.cluster_centers_

array([[ 4.809874  ,  8.42267   , -9.239025  ,  9.379146  ,  8.499881  ,
        -1.0592816 ,  3.343786  , -7.8026123 , -0.5946332 ,  0.26447606,
         5.5073943 , -4.1069803 ,  4.2890778 , -2.817205  ,  3.6150143 ,
        -4.161299  , -3.6209633 ,  6.2185297 , -6.946046  , -1.0828303 ,
        -5.8267717 ,  2.225877  , -3.860121  , -1.6974074 , -5.313417  ,
         7.5795803 ,  2.9187467 ,  8.540427  ,  1.5523205 ,  1.0841805 ],
       [-2.8941858 ,  4.4741898 , -4.4475656 ,  2.3820984 ,  1.747883  ,
        -2.5046256 , -5.2083306 , -1.6937683 , -8.134756  ,  2.6468291 ,
        -4.316363  ,  5.5655394 , -5.7321987 , -1.7384956 , -9.344655  ,
         0.708466  ,  4.435841  , -2.9009    , -4.9486394 ,  9.695301  ,
         8.366521  , -6.247453  , -6.3494744 ,  1.9546975 ,  4.1576157 ,
        -9.167905  ,  4.6070666 ,  8.788584  ,  6.864422  ,  2.2319877 ],
       [-4.6657147 , -9.558955  ,  6.657228  ,  4.44013   ,  2.1730292 ,
         2.590404  ,  0.58000994,  6.2550364 , -8