-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TFMA analyze_raw_data function support with MultiClassConfusionMatrixPlot #162
Comments
In official documentation, it shows that multi-class classification metrics is supported by TFMA. Also, Thank you! |
@singhniraj08 thanks for getting back! I do see that If it is supported: |
Thanks for bringing this to our attention. This looks like a bug in some code that is attempting to construct pyarrow.RecordBatches in a special way. I suspect this was intended to limit the supported types to avoid problems downstream, but it seems that it is too strict right now. It is possible that this may lead to more cryptic errors downstream, but at least for the toy DataFrame you've provided, I was able to get things working by just directly calling
I'll look into supporting this case more generically, but I figured I'd share this workaround now to unblock you. |
System information
provided in TensorFlow Model Analysis): No
pip install tensorflow-model-analysis
Describe the problem
I am currently trying to get
tfma.analyze_raw_data
to work withMultiClassConfusionMatrixPlot
which has multiple prediction values per record. Is this not supported? I will be happy to provide any further details or run any further tests.Details
Currently
tfma.analyze_raw_data
does not seem to work with metrics for multi classification tasks (e.g.tfma.metrics.MultiClassConfusionMatrixPlot
). However, I do not see this limitation documented anywhere.The prediction column for a multi classification column will be a series of whose values are a list or array (e.g.,.
pd.DataFrame({'predictions': [[0.2, .3, .5]], 'label': [1]})
)The
tfma.analyze_raw_data
funciton usestfx_bsl.arrow.DataFrameToRecordBatch
to convert a Pandas DataFrame to Arrow RecordBatch. The problem, however, is that it encodes columns with the dtype ofobject
as apyarrow.Binary
. Since a column that has lists or arrays as values has a dtype ofobject
, these columns are being encoded as apyarrow.Binary
instead of the relevant pyarrow list-like type.Source code / logs
Error
Temporary fix
If I change/patch
tfx_bsl.arrow.DataFrameToRecordBatch
as follows, it seems to work, but I doubt this is a solution.The text was updated successfully, but these errors were encountered: