Skip to content

Using a DataFrame with a scrambled index and color in px.scatter_matrix() yields mismatched/null colors #4788

@carschandler

Description

@carschandler

When using a DataFrame with scrambled (and optionally sparse) index values (like you'd get from sklearn.model_selection.train_test_split()) and the color field in a px function, the DataFrame is split into a dict of Series in process_args_into_dataframe(). The column corresponding to the color argument has its index reset via to_unindexed_series(), but the rest of the columns do not get reset. For sparse shuffled indices, the reset indices of the color column are re-joined with the original, sparse indices of the other columns, so indices that match are joined but the ones that do not match end up as NaNs. For non-sparse indices, the result is arguably worse: the indices all get re-joined upon re-creation of the DataFrame, and since the indices are not sparse, they will all match, but to incorrect values, so you don't get any indication that something went wrong other than the colors mismatching. While this doesn't seem to affect px.scatter() in a simple test (it seems that doesn't end up returning any df_output at all), I'm sure other px functions are affected by this.

I'm not entirely sure what the purpose of these calls to to_unindexed_series are for, considering that the indices will be re-joined properly if they are not reset, and commenting out the code at

df_output[col_name] = to_unindexed_series(
df_input[argument], col_name
)
fixes the issue for my use-case, but I have no idea what other errors this may introduce. I can try removing all calls to the function and running tests to see if this causes any regressions.

Re-joining the DataFrame using the original indices and just never resetting them seems like the best option to me, but I'm not sure if other parts of the code depend on certain columns having ordered/non-sparse indices. Maybe a simple reset_index on the entire DataFrame up-front would be best, and then we don't have to worry about some columns getting reset while others don't?

Here's a minimal example:

import numpy as np
import pandas as pd
import plotly.express as px

n_rows_total = 100
n_rows_to_sample = 50

df = pd.DataFrame(
    {
        "v1": np.random.randn(n_rows_total),
        "v2": np.random.randn(n_rows_total),
        "v3": np.random.randn(n_rows_total),
        "v_color": np.random.randint(low=0, high=3, size=n_rows_total).astype(float),
    }
)

# Correct result
px.scatter_matrix(df, color="v_color").show()

# Colors (and color values) will match with incorrect data points
px.scatter_matrix(df.sample(n_rows_total), color="v_color").show()

# "null" color values
px.scatter_matrix(df.sample(n_rows_to_sample), color="v_color").show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    P3backlogbugsomething broken

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions