Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension/src/experiments/model/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import { sum } from '../../util/math'

export type StarredExperiments = Record<string, boolean | undefined>

type SelectedExperimentWithColor = Experiment & {
export type SelectedExperimentWithColor = Experiment & {
displayColor: Color
selected: true
}
Expand Down
127 changes: 85 additions & 42 deletions extension/src/plots/model/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import {
MultiSourceEncoding,
unmergeConcatenatedFields
} from '../multiSource/collect'
import { StrokeDashEncoding } from '../multiSource/constants'

type CheckpointPlotAccumulator = {
iterations: Record<string, number>
Expand Down Expand Up @@ -502,49 +503,95 @@ export const collectTemplates = (data: PlotsOutput): TemplateAccumulator => {
}

const updateDatapoints = (
datapoints: unknown[],
path: string,
revisionData: RevisionData,
selectedRevisions: string[],
key: string,
fields: string[]
): unknown[] =>
datapoints.map(data => {
const obj = data as Record<string, unknown>
return {
...obj,
[key]: mergeFields(fields.map(field => obj[field] as string))
selectedRevisions
.flatMap(revision =>
revisionData?.[revision]?.[path].map(data => {
const obj = data as Record<string, unknown>
return {
...obj,
[key]: mergeFields(fields.map(field => obj[field] as string))
}
})
)
.filter(Boolean)

const updateRevisions = (
selectedRevisions: string[],
domain: string[]
): string[] => {
const revisions: string[] = []
for (const revision of selectedRevisions) {
for (const entry of domain) {
revisions.push(mergeFields([revision, entry]))
}
})
}
return revisions
}

const stringifyDatapoints = (
datapoints: unknown[],
field: string | undefined,
isMultiView: boolean
): string => {
if (!field || (!isMultiView && !isConcatenatedField(field))) {
return JSON.stringify(datapoints)
const transformRevisionData = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[F] This change is good because we are not looping over the datapoints twice anymore.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function transformRevisionData has 39 lines of code (exceeds 30 allowed). Consider refactoring.

path: string,
selectedRevisions: string[],
revisionData: RevisionData,
isMultiView: boolean,
multiSourceEncodingUpdate: { strokeDash: StrokeDashEncoding }
): { revisions: string[]; datapoints: unknown[] } => {
const field = multiSourceEncodingUpdate.strokeDash?.field
const isMultiSource = !!field

const transformNeeded =
isMultiSource && (isMultiView || isConcatenatedField(field))

if (!transformNeeded) {
return {
datapoints: selectedRevisions
.flatMap(revision => revisionData?.[revision]?.[path])
.filter(Boolean),
revisions: selectedRevisions
}
}

const fields = unmergeConcatenatedFields(field)

if (isMultiView) {
fields.unshift('rev')
return JSON.stringify(updateDatapoints(datapoints, 'rev', fields))
return {
datapoints: updateDatapoints(
path,
revisionData,
selectedRevisions,
'rev',
fields
),
revisions: updateRevisions(
selectedRevisions,
multiSourceEncodingUpdate.strokeDash.scale.domain
)
}
}

return JSON.stringify(updateDatapoints(datapoints, field, fields))
return {
datapoints: updateDatapoints(
path,
revisionData,
selectedRevisions,
field,
fields
),
revisions: selectedRevisions
}
}

const fillTemplate = (
template: string,
datapoints: unknown[],
field?: string
datapoints: unknown[]
): TopLevelSpec => {
const isMultiView = isMultiViewPlot(JSON.parse(template))

return JSON.parse(
template.replace(
'"<DVC_METRIC_DATA>"',
stringifyDatapoints(datapoints, field, isMultiView)
)
template.replace('"<DVC_METRIC_DATA>"', JSON.stringify(datapoints))
) as TopLevelSpec
}

Expand All @@ -562,30 +609,26 @@ const collectTemplateGroup = (
const template = templates[path]

if (template) {
const datapoints = selectedRevisions
.flatMap(revision => revisionData?.[revision]?.[path])
.filter(Boolean)

const isMultiView = isMultiViewPlot(JSON.parse(template))
const multiSourceEncodingUpdate = multiSourceEncoding[path] || {}

const content = extendVegaSpec(
fillTemplate(
template,
datapoints,
multiSourceEncodingUpdate.strokeDash?.field
),
size,
{
...multiSourceEncodingUpdate,
color: revisionColors
}
const { datapoints, revisions } = transformRevisionData(
path,
selectedRevisions,
revisionData,
isMultiView,
multiSourceEncodingUpdate
)

const content = extendVegaSpec(fillTemplate(template, datapoints), size, {
...multiSourceEncodingUpdate,
color: revisionColors
})

acc.push({
content,
id: path,
multiView: isMultiViewPlot(content),
revisions: selectedRevisions,
revisions,
type: PlotsType.VEGA
})
}
Expand Down
39 changes: 23 additions & 16 deletions extension/src/test/suite/plots/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import { MessageFromWebviewType } from '../../../webview/contract'
import { reorderObjectList } from '../../../util/array'
import * as Telemetry from '../../../telemetry'
import { EventName } from '../../../telemetry/constants'
import { SelectedExperimentWithColor } from '../../../experiments/model'

suite('Plots Test Suite', () => {
const disposable = Disposable.fn()
Expand Down Expand Up @@ -761,10 +762,13 @@ suite('Plots Test Suite', () => {
}).timeout(WEBVIEW_TEST_TIMEOUT)

it('should send the correct data to the webview for flexible plots', async () => {
const { plots, messageSpy, mockPlotsDiff } = await buildPlots(
disposable,
multiSourcePlotsDiffFixture
)
const { plots, messageSpy, mockPlotsDiff, experiments } =
await buildPlots(disposable, multiSourcePlotsDiffFixture)

stub(experiments, 'getSelectedRevisions').returns([
{ label: 'workspace' },
{ label: 'main' }
] as SelectedExperimentWithColor[])

const webview = await plots.showWebview()
await webview.isReady()
Expand Down Expand Up @@ -798,17 +802,6 @@ suite('Plots Test Suite', () => {
multiViewSection.entries.map(({ id }: { id: string }) => id)
).to.deep.equal(['dvc.yaml::Confusion-Matrix'])

const [confusionMatrix] = multiViewSection.entries

const confusionMatrixDatapoints =
(
confusionMatrix.content.data as {
values: { rev: string }[]
}
)?.values || []

expect(confusionMatrixDatapoints.length).to.be.greaterThan(0)

const expectedRevisions = [
`main::${join('evaluation', 'test', 'plots', 'confusion_matrix.json')}`,
`workspace::${join(
Expand All @@ -829,7 +822,21 @@ suite('Plots Test Suite', () => {
'plots',
'confusion_matrix.json'
)}`
]
].sort()

const [confusionMatrix] = multiViewSection.entries

const confusionMatrixDatapoints =
(
confusionMatrix.content.data as {
values: { rev: string }[]
}
)?.values || []

expect(confusionMatrixDatapoints.length).to.be.greaterThan(0)

expect(confusionMatrix.revisions?.length).to.equal(4)
expect(confusionMatrix.revisions?.sort()).to.deep.equal(expectedRevisions)

for (const entry of confusionMatrixDatapoints) {
expect(expectedRevisions).to.include(entry.rev)
Expand Down