Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] different outputs for UMAP on CPU vs. GPU #5474

Open
antortjim opened this issue Jun 21, 2023 · 9 comments
Open

[BUG] different outputs for UMAP on CPU vs. GPU #5474

antortjim opened this issue Jun 21, 2023 · 9 comments
Labels
? - Needs Triage Need team to review and classify bug Something isn't working

Comments

@antortjim
Copy link

antortjim commented Jun 21, 2023

Describe the bug
Similar to #5473 , the results of the GPU UMAP implementation provided by cuml don't match what I expect using the CPU implementation (from umap-learn). In particular, if

  1. hash_input=True (recommended when comparing cpu vs gpu implementations, as explained in [QST] Relationship between UMAP.embedding_ and reductions returned by UMAP.transform() #5188
  2. dataset contains samples from the training set
  3. dataset contains samples outside the training set as well

then the output plot is completely garbled and messed up

However, if only new data is provided (simulated by the test data), then it seems to be OK.

Steps/Code to reproduce bug

import os.path
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.keras.datasets import mnist # 60k datapoints

from umap import UMAP
from cuml.manifold.umap import UMAP as cuUMAP

FIGURES_PATH="umap_figures"
os.makedirs(FIGURES_PATH, exist_ok=True)


def load_tf_mnist():
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

    # We'll just use the training data for simplicity
    images = np.reshape(train_images, (len(train_images), -1))
    test_images = np.reshape(test_images, (len(test_images), -1))
   

    return (images, train_labels), (test_images, test_labels)

def plot_umap(fig, ax, embedding, labels, title, xlim=[-15, 15], ylim=[-10, 10]):
    im=ax.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='Spectral', s=5)
    # ax.set_aspect('equal', 'datalim')
    ax.set_title(title, fontsize=24)
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    
    n_labels=len(set(labels))
    if n_labels!=1:
        fig.colorbar(im, boundaries=np.arange(n_labels+1)-0.5).set_ticks(np.arange(n_labels))

(data, labels), (test_data, test_labels) = load_tf_mnist()
LABELS=[0, 1, 2, 3, 4]

test_data = data[np.concatenate([
    np.where(labels == i)[0][3000:3100]
    for i in LABELS
])]

data = data[np.concatenate([
    np.where(labels == i)[0][0:3000]
    for i in LABELS
])]
test_labels = labels[np.concatenate([
    np.where(labels == i)[0][3000:3100]
    for i in LABELS
])]

labels = labels[np.concatenate([
    np.where(labels == i)[0][0:3000]
    for i in LABELS
])]

all_labels=np.concatenate([labels, test_labels])
all_data=np.concatenate([data, test_data])


def benchmark_umaps(device="gpu"):
    fig, axs = plt.subplots(1, 4, figsize=(9, 3), sharey=True)

    if device == "gpu":
        reductor = cuUMAP(n_components=2, n_neighbors=60, min_dist=0.0, random_state=42, hash_input=True).fit(data)
    elif device == "cpu":
        reductor = UMAP(n_components=2, n_neighbors=60, min_dist=0.0, random_state=42).fit(data)
        
    plot_umap(fig, axs[0], reductor.transform(data), labels, title=f"{device}, Train")
    plot_umap(fig, axs[1], reductor.transform(test_data), test_labels, title=f"{device}, Test")
    embedding=reductor.transform(all_data)
    
    plot_umap(fig, axs[2], embedding, all_labels, title=f"{device}, All")    
    test_data_when_processed_with_train=embedding[(-100*len(LABELS)):, :]
    
    plot_umap(fig, axs[3], test_data_when_processed_with_train, [0 for _ in range(len(test_data_when_processed_with_train))] , title=f"{device}, Test only")
    return fig


gpu_fig=benchmark_umaps(device="gpu")
cpu_fig=benchmark_umaps(device="cpu")

cpu_fig.savefig(os.path.join(FIGURES_PATH, "cpu.png"))
gpu_fig.savefig(os.path.join(FIGURES_PATH, "gpu.png"))

gpu
cpu

Expected behavior
I expected the gpu, All plot to be ordered into the same clusters seen in gpu, Train and gpu, Test, in fact, it should be identical to combining both into a single plot (just like with the cpu plot). Instead, the output is messed up.
This is not dependent on the value of hash_input (which only affects the output of gpu, Train).

A closer look at where the Test data points have been projected to when processed together with the Train data points (Test only plots) shows that their position is not the same as when processed alone i.e. the 4 plot in each row should be identical or very similar to the 2nd one (just with monocolor). This is not true, especially in the gpu implementation.

Environment details (please complete the following information):

  • Environment location: Bare metal
  • Linux Distro/Architecture: Ubuntu 22.04, CPU = AMD® Ryzen threadripper pro 5975wx 32-cores × 64
  • GPU Model/Driver: NVIDIA RTX A6000/PCIe/SSE2 and driver 525.105.17
  • CUDA: 11.8
  • Method of cuDF & cuML install: pip inside mamba
  • Python 3.10.11
mamba --version
mamba 1.4.1
conda 23.1.0
# taken from https://docs.rapids.ai/install
pip install cudf-cu11 cuml-cu11 --extra-index-url=https://pypi.nvidia.com
@antortjim antortjim added ? - Needs Triage Need team to review and classify bug Something isn't working labels Jun 21, 2023
@antortjim
Copy link
Author

antortjim commented Jun 22, 2023

Inspired by #5473 , I tried two things

  1. Making my data dtype np.float32, to ensure no numerical imprecision artifacts are not occuring
  2. I made my data cupy.arrays before passing it to UMAP
import os.path
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.keras.datasets import mnist # 60k datapoints
import cupy
from umap import UMAP
from cuml.manifold.umap import UMAP as cuUMAP

FIGURES_PATH="umap_figures"
os.makedirs(FIGURES_PATH, exist_ok=True)


def load_tf_mnist():
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

    # We'll just use the training data for simplicity
    images = np.reshape(train_images, (len(train_images), -1))
    test_images = np.reshape(test_images, (len(test_images), -1))

    images=images.astype(np.float32)
    test_images=test_images.astype(np.float32)
    

    return (images, train_labels), (test_images, test_labels)

def plot_umap(fig, ax, embedding, labels, title, xlim=[-15, 15], ylim=[-10, 10]):
    im=ax.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='Spectral', s=5)
    # ax.set_aspect('equal', 'datalim')
    ax.set_title(title, fontsize=24)
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    
    n_labels=len(set(labels))
    if n_labels!=1:
        fig.colorbar(im, boundaries=np.arange(n_labels+1)-0.5).set_ticks(np.arange(n_labels))

(data, labels), (test_data, test_labels) = load_tf_mnist()
LABELS=[0, 1, 2, 3, 4]

test_data = data[np.concatenate([
    np.where(labels == i)[0][3000:3100]
    for i in LABELS
])]

data = data[np.concatenate([
    np.where(labels == i)[0][0:3000]
    for i in LABELS
])]
test_labels = labels[np.concatenate([
    np.where(labels == i)[0][3000:3100]
    for i in LABELS
])]

labels = labels[np.concatenate([
    np.where(labels == i)[0][0:3000]
    for i in LABELS
])]

all_labels=np.concatenate([labels, test_labels])
all_data=np.concatenate([data, test_data])


def benchmark_umaps(data, test_data, all_data, device="gpu"):
    fig, axs = plt.subplots(1, 4, figsize=(9, 3), sharey=True)

    if device == "gpu":
        data = cupy.array(data)
        test_data = cupy.array(test_data)
        all_data = cupy.array(all_data)        
        reductor = cuUMAP(n_components=2, n_neighbors=60, min_dist=0.0, random_state=42, hash_input=True).fit(data)
    
    elif device == "cpu":
        reductor = UMAP(n_components=2, n_neighbors=60, min_dist=0.0, random_state=42).fit(data)

    embedding1=reductor.transform(data)
    embedding2=reductor.transform(test_data)
    embedding3=reductor.transform(all_data)
    embedding4=embedding3[(-100*len(LABELS)):, :]

    if device == "gpu":
        embedding1=embedding1.get()
        embedding2=embedding2.get()
        embedding3=embedding3.get()
        embedding4=embedding4.get()
        

    plot_umap(fig, axs[0], embedding1, labels, title=f"{device}, Train")
    plot_umap(fig, axs[1], embedding2, test_labels, title=f"{device}, Test")
    plot_umap(fig, axs[2], embedding3, all_labels, title=f"{device}, All")    
    plot_umap(fig, axs[3], embedding4, [0 for _ in range(len(embedding4))] , title=f"{device}, Test only")
    return fig


gpu_fig=benchmark_umaps(data, test_data, all_data, device="gpu")
cpu_fig=benchmark_umaps(data, test_data, all_data, device="cpu")

gpu_fig.savefig(os.path.join(FIGURES_PATH, "gpu.png"))
cpu_fig.savefig(os.path.join(FIGURES_PATH, "cpu.png"))

I unfortunately still have the same issues

@viclafargue
Copy link
Contributor

cuML and umap-learn may not use the exact same default values for some of the hyperparameters. Could you double check that your code is running the same parameters for both, especially for things like n_epochs or learning_rate? This could also be the result of a genuine issue in the code though. Regarding the dissimilarity of CPU vs GPU clusters, it is expected as cuML does not reproduce the umap-learn output.

@antortjim
Copy link
Author

antortjim commented Jun 22, 2023

@viclafargue thank you so much for your input. Indeed, increasing n_epochs radically improves the output!
this is true regardless of whether hash_input is set to True (first row) or False (second row).
image
image

The next obvious question of course is how to set the n_epochs to make sure they are high enough.
Also, would be nice to know how many epochs actually run when n_epochs is set to None or left unset (None by default). I tried checking the UMAP object n_epochs argument but it's None. From the plots, it seems to be pretty close to 100, since None and 100 are very similar.
The training dataset is 15000 rows and 784 columns, and according to docs https://docs.rapids.ai/api/cuml/nightly/api/#umap

n_epochs: int (optional, default None)
The number of training epochs to be used in optimizing the low dimensional embedding. Larger values result in more accurate embeddings. If None is specified a value will be selected based on the size of the input dataset (200 for large datasets, 500 for small).

But it's not clear what is a small or large dataset. It's also not explained in the CPU implementation docs https://umap-learn.readthedocs.io/en/latest/api.html.

PS cupy.array and np.float32 explicitly declared are not needed
PSS Indeed, I don't expect the CPU and GPU plots to be identical, just as I don't expect two GPU plots to be identical if random_state is not set to the same value on both. However, they should show more or less the same overall trend, I think

@antortjim
Copy link
Author

Dear @viclafargue I am still very interested in this. Is there some logic to how much n_epochs should be set to? Or is it only trial and error? Thanks!

@viclafargue
Copy link
Contributor

I think that there is no solid rule about the best value to use. If you want to know more about this, it might be interesting to read Leland McInnes papers. But, in general terms, the goal of dimensionality reduction is to obtain a representative output. Then, adjusting hyperparameters comes down to testing the trustworthiness of the embeddings for different values. cuml.metrics.trustworthiness.trustworthiness offers a GPU accelerated calculation of such value.

cc @cjnolet who might have more info about this

@antortjim
Copy link
Author

Wouldn't it make sense to use the worthiness metric as a measure of "convergence" akin to the loss function during backpropagation, so that cuml UMAP runs until either the number of epochs provided is exhausted OR the worthiness reaches a high enough value (aka early stop)?

@antortjim
Copy link
Author

Also, I am not sure what trustworthiness value is adequate, would for example 0.8 be high enough? I see very significant differences in output quality with trustworthiness of 0.8 and 0.9 i.e. I would not tell from the trustworthiness that the quality is so different

@cjnolet
Copy link
Member

cjnolet commented Jul 11, 2023

@antortjim in general, the relationship that the trustworthiness score has to UMAP's objective is similar to that which a score like categorical accuracy might have to categorical cross-entropy. Accuracy, by example, is a more general score for classifiers while categorical cross entropy is specific to an algorithm's objective. We can use accuracy to measure the quality of results for any categorical classifier and we can use trustworthiness to measure the quality of any manifold learning algorithm. Trustworthiness measures the degree to which the neighborhoods in the output space preserve local neighborhood structure (by literally comparing knn results) from the input space. In general, youll expect to see reasonably well-preserved results around 0.91 and above, while a trustworthiness of 0.8 and below might preserve some neighborhoods while other could still looking like random noise.

All that being said, there is a known bug with our laplacian eigenmaps (aka spectral) initialization solver, and the resulting embedding quality can sometimes (but not always) be improved by increasing the number of epochs. Please note that this is in addition to, and not mutually exclusive of, what Victor pointed out- the initialization and parallelization of the algorithms does still cause slightly different behaviors, which can cause slightly different results with the same parameter settings.

One thing to try might be to set "init=random" to ignore differences that are specific to the spectral initialization. Another common technique that could improve the quality with spectral initialization is to run the points through a PCA before calling UMAP.

@antortjim
Copy link
Author

Thank you @cjnolet and sorry for getting back here late. My take-home message then is that the trustworthiness cannot be used as a measurement of convergence, but a measurement of correctness, and in order to ensure convergence I need to pass a high enough n_epochs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants