Skip to content

Commit

Permalink
feat(ui): add reference multiclass model quality (#69)
Browse files Browse the repository at this point in the history
* feat(ui): add reference multiclass model quality

* feat(ui): fix background style of multiclass table
  • Loading branch information
dvalleri authored Jul 2, 2024
1 parent f062c1d commit b505692
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 100 deletions.
6 changes: 4 additions & 2 deletions ui/src/components/charts/confusion-matrix-chart/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ echarts.use([
VisualMapComponent,
]);

function ConfusionMatrix({ dataset, labelClass, colors }) {
function ConfusionMatrix({
dataset, labelClass, colors, height = '20rem',
}) {
if (!dataset) return false;

const handleOnChartReady = (echart) => {
Expand All @@ -31,7 +33,7 @@ function ConfusionMatrix({ dataset, labelClass, colors }) {
echarts={echarts}
onChartReady={handleOnChartReady}
option={confusionMatrixOptions(dataset, labelClass, colors)}
style={{ height: '20rem' }}
style={{ height }}
/>
)}
size="small"
Expand Down
2 changes: 1 addition & 1 deletion ui/src/components/charts/confusion-matrix-chart/options.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export default function confusionMatrixOptions(dataset, labelClass, colors) {
...commonChartOptions.xAxisOptions.categoryType(labelClass.xAxisLabel),
...commonChartOptions.gridOptions.heatmapChart(),
...commonChartOptions.commonOptions.heatmapChart(),
...commonChartOptions.visualMapOptions.heatmapChart(dataMax, colors),
...commonChartOptions.visualMapOptions.heatmapChart(dataMax, colors, '250rem'),
series: {
...commonChartOptions.seriesOptions.heatmapChart(heatmapData),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ const useGetPredictions = () => {

const probabilityValidTypes = {
[ModelTypeEnum.BINARY_CLASSIFICATION]: ['float', 'double'],
[ModelTypeEnum.MULTI_CLASSIFICATION]: ['float', 'double'],
[ModelTypeEnum.REGRESSION]: [],
[ModelTypeEnum.MULTI_CLASSIFICATION]: ['float', 'double', 'string'],
[ModelTypeEnum.REGRESSION]: ['float', 'double'],
};
const useGetProbabilities = () => {
const { useFormbitStepOne, useFormbit } = useModalContext();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { CHART_COLOR } from '@Helpers/common-chart-options';
import { numberFormatter } from '@Src/constants';

export default [
{
title: '',
key: 'className',
dataIndex: 'className',
render: (label) => <div className="font-[var(--coo-font-weight-bold)]">{label}</div>,
},
{
title: 'Reference Precision',
key: 'precision',
dataIndex: 'precision',
align: 'right',
width: '10rem',
onCell: () => ({ style: { background: CHART_COLOR.REFERENCE_LIGHT } }),
render: (precision) => numberFormatter().format(precision),
},
{
title: 'Reference Recall',
key: 'recall',
dataIndex: 'recall',
align: 'right',
width: '10rem',
onCell: () => ({ style: { background: CHART_COLOR.REFERENCE_LIGHT } }),
render: (recall) => numberFormatter().format(recall),
},
{
title: 'Reference F1-Score',
key: 'fMeasure',
dataIndex: 'fMeasure',
align: 'right',
width: '10rem',
onCell: () => ({ style: { background: CHART_COLOR.REFERENCE_LIGHT } }),
render: (fMeasure) => numberFormatter().format(fMeasure),
},
{
title: 'Reference True Positive Rate',
key: 'truePositiveRate',
dataIndex: 'truePositiveRate',
align: 'right',
width: '10rem',
onCell: () => ({ style: { background: CHART_COLOR.REFERENCE_LIGHT } }),
render: (truePositiveRate) => numberFormatter().format(truePositiveRate),
},
{
title: 'Reference False Positive Rate',
key: 'falsePositiveRate',
dataIndex: 'falsePositiveRate',
align: 'right',
width: '10rem',
onCell: () => ({ style: { background: CHART_COLOR.REFERENCE_LIGHT } }),
render: (falsePositiveRate) => numberFormatter().format(falsePositiveRate),
},
];
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { useGetReferenceModelQualityQueryWithPolling } from '@Src/store/state/models/polling-hook';
import { DataTable } from '@radicalbit/radicalbit-design-system';
import columns from './columns.jsx';

function ClassTableMetrics() {
const { data } = useGetReferenceModelQualityQueryWithPolling();
const classMetrics = data?.modelQuality.classMetrics.map(({
className, metrics: {
precision, falsePositiveRate, recall, truePositiveRate, fMeasure,
},
}) => ({
className,
precision,
falsePositiveRate,
truePositiveRate,
recall,
fMeasure,
})) ?? [];

return (
<DataTable
columns={columns}
dataSource={classMetrics}
modifier="m-4"
pagination={false}
rowKey={({ label }) => label}
scroll={{ y: '32rem' }}
size="small"
/>
);
}

export default ClassTableMetrics;

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,92 +1,30 @@
import JobStatus from '@Components/JobStatus';
import ConfusionMatrix from '@Components/charts/confusion-matrix-chart';
import { MODEL_QUALITY_FIELD } from '@Container/models/Details/constants';
import { JOB_STATUS } from '@Src/constants';
import { CHART_COLOR } from '@Helpers/common-chart-options';
import { JOB_STATUS, numberFormatter } from '@Src/constants';
import { useGetReferenceModelQualityQueryWithPolling } from '@State/models/polling-hook';
import {
Board, DataTable, SectionTitle, Spinner,
Board,
SectionTitle,
Spinner,
} from '@radicalbit/radicalbit-design-system';
import { memo } from 'react';
import { CHART_COLOR } from '@Helpers/common-chart-options';
import columns from './columns';
import ClassTableMetrics from './class-table-metrics';

function MultiClassificationModelQualityMetrics() {
const { data, isLoading } = useGetReferenceModelQualityQueryWithPolling();

const jobStatus = data?.jobStatus;

if (jobStatus === JOB_STATUS.SUCCEEDED) {
const leftTableData = data ? [
{ label: MODEL_QUALITY_FIELD.ACCURACY, value: data.modelQuality.accuracy },
{ label: MODEL_QUALITY_FIELD.PRECISION, value: data.modelQuality.precision },
{ label: MODEL_QUALITY_FIELD.RECALL, value: data.modelQuality.recall },
{ label: MODEL_QUALITY_FIELD.F1, value: data.modelQuality.f1 },
] : [];

const centerTableData = data ? [
{ label: 'False positive rate', value: data.modelQuality.falsePositiveRate },
{ label: 'True positive rate', value: data.modelQuality.truePositiveRate },
] : [];

const rightTableData = data ? [
{ label: MODEL_QUALITY_FIELD.AREA_UNDER_ROC, value: data.modelQuality.areaUnderRoc },
{ label: MODEL_QUALITY_FIELD.AREA_UNDER_PR, value: data.modelQuality.areaUnderPr },
] : [];

const confusionMatrixLabel = {
xAxisLabel: ['Predicted: 1', 'Predicted: 0'],
yAxisLabel: ['Actual: 0', 'Actual: 1'],
};

const confusionMatrixData = [
[data.modelQuality.truePositiveCount, data.modelQuality.falsePositiveCount],
[data.modelQuality.falseNegativeCount, data.modelQuality.trueNegativeCount],
];

return (
<Spinner spinning={isLoading}>
<div className="flex flex-col gap-4 py-4">
<Board
header={<SectionTitle size="small" title="Performance metrics" />}
main={(
<div className="flex flew-row gap-4">
<DataTable
columns={columns}
dataSource={leftTableData}
modifier="basis-1/3"
pagination={false}
rowKey={({ label }) => label}
size="small"
/>

<DataTable
columns={columns}
dataSource={centerTableData}
modifier="basis-1/3"
pagination={false}
rowKey={({ label }) => label}
size="small"
/>

<DataTable
columns={columns}
dataSource={rightTableData}
modifier="basis-1/3"
pagination={false}
rowKey={({ label }) => label}
size="small"
/>
</div>
)}
size="small"
type="secondary"
/>

<ConfusionMatrix
colors={[CHART_COLOR.WHITE, CHART_COLOR.REFERENCE]}
dataset={confusionMatrixData}
labelClass={confusionMatrixLabel}
/>

<GlobalMetrics />

<ClassTableMetrics />

</div>
</Spinner>
);
Expand All @@ -95,4 +33,110 @@ function MultiClassificationModelQualityMetrics() {
return (<JobStatus jobStatus={jobStatus} />);
}

function GlobalMetrics() {
const { data } = useGetReferenceModelQualityQueryWithPolling();
const labels = data?.modelQuality.classes ?? [];
const confusionMatrixData = data?.modelQuality.globalMetrics.confusionMatrix ?? [];

const confusionMatrixLabel = {
xAxisLabel: labels,
yAxisLabel: labels.toReversed(),
};

return (
<div className="flex flex-row gap-4">
<div className="flex flex-col gap-4 basis-1/6">
<AccuracyCounter />

<F1ScoreCounter />

<ClassCounter />
</div>

<div className="w-full">
<ConfusionMatrix
colors={[CHART_COLOR.WHITE, CHART_COLOR.REFERENCE]}
dataset={confusionMatrixData}
height="36rem"
labelClass={confusionMatrixLabel}
/>
</div>
</div>
);
}

function AccuracyCounter() {
const { data } = useGetReferenceModelQualityQueryWithPolling();
const accuracy = data?.modelQuality.globalMetrics.accuracy ?? 0;
const accuracyFormatted = numberFormatter().format(accuracy);

return (
<Board
header={<SectionTitle size="small" title="Accuracy" />}
main={(
<div className="flex flex-col h-full items-center justify-center gap-4">

{/* FIXME: inline style */}
<div className="font-bold text-6xl" style={{ fontFamily: 'var(--coo-header-font)' }}>
{accuracyFormatted}
</div>

</div>
)}
modifier="h-full shadow"
size="small"
type="secondary"
/>
);
}

function F1ScoreCounter() {
const { data } = useGetReferenceModelQualityQueryWithPolling();
const f1Score = data?.modelQuality.globalMetrics.f1 ?? 0;
const f1ScoreFormatted = numberFormatter().format(f1Score);

return (
<Board
header={<SectionTitle size="small" title="F1 Score" />}
main={(
<div className="flex flex-col h-full items-center justify-center gap-4">

{/* FIXME: inline style */}
<div className="font-bold text-6xl" style={{ fontFamily: 'var(--coo-header-font)' }}>
{f1ScoreFormatted}
</div>

</div>
)}
modifier="h-full shadow"
size="small"
type="secondary"
/>
);
}

function ClassCounter() {
const { data } = useGetReferenceModelQualityQueryWithPolling();
const classes = data?.modelQuality.classes ?? [];

return (
<Board
header={<SectionTitle size="small" title="Classes" />}
main={(
<div className="flex flex-col h-full items-center justify-center gap-4">

{/* FIXME: inline style */}
<div className="font-bold text-6xl" style={{ fontFamily: 'var(--coo-header-font)' }}>
{classes.length}
</div>

</div>
)}
modifier="h-full shadow"
size="small"
type="secondary"
/>
);
}

export default memo(MultiClassificationModelQualityMetrics);
12 changes: 7 additions & 5 deletions ui/src/helpers/common-chart-options.js
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,21 @@ const lineGridOptions = () => ({

const heatmapGridOptions = () => ({
grid: {
bottom: 80,
bottom: 24,
top: 0,
left: 64,
right: 0,
right: 60,
},
});

const heatmapVisualMapOptions = (dataMax, colors) => {
const heatmapVisualMapOptions = (dataMax, colors, itemHeight) => {
const options = {
visualMap: {
calculable: true,
orient: 'horizontal',
left: 'center',
orient: 'vertical',
right: 'right',
top: 'center',
itemHeight,
},
};

Expand Down

0 comments on commit b505692

Please sign in to comment.