Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel ridge regression #4492

Merged
merged 19 commits into from
Feb 9, 2022
Merged

Conversation

RAMitchell
Copy link
Contributor

@RAMitchell RAMitchell commented Jan 18, 2022

Sklearn reference implementation: https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09b/sklearn/kernel_ridge.py#L16

I've tried to avoid touching the c++/cuda layer so far. Pairwise kernels are implented based on a numba kernel for now. I've also used cupy's lapack wrapper to access cuSolver.

The implementation of pairwise_kernels here can be reused to very easily implement kernel PCA.

Todo:

  • Single target fit/predict
  • Standard kernels implemented
  • Support custom kernels
  • Support sample weights
  • Support CSR X matrix. Maybe too difficult for this PR.
  • Multi-target fit/predict
  • Change .py files to .pyx and moved to correct places.
  • Benchmarking on reasonably large files
  • Tests take less than 20s
  • Ensure correct handling of input/output array types (I think I need to be using CumlArray and maybe some decorators)
  • Documentation

@github-actions github-actions bot added the Cython / Python Cython or Python issue label Jan 18, 2022
@caryr35 caryr35 added this to PR-WIP in v22.02 Release via automation Jan 18, 2022
@RAMitchell
Copy link
Contributor Author

Benchmarks:
kernel_ridge_time
kernel_ridge_mse

At 14800 rows, cuml takes 0.8071575206238777s, sklearn takes 10.78322389847599s, speedup is 13.359503718854398.

import time
import numpy as np
import pandas as pd

from cuml import KernelRidge as cuKernelRidge
from sklearn.kernel_ridge import KernelRidge as sklKernelRidge
from sklearn.metrics import mean_squared_error as mse
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

sns.set()

rows_all = np.arange(100, 15000, 300)
# rows_all = np.arange(100, 500, 300)
cols_all = [100]

iterations = 5

rs = np.random.RandomState(2)
estimators = {"sklearn": sklKernelRidge(), "cuml": cuKernelRidge()}
df = pd.DataFrame()

use_cache = False

if not use_cache:
    for n_rows in tqdm(rows_all):
        for n_cols in cols_all:
            X = rs.normal(size=(n_rows, n_cols))
            y = rs.normal(size=n_rows)
            for name, alg in estimators.items():
                # warmup
                alg.fit(X[0:10], y[0:10])
                for i in range(iterations):
                    start = time.perf_counter()
                    alg.fit(X, y)
                    pred = alg.predict(X)
                    time_taken = time.perf_counter() - start

                    if "cupy" in str(type(pred)):
                        pred = pred.get()
                    df = df.append(
                        {
                            "Algorithm": name,
                            "n_rows": n_rows,
                            "n_cols": n_cols,
                            "MSE": mse(y, pred),
                            "Time": time_taken,
                            "Iteration": i,
                        },
                        ignore_index=True,
                    )

if use_cache:
    df = pd.read_pickle("kernel_rr.pkl")
else:
    df.to_pickle("kernel_rr.pkl")
int_cols = ["n_rows", "n_cols", "Iteration"]
df[int_cols] = df[int_cols].astype(int)

sns.lineplot(x="n_rows", y="Time", hue="Algorithm", data=df)
plt.yscale("log")
plt.xticks(rotation=45)
plt.title(
    "Kernel ridge regression time (linear kernel, {} features, float64)".format(
        cols_all[-1]
    )
)
plt.savefig("kernel_ridge_time.png")
plt.clf()
sns.barplot(x="n_rows", y="MSE", hue="Algorithm", data=df)
plt.xticks(rotation=45)
plt.title(
    "Kernel ridge regression MSE (linear kernel, {} features, float64)".format(
        cols_all[-1]
    )
)
plt.savefig("kernel_ridge_mse.png")

sklearn_largest_time = df[
    (df["n_rows"] == df["n_rows"].max()) & (df["Algorithm"] == "sklearn")
]["Time"].mean()
cuml_largest_time = df[
    (df["n_rows"] == df["n_rows"].max()) & (df["Algorithm"] == "cuml")
]["Time"].mean()

print(
    "At {} rows, cuml takes {}s, sklearn takes {}s, speedup is {}.".format(
        df["n_rows"].max(), cuml_largest_time, sklearn_largest_time, sklearn_largest_time/cuml_largest_time 
    )
)

@RAMitchell RAMitchell changed the base branch from branch-22.02 to branch-22.04 January 24, 2022 13:58
@RAMitchell RAMitchell marked this pull request as ready for review January 26, 2022 14:27
@RAMitchell RAMitchell requested a review from a team as a code owner January 26, 2022 14:27
@RAMitchell
Copy link
Contributor Author

This should be on the 22.04 board, not 22.02.

@RAMitchell RAMitchell changed the title [WIP] Kernel ridge regression Kernel ridge regression Jan 26, 2022
@cjnolet cjnolet removed this from PR-WIP in v22.02 Release Jan 26, 2022
@cjnolet cjnolet added this to PR-WIP in v22.04 Release via automation Jan 26, 2022
@cjnolet cjnolet self-requested a review January 28, 2022 18:36
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to see more kernel-based methods in cuml. This is a really nice port from scikit-learn and I'm thinking the new API for building custom kernels might even be useful for pairwise distances in general (maybe w/ an option to turn symmetry on and off).

python/cuml/metrics/pairwise_kernels.py Outdated Show resolved Hide resolved
pairwise_kernels(X, Y, metric='linear')

@cuda.jit(device=True)
def custom_rbf_kernel(x, y, gamma=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I very much like the ability to quickly build custom kernels. Have you done any profiling / benchmarking of this against the cuml.metrics.pairwise_distances API? I'm mostly curious to know the gap between the two, and whether there's a perf hit for the different memory access patterns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be interesting to see the difference, but for now it doesn't really matter as the matrix inversion dominates computation time. It could be 5 times slower than the cuda version and we won't see any real difference in end to end time.

The bigger disadvantage of this approach for me has been jit compile time. It's in the range of a few hundred ms, which I think is reasonable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. My question isn't about this algorithm in particular. It's been on our todo list for quite awhile to see how performant it would be to allow users to implement custom pairwise distance measures in Numba.

return (X, Y)


@given(kernel_arg_strategy(), array_strategy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the use of hypothesis here. I'm hoping we will start using it more in cuml.

v22.04 Release automation moved this from PR-WIP to PR-Needs review Jan 31, 2022
@RAMitchell
Copy link
Contributor Author

Have addressed review comments. I changed the kernel implementations to build off existing primitives more, this had a couple of side effects. The jit compilation overhead went away for most of the kernels, taking the overall test time from 30s down to 10s. The estimator also became much less accurate for float32 inputs, because before I was able to force intermediate calculations to double precision. Accordingly, the tolerance has been significantly reduced for float32 tests.

The cosine kernel still uses the custom kernel path, as implementing this the sklearn way is just very inaccurate and caused me to fail some tests. Chi^2 kernels also still use the custom kernel path as I can't immediately see how to use existing primatives to get this.

I might benchmark the custom versions against the newer versions later if I get time, but this is more a matter of curiosity.

@RAMitchell
Copy link
Contributor Author

Benchmarks comparing custom kernel performance against Implementation using primitives. The custom kernels implementation falls off considerably at higher dimension due to poor memory access patterns. It is still faster than sklearn.

custom_kernels

import cupy as cp
import numpy as np
from numba import cuda
from cuml.metrics import pairwise_kernels
from sklearn.metrics.pairwise import pairwise_kernels as skl_pairwise_kernels
import math
import time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
sns.set()


df = pd.DataFrame()
for col in tqdm(range(10, 110, 10)):
    rs = np.random.RandomState(259)
    X = rs.normal(size=(20000, col))
    X_device = cp.array(X)

    K = pairwise_kernels(X_device[0:10], metric='rbf')
    start = time.perf_counter()
    K = pairwise_kernels(X_device, metric='rbf')
    cp.cuda.runtime.deviceSynchronize()
    standard_time = time.perf_counter()-start
    df = df.append(
        {"Algorithm": 'rbf', "n_rows": X.shape[0], "n_cols": X.shape[1], "Time": standard_time}, ignore_index=True)

    @cuda.jit(device=True)
    def custom_rbf_kernel(x, y, gamma=None):
        if gamma is None:
            gamma = 1.0 / len(x)
        sum = 0.0
        for i in range(len(x)):
            sum += (x[i] - y[i]) ** 2
        return math.exp(-gamma * sum)

    start = time.perf_counter()
    K = skl_pairwise_kernels(X, metric='rbf')
    cp.cuda.runtime.deviceSynchronize()
    skl_time = time.perf_counter()-start
    df = df.append(
        {"Algorithm": 'rbf_skl', "n_rows": X.shape[0], "n_cols": X.shape[1], "Time": skl_time}, ignore_index=True)
    # warmup
    K = pairwise_kernels(X_device[0:10], metric=custom_rbf_kernel)
    start = time.perf_counter()
    K = pairwise_kernels(X_device, metric=custom_rbf_kernel)
    cp.cuda.runtime.deviceSynchronize()
    custom_time = time.perf_counter()-start
    df = df.append({"Algorithm": 'rbf_custom',
                   "n_rows": X.shape[0], "n_cols": X.shape[1], "Time": custom_time}, ignore_index=True)

print(df)
sns.lineplot(x='n_cols', y='Time', hue='Algorithm', data=df)
plt.yscale('log')
plt.title('Pairwise kernel time 20,000 rows, varying cols')
plt.savefig("custom_kernels.png")

@@ -0,0 +1,291 @@
#
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticed this- we should remove 2019 since this is a new file.

z += x[i]*y[i]
x_norm += x[i] * x[i]
y_norm += y[i] * y[i]
return z / math.sqrt(x_norm * y_norm)
Copy link
Member

@cjnolet cjnolet Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is how the pairwise_distances are computing the cosine as well (with exception that it's the 2 - [a.dot(b) / (sqrt(x_l2_norm) * sqrt(y_l2_norm)] (and the sqrt(a)sqrt(b) = sqrt(ab)). It looks like you are doing this as well. Are you saying there's a numerical issue that might be causing incorrect values?

Copy link
Contributor Author

@RAMitchell RAMitchell Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sklearn version here (https://github.com/scikit-learn/scikit-learn/blob/9f85c9d44965b764f40169ef2917e5f7a798684f/sklearn/metrics/pairwise.py#L1265), when ported using cupy and using cumls normalize function, seemed to be numerically unstable to me. This is why I kept the custom kernel version. I can look more into it if necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering if correcting the cosine distance from cuml.metric.pairwise_distances back to a similarity might help eliminate the jit overhead from this one as well. If not, we can always look further into it in the future. Thanks for changing the other ones!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I have changed that one to use cosine distance too.

@cjnolet cjnolet added 3 - Ready for Review Ready for review by team improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Feb 3, 2022
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes LGTM

v22.04 Release automation moved this from PR-Needs review to PR-Reviewer approved Feb 4, 2022
@cjnolet
Copy link
Member

cjnolet commented Feb 8, 2022

rerun tests

1 similar comment
@cjnolet
Copy link
Member

cjnolet commented Feb 8, 2022

rerun tests

@codecov-commenter
Copy link

Codecov Report

❗ No coverage uploaded for pull request base (branch-22.04@5b676a1). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@               Coverage Diff               @@
##             branch-22.04    #4492   +/-   ##
===============================================
  Coverage                ?   85.74%           
===============================================
  Files                   ?      239           
  Lines                   ?    19588           
  Branches                ?        0           
===============================================
  Hits                    ?    16796           
  Misses                  ?     2792           
  Partials                ?        0           
Flag Coverage Δ
dask 46.20% <0.00%> (?)
non-dask 78.74% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5b676a1...c9955f0. Read the comment docs.

@cjnolet
Copy link
Member

cjnolet commented Feb 9, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 9921c61 into rapidsai:branch-22.04 Feb 9, 2022
v22.04 Release automation moved this from PR-Reviewer approved to Done Feb 9, 2022
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this pull request Oct 9, 2023
Sklearn reference implementation: https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09b/sklearn/kernel_ridge.py#L16

I've tried to avoid touching the c++/cuda layer so far. Pairwise kernels are implented based on a numba kernel for now. I've also used cupy's lapack wrapper to access cuSolver.

The implementation of `pairwise_kernels` here can be reused to very easily implement kernel PCA.

Todo:
- [x] Single target fit/predict
- [x] Standard kernels implemented
- [x] Support custom kernels
- [x] Support sample weights
- [ ] ~~Support CSR X matrix. Maybe too difficult for this PR.~~
- [x] Multi-target fit/predict
- [x] Change .py files to .pyx and moved to correct places.
- [x] Benchmarking on reasonably large files
- [x] Tests take less than 20s
- [x] Ensure correct handling of input/output array types (I think I need to be using CumlArray and maybe some decorators)
- [x] Documentation

Authors:
  - Rory Mitchell (https://github.com/RAMitchell)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Micka (https://github.com/lowener)

URL: rapidsai#4492
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review Ready for review by team Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants