Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel density estimation #4545

Merged
merged 25 commits into from
Mar 8, 2022
Merged

Conversation

RAMitchell
Copy link
Contributor

@RAMitchell RAMitchell commented Feb 1, 2022

Using a brute force approach compared to sklearn's kd/ball tree.

Todo:

  • Implement sample method
  • Sample weights
  • Evaluate which metrics are missing
  • Tests for sample
  • Docstrings

@github-actions github-actions bot added the Cython / Python Cython or Python issue label Feb 1, 2022
@RAMitchell
Copy link
Contributor Author

Benchmarks:
kernel_density_time_10_cols
kernel_density_time_50_cols
kernel_density_time_100_cols

import time
import numpy as np
import pandas as pd

from cuml.neighbors import KernelDensity as cuKernelDensity
from sklearn.neighbors import KernelDensity as sklKernelDensity
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

sns.set()

rows_all = np.arange(100, 15000, 600)
# rows_all = np.arange(100, 500, 300)
cols_all = [10,50,100]

iterations = 3

rs = np.random.RandomState(2)
estimators = {"sklearn": sklKernelDensity(), "cuml": cuKernelDensity()}
df = pd.DataFrame()

use_cache = True

if not use_cache:
    for n_rows in tqdm(rows_all):
        for n_cols in cols_all:
            X = rs.normal(size=(n_rows, n_cols))
            y = rs.normal(size=n_rows)
            for name, alg in estimators.items():
                # warmup
                alg.fit(X[0:10], y[0:10])
                alg.score_samples(X[0:10], y[0:10])
                for i in range(iterations):
                    start = time.perf_counter()
                    alg.fit(X, y)
                    log_prob = alg.score_samples(X)
                    time_taken = time.perf_counter() - start

                    df = df.append(
                        {
                            "Algorithm": name,
                            "n_rows": n_rows,
                            "n_cols": n_cols,
                            "Time": time_taken,
                            "Iteration": i,
                        },
                        ignore_index=True,
                    )

if use_cache:
    df = pd.read_pickle("kernel_density.pkl")
else:
    df.to_pickle("kernel_density.pkl")
int_cols = ["n_rows", "n_cols", "Iteration"]
df[int_cols] = df[int_cols].astype(int)

for col in cols_all:
    sns.lineplot(x="n_rows", y="Time", hue="Algorithm", data=df[df['n_cols']==col])
    plt.yscale("log")
    plt.xticks(rotation=45)
    plt.title(
        "Kernel density time (gaussian kernel, {} features, float64)".format(
            col
        )
    )
    plt.savefig("kernel_density_time_{}_cols.png".format(col))
    plt.clf()

@caryr35 caryr35 added this to PR-WIP in v22.04 Release via automation Feb 1, 2022
@cjnolet
Copy link
Member

cjnolet commented Feb 8, 2022

rerun tests

@RAMitchell
Copy link
Contributor Author

RAMitchell commented Feb 11, 2022

Below is an experiment that builds a generative model over MNIST using PCA and kernel density estimation. Sklearn takes over 2 hours in the KDE step, even after dimensionality reduction of the images. Cuml takes just over a minute. Example adapted from sklearns user guide: https://scikit-learn.org/stable/auto_examples/neighbors/plot_digits_kde_sampling.html#sphx-glr-auto-examples-neighbors-plot-digits-kde-sampling-py. I upgraded the dataset to the much larger MNIST.

kde_example_bar
sampled_digits_cuml
kde_example_bar_both

import numpy as np
import matplotlib.pyplot as plt
import cuml
import time
import cupy as cp
from cuml.neighbors import KernelDensity
from sklearn.neighbors import KernelDensity as sklKernelDensity
from sklearn.model_selection import GridSearchCV
from cuml.decomposition import PCA
from sklearn.decomposition import PCA as sklPCA
from sklearn.datasets import fetch_openml
import seaborn as sns
import pandas as pd
sns.set()

cp.cuda.runtime.setDevice(7)
cuml.common.memory_utils.set_global_output_type('numpy')


def run_experiment(mnist, pca, kde, name):
    # project the 64-dimensional data to a lower dimension
    start = time.perf_counter()
    data = pca.fit_transform(mnist.data)
    pca_time = time.perf_counter()-start

    # use grid search cross-validation to optimize the bandwidth
    start = time.perf_counter()
    params = {"bandwidth": np.logspace(-1, 3, 50)}
    grid = GridSearchCV(kde, params)
    grid.fit(data)
    kde_time = time.perf_counter()-start

    print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth))

    # use the best estimator to compute the kernel density estimate
    kde = grid.best_estimator_

    # sample 44 new points from the data
    new_data = kde.sample(44, random_state=0)
    new_data = pca.inverse_transform(new_data)

    # turn data into a 4x11 grid
    new_data = new_data.reshape((4, 11, -1))
    real_data = mnist.data[:44].reshape((4, 11, -1))

    # plot real digits and resampled digits
    fig, ax = plt.subplots(9, 11, subplot_kw=dict(xticks=[], yticks=[]))
    for j in range(11):
        ax[4, j].set_visible(False)
        for i in range(4):
            im = ax[i, j].imshow(
                real_data[i, j].reshape((28, 28)), cmap=plt.cm.binary, interpolation="nearest"
            )
            im.set_clim(0, 16)
            im = ax[i + 5, j].imshow(
                new_data[i, j].reshape((28, 28)), cmap=plt.cm.binary, interpolation="nearest"
            )
            im.set_clim(0, 16)

    ax[0, 5].set_title("Selection from the input data")
    ax[5, 5].set_title('"New" digits drawn from the kernel density model')

    plt.savefig("sampled_digits_{}.png".format(name))
    return pd.DataFrame([{"Name": name, "PCA Time": pca_time, "KDE Time": kde_time}])


mnist = fetch_openml(name='mnist_784', as_frame=False)
#mnist.data = mnist.data[1:100]
df = run_experiment(mnist, PCA(
    n_components=24, whiten=False), KernelDensity(), "cuml")
df = df.append(run_experiment(mnist, sklPCA(n_components=24,
               whiten=False), sklKernelDensity(), "sklearn"))
df.set_index('Name').plot(kind='bar', stacked=True)
plt.savefig("kde_example_bar.png")

@RAMitchell RAMitchell marked this pull request as ready for review February 11, 2022 15:50
@RAMitchell RAMitchell requested a review from a team as a code owner February 11, 2022 15:50
@github-actions github-actions bot removed the CUDA/C++ label Feb 11, 2022
python/cuml/neighbors/kernel_density.py Outdated Show resolved Hide resolved
python/cuml/neighbors/kernel_density.py Outdated Show resolved Hide resolved
python/cuml/neighbors/kernel_density.py Outdated Show resolved Hide resolved
python/cuml/neighbors/kernel_density.py Outdated Show resolved Hide resolved

@pytest.mark.xfail(
reason="cuml's pairwise_distances does"
"not process metric_params as expected")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only use-case for metric_params is {'p': X} for the minkowski distance? If so it can be a special case in the KDE code, or you could open an issue to have this behavior implemented (so that this xfail is tracked)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this by attempting to turn the dict input into a scalar and pass to pairwise_distances, ignoring the dict key. Ideally pairwise_distances would accept a dict like sklearn and do some error checking but I don't want to deal with that here.

python/cuml/test/test_kernel_density.py Outdated Show resolved Hide resolved
v22.04 Release automation moved this from PR-WIP to PR-Needs review Feb 19, 2022
@cjnolet
Copy link
Member

cjnolet commented Feb 19, 2022

Linking in rapidsai/raft#518

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. @lowener did a great job reviewing already and I only found a couple very minor things.

python/cuml/neighbors/kernel_density.py Outdated Show resolved Hide resolved
python/cuml/neighbors/kernel_density.py Outdated Show resolved Hide resolved
python/cuml/test/test_kernel_density.py Show resolved Hide resolved
python/cuml/neighbors/kernel_density.py Show resolved Hide resolved
@cjnolet cjnolet added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Mar 7, 2022
@codecov-commenter
Copy link

Codecov Report

❗ No coverage uploaded for pull request base (branch-22.04@9e0d458). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@               Coverage Diff               @@
##             branch-22.04    #4545   +/-   ##
===============================================
  Coverage                ?   83.85%           
===============================================
  Files                   ?      251           
  Lines                   ?    20273           
  Branches                ?        0           
===============================================
  Hits                    ?    16999           
  Misses                  ?     3274           
  Partials                ?        0           
Flag Coverage Δ
dask 44.94% <0.00%> (?)
non-dask 77.02% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9e0d458...0d5a621. Read the comment docs.

Copy link
Contributor

@lowener lowener left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

v22.04 Release automation moved this from PR-Needs review to PR-Reviewer approved Mar 8, 2022
@cjnolet
Copy link
Member

cjnolet commented Mar 8, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit b430b2e into rapidsai:branch-22.04 Mar 8, 2022
v22.04 Release automation moved this from PR-Reviewer approved to Done Mar 8, 2022
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this pull request Oct 9, 2023
Using a brute force approach compared to sklearn's kd/ball tree. 

Todo:
- [x] Implement sample method
- [x] Sample weights
- [x] Evaluate which metrics are missing
- [x] Tests for sample
- [x] Docstrings

Authors:
  - Rory Mitchell (https://github.com/RAMitchell)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#4545
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants