From aef0eb6d44faa4359997546c8be9bd0c59c4ea36 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 5 Oct 2022 17:08:37 +1100 Subject: [PATCH 1/2] fix size of multi source confusion matrix --- extension/src/experiments/model/index.ts | 2 +- extension/src/plots/model/collect.ts | 129 +++++++++++++------ extension/src/test/suite/plots/index.test.ts | 39 +++--- 3 files changed, 112 insertions(+), 58 deletions(-) diff --git a/extension/src/experiments/model/index.ts b/extension/src/experiments/model/index.ts index d1d91d0eab..11faeeaaeb 100644 --- a/extension/src/experiments/model/index.ts +++ b/extension/src/experiments/model/index.ts @@ -39,7 +39,7 @@ import { sum } from '../../util/math' export type StarredExperiments = Record -type SelectedExperimentWithColor = Experiment & { +export type SelectedExperimentWithColor = Experiment & { displayColor: Color selected: true } diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 1440a00539..d53a34f199 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -40,6 +40,7 @@ import { MultiSourceEncoding, unmergeConcatenatedFields } from '../multiSource/collect' +import { StrokeDashEncoding } from '../multiSource/constants' type CheckpointPlotAccumulator = { iterations: Record @@ -502,49 +503,99 @@ export const collectTemplates = (data: PlotsOutput): TemplateAccumulator => { } const updateDatapoints = ( - datapoints: unknown[], + path: string, + revisionData: RevisionData, + selectedRevisions: string[], key: string, fields: string[] ): unknown[] => - datapoints.map(data => { - const obj = data as Record - return { - ...obj, - [key]: mergeFields(fields.map(field => obj[field] as string)) + selectedRevisions + .flatMap(revision => + revisionData?.[revision]?.[path].map(data => { + const obj = data as Record + return { + ...obj, + [key]: mergeFields(fields.map(field => obj[field] as string)) + } + }) + ) + .filter(Boolean) + +const transformMultiSourceMultiViewRevisionData = ( + path: string, + selectedRevisions: string[], + revisionData: RevisionData, + fields: string[], + domain: string[] +) => { + fields.unshift('rev') + + const revisions: string[] = [] + for (const revision of selectedRevisions) { + for (const entry of domain) { + revisions.push([revision, entry].join('::')) } - }) + } + + return { + datapoints: updateDatapoints( + path, + revisionData, + selectedRevisions, + 'rev', + fields + ), + revisions + } +} -const stringifyDatapoints = ( - datapoints: unknown[], - field: string | undefined, - isMultiView: boolean -): string => { +const transformRevisionData = ( + path: string, + selectedRevisions: string[], + revisionData: RevisionData, + isMultiView: boolean, + multiSourceEncodingUpdate: { strokeDash: StrokeDashEncoding } +): { revisions: string[]; datapoints: unknown[] } => { + const field = multiSourceEncodingUpdate.strokeDash?.field if (!field || (!isMultiView && !isConcatenatedField(field))) { - return JSON.stringify(datapoints) + return { + datapoints: selectedRevisions + .flatMap(revision => revisionData?.[revision]?.[path]) + .filter(Boolean), + revisions: selectedRevisions + } } const fields = unmergeConcatenatedFields(field) if (isMultiView) { - fields.unshift('rev') - return JSON.stringify(updateDatapoints(datapoints, 'rev', fields)) + return transformMultiSourceMultiViewRevisionData( + path, + selectedRevisions, + revisionData, + fields, + multiSourceEncodingUpdate.strokeDash.scale.domain + ) } - return JSON.stringify(updateDatapoints(datapoints, field, fields)) + return { + datapoints: updateDatapoints( + path, + revisionData, + selectedRevisions, + field, + fields + ), + revisions: selectedRevisions + } } const fillTemplate = ( template: string, - datapoints: unknown[], - field?: string + datapoints: unknown[] ): TopLevelSpec => { - const isMultiView = isMultiViewPlot(JSON.parse(template)) - return JSON.parse( - template.replace( - '""', - stringifyDatapoints(datapoints, field, isMultiView) - ) + template.replace('""', JSON.stringify(datapoints)) ) as TopLevelSpec } @@ -562,30 +613,26 @@ const collectTemplateGroup = ( const template = templates[path] if (template) { - const datapoints = selectedRevisions - .flatMap(revision => revisionData?.[revision]?.[path]) - .filter(Boolean) - + const isMultiView = isMultiViewPlot(JSON.parse(template)) const multiSourceEncodingUpdate = multiSourceEncoding[path] || {} - - const content = extendVegaSpec( - fillTemplate( - template, - datapoints, - multiSourceEncodingUpdate.strokeDash?.field - ), - size, - { - ...multiSourceEncodingUpdate, - color: revisionColors - } + const { datapoints, revisions } = transformRevisionData( + path, + selectedRevisions, + revisionData, + isMultiView, + multiSourceEncodingUpdate ) + const content = extendVegaSpec(fillTemplate(template, datapoints), size, { + ...multiSourceEncodingUpdate, + color: revisionColors + }) + acc.push({ content, id: path, multiView: isMultiViewPlot(content), - revisions: selectedRevisions, + revisions, type: PlotsType.VEGA }) } diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index f34f03cbb0..22d3a52147 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -37,6 +37,7 @@ import { MessageFromWebviewType } from '../../../webview/contract' import { reorderObjectList } from '../../../util/array' import * as Telemetry from '../../../telemetry' import { EventName } from '../../../telemetry/constants' +import { SelectedExperimentWithColor } from '../../../experiments/model' suite('Plots Test Suite', () => { const disposable = Disposable.fn() @@ -761,10 +762,13 @@ suite('Plots Test Suite', () => { }).timeout(WEBVIEW_TEST_TIMEOUT) it('should send the correct data to the webview for flexible plots', async () => { - const { plots, messageSpy, mockPlotsDiff } = await buildPlots( - disposable, - multiSourcePlotsDiffFixture - ) + const { plots, messageSpy, mockPlotsDiff, experiments } = + await buildPlots(disposable, multiSourcePlotsDiffFixture) + + stub(experiments, 'getSelectedRevisions').returns([ + { label: 'workspace' }, + { label: 'main' } + ] as SelectedExperimentWithColor[]) const webview = await plots.showWebview() await webview.isReady() @@ -798,17 +802,6 @@ suite('Plots Test Suite', () => { multiViewSection.entries.map(({ id }: { id: string }) => id) ).to.deep.equal(['dvc.yaml::Confusion-Matrix']) - const [confusionMatrix] = multiViewSection.entries - - const confusionMatrixDatapoints = - ( - confusionMatrix.content.data as { - values: { rev: string }[] - } - )?.values || [] - - expect(confusionMatrixDatapoints.length).to.be.greaterThan(0) - const expectedRevisions = [ `main::${join('evaluation', 'test', 'plots', 'confusion_matrix.json')}`, `workspace::${join( @@ -829,7 +822,21 @@ suite('Plots Test Suite', () => { 'plots', 'confusion_matrix.json' )}` - ] + ].sort() + + const [confusionMatrix] = multiViewSection.entries + + const confusionMatrixDatapoints = + ( + confusionMatrix.content.data as { + values: { rev: string }[] + } + )?.values || [] + + expect(confusionMatrixDatapoints.length).to.be.greaterThan(0) + + expect(confusionMatrix.revisions?.length).to.equal(4) + expect(confusionMatrix.revisions?.sort()).to.deep.equal(expectedRevisions) for (const entry of confusionMatrixDatapoints) { expect(expectedRevisions).to.include(entry.rev) From 00cf8054681116a987d5ac54278b263e3aa83e0f Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 5 Oct 2022 18:14:15 +1100 Subject: [PATCH 2/2] refactor transform revision data --- extension/src/plots/model/collect.ts | 52 +++++++++++++--------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index d53a34f199..e4f42f0fe8 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -521,32 +521,17 @@ const updateDatapoints = ( ) .filter(Boolean) -const transformMultiSourceMultiViewRevisionData = ( - path: string, +const updateRevisions = ( selectedRevisions: string[], - revisionData: RevisionData, - fields: string[], domain: string[] -) => { - fields.unshift('rev') - +): string[] => { const revisions: string[] = [] for (const revision of selectedRevisions) { for (const entry of domain) { - revisions.push([revision, entry].join('::')) + revisions.push(mergeFields([revision, entry])) } } - - return { - datapoints: updateDatapoints( - path, - revisionData, - selectedRevisions, - 'rev', - fields - ), - revisions - } + return revisions } const transformRevisionData = ( @@ -557,7 +542,12 @@ const transformRevisionData = ( multiSourceEncodingUpdate: { strokeDash: StrokeDashEncoding } ): { revisions: string[]; datapoints: unknown[] } => { const field = multiSourceEncodingUpdate.strokeDash?.field - if (!field || (!isMultiView && !isConcatenatedField(field))) { + const isMultiSource = !!field + + const transformNeeded = + isMultiSource && (isMultiView || isConcatenatedField(field)) + + if (!transformNeeded) { return { datapoints: selectedRevisions .flatMap(revision => revisionData?.[revision]?.[path]) @@ -567,15 +557,21 @@ const transformRevisionData = ( } const fields = unmergeConcatenatedFields(field) - if (isMultiView) { - return transformMultiSourceMultiViewRevisionData( - path, - selectedRevisions, - revisionData, - fields, - multiSourceEncodingUpdate.strokeDash.scale.domain - ) + fields.unshift('rev') + return { + datapoints: updateDatapoints( + path, + revisionData, + selectedRevisions, + 'rev', + fields + ), + revisions: updateRevisions( + selectedRevisions, + multiSourceEncodingUpdate.strokeDash.scale.domain + ) + } } return {