# Pickling cuML Models for Persistence

This notebook demonstrates simple pickling of both single-GPU and multi-GPU cuML models for persistence

In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

## Single GPU Model Pickling

All single-GPU estimators are pickleable. The following example demonstrates the creation of a synthetic dataset, training, and pickling of the resulting model for storage. Trained single-GPU models can also be used to distribute the inference on a Dask cluster, which the `Distributed Model Pickling` section below demonstrates.

In [2]:
from cuml.datasets import make_blobs

X, y = make_blobs(n_samples=50,
                  n_features=10,
                  centers=5,
                  cluster_std=0.4,
                  random_state=0)

In [3]:
from cuml.cluster import KMeans

model = KMeans(n_clusters=5)

model.fit(X)

KMeans()

In [4]:
import pickle

pickle.dump(model, open("kmeans_model.pkl", "wb"))

In [5]:
model = pickle.load(open("kmeans_model.pkl", "rb"))

In [6]:
model.cluster_centers_

array([[-5.7684636,  2.3276033, -3.7457771, -1.8541754, -5.1695833,
         7.667088 ,  2.7118318,  8.495609 ,  1.7038484,  1.1884269],
       [ 4.6476874,  8.37788  , -9.070581 ,  9.459332 ,  8.450422 ,
        -1.0210547,  3.3920872, -7.8629856, -0.7527663,  0.4838412],
       [-2.9414442,  4.6401706, -4.5027537,  2.2855108,  1.644645 ,
        -2.4937892, -5.2241607, -1.5499196, -8.063638 ,  2.816936 ],
       [-4.2710767,  5.561165 , -5.6640916, -1.8229512, -9.2925   ,
         0.730283 ,  4.4586773, -2.8876224, -5.125775 ,  9.694357 ],
       [ 5.5837417, -4.1515303,  4.369667 , -3.00205  ,  3.6388965,
        -4.341912 , -3.318711 ,  6.503671 , -6.865036 , -1.0266498]],
      dtype=float32)

## Distributed Model Pickling

The distributed estimator wrappers inside of the `cuml.dask` are not intended to be pickled directly. The Dask cuML estimators provide a function `get_combined_model()`, which returns the trained single-GPU model for pickling. The combined model can be used for inference on a single-GPU, and the `ParallelPostFit` wrapper from the [Dask-ML](https://ml.dask.org/meta-estimators.html) library can be used to perform distributed inference on a Dask cluster.

In [7]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster

cluster = LocalCUDACluster()
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:39764  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 1  Cores: 1  Memory: 251.80 GiB


In [8]:
from cuml.dask.datasets import make_blobs

n_workers = len(client.scheduler_info()["workers"].keys())

X, y = make_blobs(n_samples=5000, 
                  n_features=30,
                  centers=5, 
                  cluster_std=0.4, 
                  random_state=0,
                  n_parts=n_workers*5)

X = X.persist()
y = y.persist()

In [9]:
from cuml.dask.cluster import KMeans

dist_model = KMeans(n_clusters=5)

In [10]:
dist_model.fit(X)

<cuml.dask.cluster.kmeans.KMeans at 0x7f9be2153f10>

In [11]:
import pickle

single_gpu_model = dist_model.get_combined_model()
pickle.dump(single_gpu_model, open("kmeans_model.pkl", "wb"))

In [12]:
single_gpu_model = pickle.load(open("kmeans_model.pkl", "rb"))

In [13]:
single_gpu_model.cluster_centers_

array([[ 4.8098736 ,  8.422669  , -9.239022  ,  9.379145  ,  8.49988   ,
        -1.0592818 ,  3.3437862 , -7.8026123 , -0.59463334,  0.264476  ,
         5.5073934 , -4.1069803 ,  4.2890778 , -2.8172047 ,  3.6150157 ,
        -4.161299  , -3.6209629 ,  6.218531  , -6.946048  , -1.0828304 ,
        -5.8267703 ,  2.2258766 , -3.8601215 , -1.6974078 , -5.3134165 ,
         7.5795784 ,  2.9187477 ,  8.540424  ,  1.5523201 ,  1.0841808 ],
       [-2.894185  ,  4.4741883 , -4.4475665 ,  2.3820996 ,  1.7478832 ,
        -2.504625  , -5.208329  , -1.6937685 , -8.134755  ,  2.6468294 ,
        -4.3163667 ,  5.565539  , -5.732199  , -1.7384957 , -9.344654  ,
         0.7084657 ,  4.43584   , -2.900899  , -4.9486413 ,  9.695299  ,
         8.366522  , -6.247453  , -6.349474  ,  1.9546974 ,  4.1576147 ,
        -9.167908  ,  4.607068  ,  8.788586  ,  6.8644233 ,  2.231987  ],
       [-4.6657143 , -9.558956  ,  6.6572294 ,  4.4401298 ,  2.1730304 ,
         2.5904038 ,  0.58000994,  6.255034  , -8