Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions extension/src/plots/paths/model.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import { join } from 'path'
import { PathsModel } from './model'
import { PathType } from './collect'
import plotsDiffFixture from '../../test/fixtures/plotsDiff/output'
import { buildMockMemento } from '../../test/util'
import { TemplatePlotGroup } from '../webview/contract'

describe('PathsModel', () => {
const mockDvcRoot = 'test'

const logsAcc = join('logs', 'acc.tsv')
const logsLoss = join('logs', 'loss.tsv')
const plotsAcc = join('plots', 'acc.png')

it('should return the expected columns when given the default output fixture', () => {
const comparisonType = new Set([PathType.COMPARISON])
const singleType = new Set([PathType.TEMPLATE_SINGLE])
const multiType = new Set([PathType.TEMPLATE_MULTI])

const model = new PathsModel(mockDvcRoot, buildMockMemento())
model.transformAndSet(plotsDiffFixture)
expect(model.getTerminalNodes()).toStrictEqual([
{
hasChildren: false,
label: 'acc.png',
parentPath: 'plots',
path: plotsAcc,
selected: true,
type: comparisonType
},
{
hasChildren: false,
label: 'heatmap.png',
parentPath: 'plots',
path: join('plots', 'heatmap.png'),
selected: true,
type: comparisonType
},
{
hasChildren: false,
label: 'loss.png',
parentPath: 'plots',
path: join('plots', 'loss.png'),
selected: true,
type: comparisonType
},
{
hasChildren: false,
label: 'loss.tsv',
parentPath: 'logs',
path: logsLoss,
selected: true,
type: singleType
},
{
hasChildren: false,
label: 'acc.tsv',
parentPath: 'logs',
path: logsAcc,
selected: true,
type: singleType
},
{
hasChildren: false,
label: 'predictions.json',
parentPath: undefined,
path: 'predictions.json',
selected: true,
type: multiType
}
])
})

const multiViewGroup = {
group: TemplatePlotGroup.MULTI_VIEW,
paths: ['predictions.json']
}
const originalSingleViewGroup = {
group: TemplatePlotGroup.SINGLE_VIEW,
paths: [logsLoss, logsAcc]
}

const logsAccGroup = {
group: TemplatePlotGroup.SINGLE_VIEW,
paths: [logsAcc]
}

const logsLossGroup = {
group: TemplatePlotGroup.SINGLE_VIEW,
paths: [logsLoss]
}

const originalTemplateOrder = [originalSingleViewGroup, multiViewGroup]

it('should retain the order of template paths when they are unselected', () => {
const model = new PathsModel(mockDvcRoot, buildMockMemento())
model.transformAndSet(plotsDiffFixture)

expect(model.getTemplateOrder()).toStrictEqual(originalTemplateOrder)

model.toggleStatus(logsAcc)

const newOrder = model.getTemplateOrder()

expect(newOrder).toStrictEqual([logsLossGroup, multiViewGroup])

model.toggleStatus(logsAcc)

expect(model.getTemplateOrder()).toStrictEqual(originalTemplateOrder)
})

it('should move unselected plots to the end when a reordering occurs', () => {
const model = new PathsModel(mockDvcRoot, buildMockMemento())
model.transformAndSet(plotsDiffFixture)

expect(model.getTemplateOrder()).toStrictEqual(originalTemplateOrder)

model.toggleStatus(logsAcc)

const newOrder = model.getTemplateOrder()

expect(newOrder).toStrictEqual([logsLossGroup, multiViewGroup])

model.setTemplateOrder([multiViewGroup, logsLossGroup])

model.toggleStatus(logsAcc)

expect(model.getTemplateOrder()).toStrictEqual([
multiViewGroup,
{
group: TemplatePlotGroup.SINGLE_VIEW,
paths: [logsLoss, logsAcc]
}
])
})

it('should merge template plots groups when a path is unselected', () => {
const model = new PathsModel(mockDvcRoot, buildMockMemento())
model.transformAndSet(plotsDiffFixture)

model.setTemplateOrder([logsLossGroup, logsAccGroup, multiViewGroup])

model.toggleStatus('predictions.json')

expect(model.getTemplateOrder()).toStrictEqual([originalSingleViewGroup])
})
})
23 changes: 16 additions & 7 deletions extension/src/plots/paths/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ export class PathsModel extends PathSelectionModel<PlotPath> {
}

public setTemplateOrder(templateOrder?: TemplateOrder) {
const filter = (type: PathType, plotPath: PlotPath) =>
!!plotPath.type?.has(type)

this.templateOrder = collectTemplateOrder(
this.getPathsByType(PathType.TEMPLATE_SINGLE),
this.getPathsByType(PathType.TEMPLATE_MULTI),
this.getPathsByType(PathType.TEMPLATE_SINGLE, filter),
this.getPathsByType(PathType.TEMPLATE_MULTI, filter),
templateOrder || this.templateOrder
)

Expand All @@ -51,7 +54,11 @@ export class PathsModel extends PathSelectionModel<PlotPath> {
}

public getTemplateOrder(): TemplateOrder {
return this.templateOrder
return collectTemplateOrder(
this.getPathsByType(PathType.TEMPLATE_SINGLE),
this.getPathsByType(PathType.TEMPLATE_MULTI),
this.templateOrder
)
}

public getComparisonPaths() {
Expand All @@ -62,11 +69,13 @@ export class PathsModel extends PathSelectionModel<PlotPath> {
return this.data.length > 0
}

private getPathsByType(type: PathType) {
private getPathsByType(
type: PathType,
filter = (type: PathType, plotPath: PlotPath) =>
!!(plotPath.type?.has(type) && this.status[plotPath.path])
) {
return this.data
.filter(
plotPath => plotPath.type?.has(type) && this.status[plotPath.path]
)
.filter(plotPath => filter(type, plotPath))
.map(({ path }) => path)
}
}