diff --git a/extension/src/experiments/model/index.ts b/extension/src/experiments/model/index.ts index d1d91d0eab..11faeeaaeb 100644 --- a/extension/src/experiments/model/index.ts +++ b/extension/src/experiments/model/index.ts @@ -39,7 +39,7 @@ import { sum } from '../../util/math' export type StarredExperiments = Record -type SelectedExperimentWithColor = Experiment & { +export type SelectedExperimentWithColor = Experiment & { displayColor: Color selected: true } diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 1440a00539..e4f42f0fe8 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -40,6 +40,7 @@ import { MultiSourceEncoding, unmergeConcatenatedFields } from '../multiSource/collect' +import { StrokeDashEncoding } from '../multiSource/constants' type CheckpointPlotAccumulator = { iterations: Record @@ -502,49 +503,95 @@ export const collectTemplates = (data: PlotsOutput): TemplateAccumulator => { } const updateDatapoints = ( - datapoints: unknown[], + path: string, + revisionData: RevisionData, + selectedRevisions: string[], key: string, fields: string[] ): unknown[] => - datapoints.map(data => { - const obj = data as Record - return { - ...obj, - [key]: mergeFields(fields.map(field => obj[field] as string)) + selectedRevisions + .flatMap(revision => + revisionData?.[revision]?.[path].map(data => { + const obj = data as Record + return { + ...obj, + [key]: mergeFields(fields.map(field => obj[field] as string)) + } + }) + ) + .filter(Boolean) + +const updateRevisions = ( + selectedRevisions: string[], + domain: string[] +): string[] => { + const revisions: string[] = [] + for (const revision of selectedRevisions) { + for (const entry of domain) { + revisions.push(mergeFields([revision, entry])) } - }) + } + return revisions +} -const stringifyDatapoints = ( - datapoints: unknown[], - field: string | undefined, - isMultiView: boolean -): string => { - if (!field || (!isMultiView && !isConcatenatedField(field))) { - return JSON.stringify(datapoints) +const transformRevisionData = ( + path: string, + selectedRevisions: string[], + revisionData: RevisionData, + isMultiView: boolean, + multiSourceEncodingUpdate: { strokeDash: StrokeDashEncoding } +): { revisions: string[]; datapoints: unknown[] } => { + const field = multiSourceEncodingUpdate.strokeDash?.field + const isMultiSource = !!field + + const transformNeeded = + isMultiSource && (isMultiView || isConcatenatedField(field)) + + if (!transformNeeded) { + return { + datapoints: selectedRevisions + .flatMap(revision => revisionData?.[revision]?.[path]) + .filter(Boolean), + revisions: selectedRevisions + } } const fields = unmergeConcatenatedFields(field) - if (isMultiView) { fields.unshift('rev') - return JSON.stringify(updateDatapoints(datapoints, 'rev', fields)) + return { + datapoints: updateDatapoints( + path, + revisionData, + selectedRevisions, + 'rev', + fields + ), + revisions: updateRevisions( + selectedRevisions, + multiSourceEncodingUpdate.strokeDash.scale.domain + ) + } } - return JSON.stringify(updateDatapoints(datapoints, field, fields)) + return { + datapoints: updateDatapoints( + path, + revisionData, + selectedRevisions, + field, + fields + ), + revisions: selectedRevisions + } } const fillTemplate = ( template: string, - datapoints: unknown[], - field?: string + datapoints: unknown[] ): TopLevelSpec => { - const isMultiView = isMultiViewPlot(JSON.parse(template)) - return JSON.parse( - template.replace( - '""', - stringifyDatapoints(datapoints, field, isMultiView) - ) + template.replace('""', JSON.stringify(datapoints)) ) as TopLevelSpec } @@ -562,30 +609,26 @@ const collectTemplateGroup = ( const template = templates[path] if (template) { - const datapoints = selectedRevisions - .flatMap(revision => revisionData?.[revision]?.[path]) - .filter(Boolean) - + const isMultiView = isMultiViewPlot(JSON.parse(template)) const multiSourceEncodingUpdate = multiSourceEncoding[path] || {} - - const content = extendVegaSpec( - fillTemplate( - template, - datapoints, - multiSourceEncodingUpdate.strokeDash?.field - ), - size, - { - ...multiSourceEncodingUpdate, - color: revisionColors - } + const { datapoints, revisions } = transformRevisionData( + path, + selectedRevisions, + revisionData, + isMultiView, + multiSourceEncodingUpdate ) + const content = extendVegaSpec(fillTemplate(template, datapoints), size, { + ...multiSourceEncodingUpdate, + color: revisionColors + }) + acc.push({ content, id: path, multiView: isMultiViewPlot(content), - revisions: selectedRevisions, + revisions, type: PlotsType.VEGA }) } diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index f34f03cbb0..22d3a52147 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -37,6 +37,7 @@ import { MessageFromWebviewType } from '../../../webview/contract' import { reorderObjectList } from '../../../util/array' import * as Telemetry from '../../../telemetry' import { EventName } from '../../../telemetry/constants' +import { SelectedExperimentWithColor } from '../../../experiments/model' suite('Plots Test Suite', () => { const disposable = Disposable.fn() @@ -761,10 +762,13 @@ suite('Plots Test Suite', () => { }).timeout(WEBVIEW_TEST_TIMEOUT) it('should send the correct data to the webview for flexible plots', async () => { - const { plots, messageSpy, mockPlotsDiff } = await buildPlots( - disposable, - multiSourcePlotsDiffFixture - ) + const { plots, messageSpy, mockPlotsDiff, experiments } = + await buildPlots(disposable, multiSourcePlotsDiffFixture) + + stub(experiments, 'getSelectedRevisions').returns([ + { label: 'workspace' }, + { label: 'main' } + ] as SelectedExperimentWithColor[]) const webview = await plots.showWebview() await webview.isReady() @@ -798,17 +802,6 @@ suite('Plots Test Suite', () => { multiViewSection.entries.map(({ id }: { id: string }) => id) ).to.deep.equal(['dvc.yaml::Confusion-Matrix']) - const [confusionMatrix] = multiViewSection.entries - - const confusionMatrixDatapoints = - ( - confusionMatrix.content.data as { - values: { rev: string }[] - } - )?.values || [] - - expect(confusionMatrixDatapoints.length).to.be.greaterThan(0) - const expectedRevisions = [ `main::${join('evaluation', 'test', 'plots', 'confusion_matrix.json')}`, `workspace::${join( @@ -829,7 +822,21 @@ suite('Plots Test Suite', () => { 'plots', 'confusion_matrix.json' )}` - ] + ].sort() + + const [confusionMatrix] = multiViewSection.entries + + const confusionMatrixDatapoints = + ( + confusionMatrix.content.data as { + values: { rev: string }[] + } + )?.values || [] + + expect(confusionMatrixDatapoints.length).to.be.greaterThan(0) + + expect(confusionMatrix.revisions?.length).to.equal(4) + expect(confusionMatrix.revisions?.sort()).to.deep.equal(expectedRevisions) for (const entry of confusionMatrixDatapoints) { expect(expectedRevisions).to.include(entry.rev)