Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] different outputs for PCA on CPU vs. GPU #5473

Open
stephanie-fu opened this issue Jun 20, 2023 · 9 comments
Open

[BUG] different outputs for PCA on CPU vs. GPU #5473

stephanie-fu opened this issue Jun 20, 2023 · 9 comments
Labels
? - Needs Triage Need team to review and classify bug Something isn't working

Comments

@stephanie-fu
Copy link

stephanie-fu commented Jun 20, 2023

Describe the bug
Using cuml.PCA with set_global_device_type to 'CPU' and 'GPU' produce different results (with set_global_device_type('CPU') matching the output of sklearn's PCA).

Steps/Code to reproduce bug

from cuml.common.device_selection import set_global_device_type
import matplotlib.pyplot as plt
import torch

from sklearn.decomposition import PCA as skPCA
from cuml import PCA as cuPCA

data = torch.randn((10000, 2)).cuda()
data[:, 0] *= 10

plt.scatter(data[:, 0].cpu(), data[:, 1].cpu())
plt.gca().set_xlim(-50, 50)
plt.gca().set_ylim(-50, 50)
plt.show()

# 1: cuML on CPU
set_global_device_type('CPU')
dimred = cuPCA(n_components=2, output_type='numpy')
data_dimred = dimred.fit_transform(data)
plt.scatter(data_dimred[:, 0], data_dimred[:, 1])
plt.gca().set_xlim(-50, 50)
plt.gca().set_ylim(-50, 50)
plt.show()

# 2: cuML on GPU
set_global_device_type('GPU')
dimred = cuPCA(n_components=2, output_type='numpy')
data_dimred = dimred.fit_transform(data)
plt.scatter(data_dimred[:, 0], data_dimred[:, 1])
plt.gca().set_xlim(-50, 50)
plt.gca().set_ylim(-50, 50)
plt.show()

# 3: sklearn on CPU
dimred = skPCA(n_components=2)
data_dimred = dimred.fit_transform(data.cpu())
plt.scatter(data_dimred[:, 0], data_dimred[:, 1])
plt.gca().set_xlim(-50, 50)
plt.gca().set_ylim(-50, 50)
plt.show()

Expected behavior
All 3 examples above are expected to have the same output.

Environment details (please complete the following information):

  • Linux Distro/Architecture: Ubuntu 20.04.2
  • GPU Model/Driver: NVIDIA RTX A6000
  • CUDA: 12.1
  • Method of cuDF & cuML install: conda
asttokens                 2.2.1                    pypi_0    pypi
astunparse                1.6.3                    pypi_0    pypi
async-timeout             4.0.2                    pypi_0    pypi
attrs                     23.1.0                   pypi_0    pypi
backcall                  0.2.0                    pypi_0    pypi
beautifulsoup4            4.12.2                   pypi_0    pypi
bleach                    6.0.0                    pypi_0    pypi
blessed                   1.20.0                   pypi_0    pypi
ca-certificates           2023.01.10           h06a4308_0
cachetools                5.3.1                    pypi_0    pypi
certifi                   2023.5.7                 pypi_0    pypi
cffi                      1.15.1                   pypi_0    pypi
charset-normalizer        3.1.0                    pypi_0    pypi
click                     8.1.3                    pypi_0    pypi
cloudpickle               2.2.1                    pypi_0    pypi
cmake                     3.26.3                   pypi_0    pypi
comm                      0.1.3                    pypi_0    pypi
contourpy                 1.0.7                    pypi_0    pypi
croniter                  1.3.15                   pypi_0    pypi
cubinlinker-cu11          0.3.0.post1              pypi_0    pypi
cuda-python               11.8.2                   pypi_0    pypi
cudf-cu11                 23.6.0                   pypi_0    pypi
cuml-cu11                 23.6.0                   pypi_0    pypi
cupy-cuda12x              12.1.0                   pypi_0    pypi
cycler                    0.11.0                   pypi_0    pypi
cython                    0.29.35                  pypi_0    pypi
dask                      2023.3.2                 pypi_0    pypi
dask-cuda                 23.6.0                   pypi_0    pypi
dask-cudf-cu11            23.6.0                   pypi_0    pypi
dateutils                 0.6.12                   pypi_0    pypi
debugpy                   1.6.7                    pypi_0    pypi
decorator                 5.1.1                    pypi_0    pypi
deepdiff                  6.3.0                    pypi_0    pypi
defusedxml                0.7.1                    pypi_0    pypi
distributed               2023.3.2.1               pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
einops                    0.6.1                    pypi_0    pypi
exceptiongroup            1.1.1                    pypi_0    pypi
executing                 1.2.0                    pypi_0    pypi
fast-pytorch-kmeans       0.2.0.1                  pypi_0    pypi
fastapi                   0.88.0                   pypi_0    pypi
fastjsonschema            2.17.1                   pypi_0    pypi
fastrlock                 0.8.1                    pypi_0    pypi
filelock                  3.12.0                   pypi_0    pypi
flatbuffers               23.5.26                  pypi_0    pypi
fonttools                 4.39.4                   pypi_0    pypi
fqdn                      1.5.1                    pypi_0    pypi
frozenlist                1.3.3                    pypi_0    pypi
fsspec                    2023.5.0                 pypi_0    pypi
gast
inquirer                  3.1.3                    pypi_0    pypi
ipykernel                 6.23.1                   pypi_0    pypi
ipython                   8.13.2                   pypi_0    pypi
ipython-genutils          0.2.0                    pypi_0    pypi
ipywidgets                8.0.6                    pypi_0    pypi
isoduration               20.11.0                  pypi_0    pypi
itsdangerous              2.1.2                    pypi_0    pypi
jax                       0.4.11                   pypi_0    pypi
jedi                      0.18.2                   pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
joblib                    1.2.0                    pypi_0    pypi
jsonpointer               2.3                      pypi_0    pypi
jsonschema                4.17.3                   pypi_0    pypi
jupyter                   1.0.0                    pypi_0    pypi
jupyter-client            8.2.0                    pypi_0    pypi
jupyter-console           6.6.3                    pypi_0    pypi
jupyter-core              5.3.0                    pypi_0    pypi
jupyter-events            0.6.3                    pypi_0    pypi
jupyter-server            2.6.0                    pypi_0    pypi
jupyter-server-terminals  0.4.4                    pypi_0    pypi
jupyterlab-pygments       0.2.2                    pypi_0    pypi
jupyterlab-widgets        3.0.7                    pypi_0    pypi
keras                     2.12.0                   pypi_0    pypi
kiwisolver                1.4.4                    pypi_0    pypi
kmeans-pytorch            0.3                      pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1
libclang                  16.0.0                   pypi_0    pypi
libffi                    3.4.4                h6a678d5_0
libgcc-ng                 11.2.0               h1234567_1
libgomp                   11.2.0               h1234567_1
libstdcxx-ng              11.2.0               h1234567_1
lightning                 2.0.3                    pypi_0    pypi
lightning-cloud           0.5.36                   pypi_0    pypi
lightning-utilities       0.8.0                    pypi_0    pypi
lit                       16.0.5                   pypi_0    pypi
llvmlite                  0.40.1rc1                pypi_0    pypi
locket                    1.0.0                    pypi_0    pypi
markdown                  3.4.3                    pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.2                    pypi_0    pypi
matplotlib                3.7.1                    pypi_0    pypi
matplotlib-inline         0.1.6                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mistune                   2.0.5                    pypi_0    pypi
ml-dtypes                 0.1.0                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
msgpack                   1.0.5                    pypi_0    pypi
multidict                 6.0.4                    pypi_0    pypi
nbclassic                 1.0.0                    pypi_0    pypi
nbclient                  0.8.0                    pypi_0
nvidia-cuda-runtime-cu11  11.7.99                  pypi_0    pypi
nvidia-cudnn-cu11         8.5.0.96                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-curand-cu11        10.2.10.91               pypi_0    pypi
nvidia-cusolver-cu11      11.4.0.1                 pypi_0    pypi
nvidia-cusparse-cu11      11.7.4.91                pypi_0    pypi
nvidia-nccl-cu11          2.14.3                   pypi_0    pypi
nvidia-nvtx-cu11          11.7.91                  pypi_0    pypi
nvtx                      0.2.5                    pypi_0    pypi
oauthlib                  3.2.2                    pypi_0    pypi
openssl                   1.1.1t               h7f8727e_0
opt-einsum                3.3.0                    pypi_0    pypi
ordered-set               4.1.0                    pypi_0    pypi
overrides                 7.3.1                    pypi_0    pypi
packaging                 23.1                     pypi_0    pypi
pandas                    1.5.3                    pypi_0    pypi
pandocfilters             1.5.0                    pypi_0    pypi
parso                     0.8.3                    pypi_0    pypi
partd                     1.4.0                    pypi_0    pypi
pathtools                 0.1.2                    pypi_0    pypi
pexpect                   4.8.0                    pypi_0    pypi
pickleshare               0.7.5                    pypi_0    pypi
pillow                    9.5.0                    pypi_0    pypi
pip                       23.0.1           py39h06a4308_0
platformdirs              3.5.1                    pypi_0    pypi
prometheus-client         0.17.0                   pypi_0    pypi
prompt-toolkit            3.0.38                   pypi_0    pypi
protobuf                  4.21.12                  pypi_0    pypi
psutil                    5.9.5                    pypi_0    pypi
ptxcompiler-cu11          0.7.0.post1              pypi_0    pypi
ptyprocess                0.7.0                    pypi_0    pypi
pure-eval                 0.2.2                    pypi_0    pypi
pyarrow                   11.0.0                   pypi_0    pypi
pyasn1                    0.5.0                    pypi_0    pypi
pyasn1-modules            0.3.0                    pypi_0    pypi
pycparser                 2.21                     pypi_0    pypi
pydantic                  1.10.9                   pypi_0    pypi
pygments                  2.15.1                   pypi_0    pypi
pyjwt                     2.7.0                    pypi_0    pypi
pylibraft-cu11            23.6.0                   pypi_0    pypi
pynvml                    11.4.1                   pypi_0    pypi
pyparsing                 3.0.9                    pypi_0    pypi
pyrsistent                0.19.3                   pypi_0    pypi
python                    3.9.16               h7a1cb2a_2
python-dateutil           2.8.2                    pypi_0    pypi
python-editor             1.0.4                    pypi_0    pypi
python-json-logger        2.0.7                    pypi_0    pypi
python-multipart          0.0.6                    pypi_0    pypi
pytorch-lightning         2.0.2                    pypi_0    pypi
pytorch-triton
rfc3986-validator         0.1.1                    pypi_0    pypi
rich                      13.4.2                   pypi_0    pypi
rmm-cu11                  23.6.0                   pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
scikit-learn              1.2.2                    pypi_0    pypi
scipy                     1.10.1                   pypi_0    pypi
send2trash                1.8.2                    pypi_0    pypi
sentry-sdk                1.25.0                   pypi_0    pypi
setproctitle              1.3.2                    pypi_0    pypi
setuptools                67.8.0           py39h06a4308_0
six                       1.16.0                   pypi_0    pypi
sklearn                   0.0.post5                pypi_0    pypi
smmap                     5.0.0                    pypi_0    pypi
sniffio                   1.3.0                    pypi_0    pypi
sortedcontainers          2.4.0                    pypi_0    pypi
soupsieve                 2.4.1                    pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0
stack-data                0.6.2                    pypi_0    pypi
starlette                 0.22.0                   pypi_0    pypi
starsessions              1.3.0                    pypi_0    pypi
sympy                     1.12                     pypi_0    pypi
tblib                     1.7.0                    pypi_0    pypi
tensorboard               2.12.3                   pypi_0    pypi
tensorboard-data-server   0.7.0                    pypi_0    pypi
tensorflow                2.12.0                   pypi_0    pypi
tensorflow-estimator      2.12.0                   pypi_0    pypi
tensorflow-io-gcs-filesystem 0.32.0                   pypi_0    pypi
termcolor                 2.3.0                    pypi_0    pypi
terminado                 0.17.1                   pypi_0    pypi
threadpoolctl             3.1.0                    pypi_0    pypi
tinycss2                  1.2.1                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0
toolz                     0.12.0                   pypi_0    pypi
torch                     2.1.0.dev20230619+cu121          pypi_0    pypi
torchaudio                2.1.0.dev20230619+cu121          pypi_0    pypi
torchmetrics              0.11.4                   pypi_0    pypi
torchvision               0.16.0.dev20230619+cu121          pypi_0    pypi
tornado                   6.3.2                    pypi_0    pypi
tqdm                      4.65.0                   pypi_0    pypi
traitlets                 5.9.0                    pypi_0    pypi
treelite                  3.2.0                    pypi_0    pypi
treelite-runtime          3.2.0                    pypi_0    pypi
triton                    2.0.0                    pypi_0    pypi
typing-extensions         4.6.2                    pypi_0    pypi
tzdata                    2023c                h04d1e81_0
ucx-py-cu11               0.32.0                   pypi_0    pypi
uri-template              1.2.0                    pypi_0    pypi
urllib3                   1.26.16                  pypi_0    pypi
uvicorn                   0.22.0                   pypi_0    pypi
wandb                     0.15.3                   pypi_0    pypi
wcwidth                   0.2.6                    pypi_0    pypi
webcolors                 1.13                     pypi_0    pypi
webencodings              0.5.1                    pypi_0    pypi
websocket-client          1.5.2                    pypi_0    pypi
websockets                11.0.3                   pypi_0    pypi
werkzeug                  2.3.4                    pypi_0    pypi
wheel                     0.38.4           py39h06a4308_0
widgetsnbextension        4.0.7                    pypi_0    pypi
wrapt                     1.14.1                   pypi_0    pypi
xz                        5.4.2                h5eee18b_0
yarl                      1.9.2                    pypi_0    pypi
zict                      3.0.0                    pypi_0    pypi
zipp                      3.15.0                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_0
@stephanie-fu stephanie-fu added ? - Needs Triage Need team to review and classify bug Something isn't working labels Jun 20, 2023
@lowener
Copy link
Contributor

lowener commented Jun 20, 2023

@stephanie-fu This is related to #4560.
The signs are flipped, but the result is still valid.

@lowener lowener closed this as completed Jun 20, 2023
@antortjim
Copy link

@lowener Could you please explain how are the results still valid? Indeed, I would expect the output plot to be the same in all three cases. Thank you!

@lowener
Copy link
Contributor

lowener commented Jun 22, 2023

The issue that you opened on UMAP is not related to this.
The problem on this PCA issue is this function https://github.com/rapidsai/cuml/blob/branch-23.08/cpp/src/tsvd/tsvd.cuh#L137 that is used in PCA and TSVD.
Sign flipping doesn't change the correctness of the results. Some columns of U and V can be positive or negative and that won't affect the projection of the data, because a PCA analysis is looking into how much each variable contributes to the data, positively or negatively.
Here is a more in-depth answer. https://stats.stackexchange.com/a/88882

@stephanie-fu
Copy link
Author

stephanie-fu commented Jun 22, 2023

Thank you for the clarification. I am still a little confused about the output of the example above - I would expect all 3 code examples to look like an ellipse, but am getting something from the GPU implementation that visually looks different from a sign flip (in fact, it looks like the CPU output is a flipped version of the data, which seems valid). Is this expected output?

image

@antortjim
Copy link

The issue that you opened on UMAP is not related to this. The problem on this PCA issue is this function https://github.com/rapidsai/cuml/blob/branch-23.08/cpp/src/tsvd/tsvd.cuh#L137 that is used in PCA and TSVD. Sign flipping doesn't change the correctness of the results. Some columns of U and V can be positive or negative and that won't affect the projection of the data, because a PCA analysis is looking into how much each variable contributes to the data, positively or negatively. Here is a more in-depth answer. https://stats.stackexchange.com/a/88882

Thank you for reply. Yes, I imagined probably the root causes are different and thus the solutions or interpretations are gonna be different. I just linked this there because they are similar in the sense that both PCA and UMAP are extremely used algorithms which a lot of people will try to accelerate with cuml. However, doing so they will face the conundrum of getting difficult to interpret results. Maybe in both cases more documentation would be beneficial.

I agree with with @stephanie-fu that if the only explanation here with the PCA is a sign flip, that still does not explain the loss of structure observed on the "PCA on GPU" plot. The GPU output is clearly not maximizing the amount of variance on one axis. In other wors, the expected output, regardless of the sign flip, would be that most of the variance will lie along one axis and only a little bit along the opposite axis (opposite to the round blob observed). How can we explain it?

Thank you again!

@lowener lowener reopened this Jun 22, 2023
@lowener
Copy link
Contributor

lowener commented Jun 22, 2023

I jumped too fast to the sign flipping conclusion. The issue seems to be our compatibility with torch Tensor?
Transforming the data to a cupy array is fixing the problem that you're seeing and can be used as a workaround:

import cupy
import torch
data = torch.randn((10000, 2)).cuda()
data[:, 0] *= 10
data = cupy.array(data)

@antortjim
Copy link

I confirm that making the data a cupy array solves the problem here!
This solution is fully compatible with @stephanie-fu 's code except for the calls to .cpu() which need to be replaced by .get().

@antortjim
Copy link

i.e. this code

from cuml.common.device_selection import set_global_device_type
import matplotlib.pyplot as plt
import torch
import cupy
from sklearn.decomposition import PCA as skPCA
from cuml import PCA as cuPCA

data = torch.randn((10000, 2)).cuda()
data[:, 0] *= 10

data = cupy.array(data)

plt.scatter(data[:, 0].get(), data[:, 1].get())
plt.gca().set_xlim(-50, 50)
plt.gca().set_ylim(-50, 50)
plt.show()

# 1: cuML on CPU
set_global_device_type('CPU')
dimred = cuPCA(n_components=2, output_type='numpy')
data_dimred = dimred.fit_transform(data)
plt.scatter(data_dimred[:, 0], data_dimred[:, 1])
plt.gca().set_xlim(-50, 50)
plt.gca().set_ylim(-50, 50)
plt.show()

# 2: cuML on GPU
set_global_device_type('GPU')
dimred = cuPCA(n_components=2, output_type='numpy')
data_dimred = dimred.fit_transform(data)
plt.scatter(data_dimred[:, 0], data_dimred[:, 1])
plt.gca().set_xlim(-50, 50)
plt.gca().set_ylim(-50, 50)
plt.show()

# 3: sklearn on CPU
dimred = skPCA(n_components=2)
data_dimred = dimred.fit_transform(data.get())
plt.scatter(data_dimred[:, 0], data_dimred[:, 1])
plt.gca().set_xlim(-50, 50)
plt.gca().set_ylim(-50, 50)
plt.show()

@stephanie-fu
Copy link
Author

Thank you for the workaround!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants