From 9ff4c3c77d24d60b716ff71f4739c4435ede2faa Mon Sep 17 00:00:00 2001 From: Riley Jones Date: Fri, 28 Apr 2023 20:37:00 +0000 Subject: [PATCH 1/6] add common selector to determine which runs should be shown --- .../webapp/metrics/views/main_view/BUILD | 4 + .../views/main_view/common_selectors.ts | 144 ++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/tensorboard/webapp/metrics/views/main_view/BUILD b/tensorboard/webapp/metrics/views/main_view/BUILD index 41b28403c7..9d0759b498 100644 --- a/tensorboard/webapp/metrics/views/main_view/BUILD +++ b/tensorboard/webapp/metrics/views/main_view/BUILD @@ -81,11 +81,15 @@ tf_ts_library( deps = [ "//tensorboard/webapp:app_state", "//tensorboard/webapp:selectors", + "//tensorboard/webapp/hparams:types", + "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/metrics:utils", "//tensorboard/webapp/metrics/data_source", "//tensorboard/webapp/metrics/store", "//tensorboard/webapp/metrics/views:types", "//tensorboard/webapp/metrics/views:utils", + "//tensorboard/webapp/runs/views/runs_table:types", + "//tensorboard/webapp/util:matcher", "//tensorboard/webapp/util:types", "@npm//@ngrx/store", ], diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index a8967720af..fc6a39dc18 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -18,8 +18,26 @@ import { getCurrentRouteRunSelection, getMetricsHideEmptyCards, getMetricsTagMetadata, + getExperimentIdsFromRoute, + getExperimentIdToExperimentAliasMap, + getExperimentNames, + getRunColorMap, + getRunSelectorRegexFilter, + getRunsFromExperimentIds, } from '../../../selectors'; import {DeepReadonly} from '../../../util/types'; +import { + getHparamFilterMapFromExperimentIds, + getMetricFilterMapFromExperimentIds, +} from '../../../hparams/_redux/hparams_selectors'; +import { + DiscreteFilter, + DiscreteHparamValue, + DomainType, + IntervalFilter, +} from '../../../hparams/types'; +import {RunTableItem} from '../../../runs/views/runs_table/types'; +import {matchRunToRegex} from '../../../util/matcher'; import {isSingleRunPlugin, PluginType} from '../../data_source'; import {getNonEmptyCardIdsWithMetadata, TagMetadata} from '../../store'; import {compareTagNames} from '../../utils'; @@ -86,6 +104,132 @@ export const getSortedRenderableCardIdsWithMetadata = createSelector< }); }); +export function getRenderableRuns(experimentIds: string[]) { + return createSelector( + getRunsFromExperimentIds(experimentIds), + getExperimentNames(experimentIds), + getCurrentRouteRunSelection, + getRunColorMap, + getExperimentIdToExperimentAliasMap, + (runs, experimentNames, selectionMap, colorMap, experimentIdToAlias) => { + return runs.map((run) => { + const hparamMap: RunTableItem['hparams'] = new Map(); + (run.hparams || []).forEach((hparam) => { + hparamMap.set(hparam.name, hparam.value); + }); + const metricMap: RunTableItem['metrics'] = new Map(); + (run.metrics || []).forEach((metric) => { + metricMap.set(metric.tag, metric.value); + }); + return { + run, + experimentName: experimentNames[run.experimentId] || '', + experimentAlias: experimentIdToAlias[run.experimentId], + selected: Boolean(selectionMap && selectionMap.get(run.id)), + runColor: colorMap[run.id], + hparams: hparamMap, + metrics: metricMap, + }; + }); + } + ); +} + +function filterRunItemsByRegex(runItems: RunTableItem[], regexString: string) { + if (!regexString) { + return runItems; + } + + // DO_NOT_SUBMIT + // const shouldIncludeExperimentName = this.columns.includes( + // RunsTableColumn.EXPERIMENT_NAME + // ); + const shouldIncludeExperimentName = false; + return runItems.filter((item) => { + return matchRunToRegex( + { + runName: item.run.name, + experimentAlias: item.experimentAlias, + }, + regexString, + shouldIncludeExperimentName + ); + }); +} + +function matchFilter( + filter: DiscreteFilter | IntervalFilter, + value: number | DiscreteHparamValue | undefined +): boolean { + if (value === undefined) { + return filter.includeUndefined; + } + if (filter.type === DomainType.DISCRETE) { + // (upcast to work around bad TypeScript libdefs) + const values: Readonly> = + filter.filterValues; + return values.includes(value); + } else if (filter.type === DomainType.INTERVAL) { + // Auto-added to unblock TS5.0 migration + // @ts-ignore(go/ts50upgrade): Operator '<=' cannot be applied to types + // 'number' and 'string | number | boolean'. + // Auto-added to unblock TS5.0 migration + // @ts-ignore(go/ts50upgrade): Operator '<=' cannot be applied to types + // 'string | number | boolean' and 'number'. + return filter.filterLowerValue <= value && value <= filter.filterUpperValue; + } + return false; +} + +function filterRunItemsByHparamAndMetricFilter( + runItems: RunTableItem[], + hparamFilters: Map, + metricFilters: Map +) { + return runItems.filter(({hparams, metrics}) => { + const hparamMatches = [...hparamFilters.entries()].every( + ([hparamName, filter]) => { + const value = hparams.get(hparamName); + return matchFilter(filter, value); + } + ); + + return ( + hparamMatches && + [...metricFilters.entries()].every(([metricTag, filter]) => { + const value = metrics.get(metricTag); + return matchFilter(filter, value); + }) + ); + }); +} + +export function getFilteredRenderableRuns(experimentIds: string[]) { + return createSelector( + getRunSelectorRegexFilter, + getRenderableRuns(experimentIds), + getHparamFilterMapFromExperimentIds(experimentIds), + getMetricFilterMapFromExperimentIds(experimentIds), + (regexFilter, runItems, hparamFilters, metricFilters) => { + const regexFilteredItems = filterRunItemsByRegex(runItems, regexFilter); + + return filterRunItemsByHparamAndMetricFilter( + regexFilteredItems, + hparamFilters, + metricFilters + ); + } + ); +} + +export const getFilteredRenderableRunsFromRoute = createSelector( + (state) => state, + getExperimentIdsFromRoute, + (state, experimentIds) => { + return getFilteredRenderableRuns(experimentIds || [])(state); + } +); + export const TEST_ONLY = { getRenderableCardIdsWithMetadata, getScalarTagsForRunSelection, From bdbb4751c725bc51343f9d05c8464c488774a79d Mon Sep 17 00:00:00 2001 From: Riley Jones Date: Thu, 4 May 2023 00:20:12 +0000 Subject: [PATCH 2/6] moar tests --- .../webapp/metrics/views/main_view/BUILD | 5 + .../views/main_view/common_selectors.ts | 209 ++++---- .../views/main_view/common_selectors_test.ts | 471 +++++++++++++++++- 3 files changed, 585 insertions(+), 100 deletions(-) diff --git a/tensorboard/webapp/metrics/views/main_view/BUILD b/tensorboard/webapp/metrics/views/main_view/BUILD index 9d0759b498..0a50296375 100644 --- a/tensorboard/webapp/metrics/views/main_view/BUILD +++ b/tensorboard/webapp/metrics/views/main_view/BUILD @@ -81,6 +81,7 @@ tf_ts_library( deps = [ "//tensorboard/webapp:app_state", "//tensorboard/webapp:selectors", + "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/hparams:types", "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/metrics:utils", @@ -189,6 +190,8 @@ tf_ts_library( "//tensorboard/webapp/app_routing:testing", "//tensorboard/webapp/app_routing/store:testing", "//tensorboard/webapp/customization", + "//tensorboard/webapp/experiments/store:testing", + "//tensorboard/webapp/hparams:types", "//tensorboard/webapp/metrics:test_lib", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics/actions", @@ -196,8 +199,10 @@ tf_ts_library( "//tensorboard/webapp/metrics/store", "//tensorboard/webapp/metrics/views:types", "//tensorboard/webapp/metrics/views/card_renderer", + "//tensorboard/webapp/runs/store:selectors", "//tensorboard/webapp/runs/store:testing", "//tensorboard/webapp/runs/store:types", + "//tensorboard/webapp/runs/views/runs_table:types", "//tensorboard/webapp/settings", "//tensorboard/webapp/testing:dom", "//tensorboard/webapp/testing:mat_icon", diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index fc6a39dc18..f6b4caa6df 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -23,6 +23,7 @@ import { getExperimentNames, getRunColorMap, getRunSelectorRegexFilter, + getRouteKind, getRunsFromExperimentIds, } from '../../../selectors'; import {DeepReadonly} from '../../../util/types'; @@ -42,6 +43,7 @@ import {isSingleRunPlugin, PluginType} from '../../data_source'; import {getNonEmptyCardIdsWithMetadata, TagMetadata} from '../../store'; import {compareTagNames} from '../../utils'; import {CardIdWithMetadata} from '../metrics_view_types'; +import {RouteKind} from '../../../app_routing/types'; export const getScalarTagsForRunSelection = createSelector( getMetricsTagMetadata, @@ -104,116 +106,124 @@ export const getSortedRenderableCardIdsWithMetadata = createSelector< }); }); -export function getRenderableRuns(experimentIds: string[]) { - return createSelector( - getRunsFromExperimentIds(experimentIds), - getExperimentNames(experimentIds), - getCurrentRouteRunSelection, - getRunColorMap, - getExperimentIdToExperimentAliasMap, - (runs, experimentNames, selectionMap, colorMap, experimentIdToAlias) => { - return runs.map((run) => { - const hparamMap: RunTableItem['hparams'] = new Map(); - (run.hparams || []).forEach((hparam) => { - hparamMap.set(hparam.name, hparam.value); - }); - const metricMap: RunTableItem['metrics'] = new Map(); - (run.metrics || []).forEach((metric) => { - metricMap.set(metric.tag, metric.value); +const utils = { + getRenderableRuns(experimentIds: string[]) { + return createSelector( + getRunsFromExperimentIds(experimentIds), + getExperimentNames(experimentIds), + getCurrentRouteRunSelection, + getRunColorMap, + getExperimentIdToExperimentAliasMap, + (runs, experimentNames, selectionMap, colorMap, experimentIdToAlias) => { + return runs.map((run) => { + const hparamMap: RunTableItem['hparams'] = new Map(); + (run.hparams || []).forEach((hparam) => { + hparamMap.set(hparam.name, hparam.value); + }); + const metricMap: RunTableItem['metrics'] = new Map(); + (run.metrics || []).forEach((metric) => { + metricMap.set(metric.tag, metric.value); + }); + return { + run, + experimentName: experimentNames[run.experimentId] || '', + experimentAlias: experimentIdToAlias[run.experimentId], + selected: Boolean(selectionMap && selectionMap.get(run.id)), + runColor: colorMap[run.id], + hparams: hparamMap, + metrics: metricMap, + }; }); - return { - run, - experimentName: experimentNames[run.experimentId] || '', - experimentAlias: experimentIdToAlias[run.experimentId], - selected: Boolean(selectionMap && selectionMap.get(run.id)), - runColor: colorMap[run.id], - hparams: hparamMap, - metrics: metricMap, - }; - }); - } - ); -} - -function filterRunItemsByRegex(runItems: RunTableItem[], regexString: string) { - if (!regexString) { - return runItems; - } - - // DO_NOT_SUBMIT - // const shouldIncludeExperimentName = this.columns.includes( - // RunsTableColumn.EXPERIMENT_NAME - // ); - const shouldIncludeExperimentName = false; - return runItems.filter((item) => { - return matchRunToRegex( - { - runName: item.run.name, - experimentAlias: item.experimentAlias, - }, - regexString, - shouldIncludeExperimentName - ); - }); -} - -function matchFilter( - filter: DiscreteFilter | IntervalFilter, - value: number | DiscreteHparamValue | undefined -): boolean { - if (value === undefined) { - return filter.includeUndefined; - } - if (filter.type === DomainType.DISCRETE) { - // (upcast to work around bad TypeScript libdefs) - const values: Readonly> = - filter.filterValues; - return values.includes(value); - } else if (filter.type === DomainType.INTERVAL) { - // Auto-added to unblock TS5.0 migration - // @ts-ignore(go/ts50upgrade): Operator '<=' cannot be applied to types - // 'number' and 'string | number | boolean'. - // Auto-added to unblock TS5.0 migration - // @ts-ignore(go/ts50upgrade): Operator '<=' cannot be applied to types - // 'string | number | boolean' and 'number'. - return filter.filterLowerValue <= value && value <= filter.filterUpperValue; - } - return false; -} - -function filterRunItemsByHparamAndMetricFilter( - runItems: RunTableItem[], - hparamFilters: Map, - metricFilters: Map -) { - return runItems.filter(({hparams, metrics}) => { - const hparamMatches = [...hparamFilters.entries()].every( - ([hparamName, filter]) => { - const value = hparams.get(hparamName); - return matchFilter(filter, value); } ); + }, + + filterRunItemsByRegex( + runItems: RunTableItem[], + regexString: string, + shouldIncludeExperimentName: boolean + ) { + if (!regexString) { + return runItems; + } - return ( - hparamMatches && - [...metricFilters.entries()].every(([metricTag, filter]) => { - const value = metrics.get(metricTag); - return matchFilter(filter, value); - }) - ); - }); -} + return runItems.filter((item) => { + return matchRunToRegex( + { + runName: item.run.name, + experimentAlias: item.experimentAlias, + }, + regexString, + shouldIncludeExperimentName + ); + }); + }, + + matchFilter( + filter: DiscreteFilter | IntervalFilter, + value: number | DiscreteHparamValue | undefined + ): boolean { + if (value === undefined) { + return filter.includeUndefined; + } + if (filter.type === DomainType.DISCRETE) { + // (upcast to work around bad TypeScript libdefs) + const values: Readonly> = + filter.filterValues; + return values.includes(value); + } else if (filter.type === DomainType.INTERVAL) { + // Auto-added to unblock TS5.0 migration + // @ts-ignore(go/ts50upgrade): Operator '<=' cannot be applied to types + // 'number' and 'string | number | boolean'. + // Auto-added to unblock TS5.0 migration + // @ts-ignore(go/ts50upgrade): Operator '<=' cannot be applied to types + // 'string | number | boolean' and 'number'. + return ( + filter.filterLowerValue <= value && value <= filter.filterUpperValue + ); + } + return false; + }, + + filterRunItemsByHparamAndMetricFilter( + runItems: RunTableItem[], + hparamFilters: Map, + metricFilters: Map + ) { + return runItems.filter(({hparams, metrics}) => { + const hparamMatches = [...hparamFilters.entries()].every( + ([hparamName, filter]) => { + const value = hparams.get(hparamName); + return utils.matchFilter(filter, value); + } + ); + + return ( + hparamMatches && + [...metricFilters.entries()].every(([metricTag, filter]) => { + const value = metrics.get(metricTag); + return utils.matchFilter(filter, value); + }) + ); + }); + }, +}; export function getFilteredRenderableRuns(experimentIds: string[]) { return createSelector( getRunSelectorRegexFilter, - getRenderableRuns(experimentIds), + utils.getRenderableRuns(experimentIds), getHparamFilterMapFromExperimentIds(experimentIds), getMetricFilterMapFromExperimentIds(experimentIds), - (regexFilter, runItems, hparamFilters, metricFilters) => { - const regexFilteredItems = filterRunItemsByRegex(runItems, regexFilter); + getRouteKind, + (regexFilter, runItems, hparamFilters, metricFilters, routeKind) => { + const regexFilteredItems = utils.filterRunItemsByRegex( + runItems, + regexFilter, + routeKind === RouteKind.COMPARE_EXPERIMENT + ); - return filterRunItemsByHparamAndMetricFilter( + return utils.filterRunItemsByHparamAndMetricFilter( regexFilteredItems, hparamFilters, metricFilters @@ -233,4 +243,5 @@ export const getFilteredRenderableRunsFromRoute = createSelector( export const TEST_ONLY = { getRenderableCardIdsWithMetadata, getScalarTagsForRunSelection, + utils, }; diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 46cef42f6f..61de01323d 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -18,11 +18,16 @@ import { buildStateFromAppRoutingState, } from '../../../app_routing/store/testing'; import {buildRoute} from '../../../app_routing/testing'; -import {Run} from '../../../runs/store/runs_types'; +import {buildExperiment} from '../../../experiments/store/testing'; +import {IntervalFilter, DiscreteFilter} from '../../../hparams/types'; +import {DomainType, Run} from '../../../runs/store/runs_types'; import { + buildRun, buildRunsState, buildStateFromRunsState, } from '../../../runs/store/testing'; +import {RunTableItem} from '../../../runs/views/runs_table/types'; +import {buildMockState} from '../../../testing/utils'; import { appStateFromMetricsState, buildMetricsSettingsState, @@ -35,6 +40,15 @@ describe('common selectors', () => { let runIds: Record; let runIdToExpId: Record; let runMetadata: Record; + + let runTableItems: RunTableItem[]; + + let run1: Run; + let run2: Run; + let run3: Run; + let run4: Run; + let state: ReturnType; + beforeEach(() => { runIds = {defaultExperimentId: ['run1', 'run2', 'run3']}; runIdToExpId = { @@ -65,6 +79,114 @@ describe('common selectors', () => { metrics: null, }, }; + + runTableItems = [ + { + run: { + id: 'run1-id', + name: 'run1', + startTime: 0, + hparams: [ + { + name: 'accurracy', + value: 1, + }, + ], + metrics: null, + }, + experimentAlias: { + aliasNumber: 1, + aliasText: 'exp1', + }, + experimentName: 'experiment1', + selected: true, + hparams: new Map([['lr', 5]]), + metrics: new Map([['foo', 1]]), + }, + { + run: { + id: 'run2-id', + name: 'run2', + startTime: 0, + hparams: [ + { + name: 'accurracy', + value: 1, + }, + ], + metrics: null, + }, + experimentAlias: { + aliasNumber: 1, + aliasText: 'exp1', + }, + experimentName: 'experiment1', + selected: true, + hparams: new Map([['lr', 3]]), + metrics: new Map([['foo', 2]]), + }, + { + run: { + id: 'run1-id', + name: 'run1', + startTime: 0, + hparams: [ + { + name: 'accurracy', + value: 1, + }, + ], + metrics: null, + }, + experimentAlias: { + aliasNumber: 1, + aliasText: 'exp2', + }, + experimentName: 'experiment2', + selected: true, + hparams: new Map([['lr', 1]]), + metrics: new Map([['foo', 3]]), + }, + ]; + + run1 = buildRun({name: 'run 1'}); + run2 = buildRun({id: '2', name: 'run 2'}); + run3 = buildRun({id: '3', name: 'run 3'}); + run4 = buildRun({id: '4', name: 'run 4'}); + state = buildMockState({ + runs: { + data: { + regexFilter: '', + runIds: { + exp1: ['run1', 'run2'], + exp2: ['run2', 'run3', 'run4'], + }, + runMetadata: { + run1, + run2, + run3, + run4, + }, + } as any, + ui: {} as any, + }, + experiments: { + data: { + experimentMap: { + exp1: buildExperiment({name: 'experiment1', id: 'exp1'}), + exp2: buildExperiment({name: 'experiment2', id: 'exp2'}), + }, + }, + }, + app_routing: { + activeRoute: { + routeKind: RouteKind.EXPERIMENT, + params: { + experimentIds: 'foo:exp1,bar:exp2', + }, + }, + } as any, + }); }); describe('getScalarTagsForRunSelection', () => { @@ -430,4 +552,351 @@ describe('common selectors', () => { ]); }); }); + + describe('matchFilter', () => { + it('respects includeUndefined when value is undefined', () => { + expect( + selectors.TEST_ONLY.utils.matchFilter( + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 1, + filterUpperValue: 5, + includeUndefined: true, + }, + undefined + ) + ).toBeTrue(); + + expect( + selectors.TEST_ONLY.utils.matchFilter( + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 1, + filterUpperValue: 5, + includeUndefined: false, + }, + undefined + ) + ).toBeFalse(); + }); + + it('returns values including value when filter type is DISCRETE', () => { + const filter: DiscreteFilter = { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [], + filterValues: ['afoo', 'foob', 'foo', 'fo'], + }; + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 'foo')).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 'bar')).toBeFalse(); + }); + + it('checks if value is within bounds when filter type is INTERVAL', () => { + const filter: IntervalFilter = { + type: DomainType.INTERVAL, + includeUndefined: true, + minValue: 0, + maxValue: 10, + filterLowerValue: 1, + filterUpperValue: 5, + }; + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 0)).toBeFalse(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 1)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 3)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 5)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 6)).toBeFalse(); + + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 'foo')).toBeFalse(); + }); + + it('returns false if filter type is neither DISCRETE nor INTERVAL', () => { + expect( + selectors.TEST_ONLY.utils.matchFilter({} as DiscreteFilter, 'foo') + ).toBeFalse(); + }); + }); + + describe('filterRunItemsByRegex', () => { + it('returns all runs when no regex is provided', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + '', + false + ) + ).toEqual(runTableItems); + }); + + it('only returns runs matching regex', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'run', + false + ) + ).toEqual(runTableItems); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'exp', + false + ) + ).toEqual([]); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'exp', + true + ) + ).toEqual(runTableItems); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'run1', + false + ) + ).toEqual([runTableItems[0], runTableItems[2]]); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'run2', + false + ) + ).toEqual([runTableItems[1]]); + }); + }); + + describe('getRenderableRuns', () => { + it('returns all runs associated with experiment', () => { + const exp1Result = selectors.TEST_ONLY.utils.getRenderableRuns(['exp1'])( + state + ); + expect(exp1Result.length).toEqual(2); + expect(exp1Result[0].run).toEqual({...run1, experimentId: 'exp1'}); + expect(exp1Result[1].run).toEqual({...run2, experimentId: 'exp1'}); + + const exp2Result = selectors.TEST_ONLY.utils.getRenderableRuns(['exp2'])( + state + ); + expect(exp2Result.length).toEqual(3); + expect(exp2Result[0].run).toEqual({...run2, experimentId: 'exp2'}); + expect(exp2Result[1].run).toEqual({...run3, experimentId: 'exp2'}); + expect(exp2Result[2].run).toEqual({...run4, experimentId: 'exp2'}); + }); + + it('returns two runs when a run is associated with multiple experiments', () => { + const result = selectors.TEST_ONLY.utils.getRenderableRuns([ + 'exp1', + 'exp2', + ])(state); + expect(result.length).toEqual(5); + expect(result[0].run).toEqual({...run1, experimentId: 'exp1'}); + expect(result[1].run).toEqual({...run2, experimentId: 'exp1'}); + expect(result[2].run).toEqual({...run2, experimentId: 'exp2'}); + expect(result[3].run).toEqual({...run3, experimentId: 'exp2'}); + expect(result[4].run).toEqual({...run4, experimentId: 'exp2'}); + }); + + it('returns empty list when no experiments are provided', () => { + expect(selectors.TEST_ONLY.utils.getRenderableRuns([])(state)).toEqual( + [] + ); + }); + }); + + describe('filterRunItemsByHparamAndMetricFilter', () => { + it('filters by hparams using discrete filters', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'lr', + { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [1, 3, 5], + filterValues: [1], + }, + ], + ]), + new Map() + ).length + ).toEqual(1); + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'lr', + { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [1, 3, 5], + filterValues: [1, 5], + }, + ], + ]), + new Map() + ).length + ).toEqual(2); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'who knows', + { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [1, 3, 5], + filterValues: [1, 5], + }, + ], + ]), + new Map() + ).length + ).toEqual(0); + }); + + it('filters by hparams using interval filters', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'lr', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 5, + includeUndefined: true, + }, + ], + ]), + new Map() + ).length + ).toEqual(2); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'who knows', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 5, + includeUndefined: false, + }, + ], + ]), + new Map() + ).length + ).toEqual(0); + }); + + it('filters by metrics using interval filters', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map(), + new Map([ + [ + 'foo', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 3, + includeUndefined: false, + }, + ], + ]) + ).length + ).toEqual(2); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map(), + new Map([ + [ + 'bar', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 3, + includeUndefined: false, + }, + ], + ]) + ).length + ).toEqual(0); + }); + }); + + describe('getFilteredRenderableRuns', () => { + it('does not use experiment alias when route is not compare', () => { + state.runs!.data.regexFilter = 'foo'; + const result = selectors.getFilteredRenderableRuns(['exp1'])(state); + expect(result).toEqual([]); + }); + + it('uses experiment alias when route is compare', () => { + state.runs!.data.regexFilter = 'foo'; + state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; + const result = selectors.getFilteredRenderableRuns(['exp1'])(state); + expect(result.length).toEqual(2); + expect(result[0].run.name).toEqual('run 1'); + expect(result[1].run.name).toEqual('run 2'); + }); + + it('filters runs by hparam and metrics', () => { + const spy = spyOn( + selectors.TEST_ONLY.utils, + 'filterRunItemsByHparamAndMetricFilter' + ).and.callThrough(); + const results = selectors.getFilteredRenderableRuns(['exp1'])(state); + expect(spy).toHaveBeenCalledOnceWith(results, new Map(), new Map()); + }); + + it('returns empty list when no experiments are provided', () => { + expect(selectors.getFilteredRenderableRuns([])(state)).toEqual([]); + }); + }); + + describe('getFilteredRenderableRunsFromRoute', () => { + it('calls getFilteredRenderableRuns with experiment ids from the route when in compare view', () => { + state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; + const result = selectors.getFilteredRenderableRunsFromRoute(state); + expect(result).toEqual( + selectors.getFilteredRenderableRuns(['exp1', 'exp2'])(state) + ); + }); + + it('calls getFilteredRenderableRuns with experiment ids from the route when in single experiment view', () => { + const result = selectors.getFilteredRenderableRunsFromRoute(state); + expect(result).toEqual( + selectors.getFilteredRenderableRuns(['defaultExperimentId'])(state) + ); + }); + }); }); From 67362e66d74eafe2d82f77a245528936b10045cd Mon Sep 17 00:00:00 2001 From: Riley Jones Date: Fri, 5 May 2023 17:08:30 +0000 Subject: [PATCH 3/6] address pr comments --- .../webapp/metrics/views/main_view/BUILD | 1 + .../views/main_view/common_selectors.ts | 30 +++--- .../views/main_view/common_selectors_test.ts | 91 +++++++++++++++---- 3 files changed, 90 insertions(+), 32 deletions(-) diff --git a/tensorboard/webapp/metrics/views/main_view/BUILD b/tensorboard/webapp/metrics/views/main_view/BUILD index 0a50296375..c81e66906a 100644 --- a/tensorboard/webapp/metrics/views/main_view/BUILD +++ b/tensorboard/webapp/metrics/views/main_view/BUILD @@ -90,6 +90,7 @@ tf_ts_library( "//tensorboard/webapp/metrics/views:types", "//tensorboard/webapp/metrics/views:utils", "//tensorboard/webapp/runs/views/runs_table:types", + "//tensorboard/webapp/runs:types", "//tensorboard/webapp/util:matcher", "//tensorboard/webapp/util:types", "@npm//@ngrx/store", diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index f6b4caa6df..954953e69a 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -44,6 +44,7 @@ import {getNonEmptyCardIdsWithMetadata, TagMetadata} from '../../store'; import {compareTagNames} from '../../utils'; import {CardIdWithMetadata} from '../metrics_view_types'; import {RouteKind} from '../../../app_routing/types'; +import {Run} from '../../../runs/types'; export const getScalarTagsForRunSelection = createSelector( getMetricsTagMetadata, @@ -114,7 +115,13 @@ const utils = { getCurrentRouteRunSelection, getRunColorMap, getExperimentIdToExperimentAliasMap, - (runs, experimentNames, selectionMap, colorMap, experimentIdToAlias) => { + ( + runs, + experimentNames, + selectionMap, + colorMap, + experimentIdToAlias + ): Array => { return runs.map((run) => { const hparamMap: RunTableItem['hparams'] = new Map(); (run.hparams || []).forEach((hparam) => { @@ -142,7 +149,7 @@ const utils = { runItems: RunTableItem[], regexString: string, shouldIncludeExperimentName: boolean - ) { + ): RunTableItem[] { if (!regexString) { return runItems; } @@ -172,14 +179,10 @@ const utils = { filter.filterValues; return values.includes(value); } else if (filter.type === DomainType.INTERVAL) { - // Auto-added to unblock TS5.0 migration - // @ts-ignore(go/ts50upgrade): Operator '<=' cannot be applied to types - // 'number' and 'string | number | boolean'. - // Auto-added to unblock TS5.0 migration - // @ts-ignore(go/ts50upgrade): Operator '<=' cannot be applied to types - // 'string | number | boolean' and 'number'. return ( - filter.filterLowerValue <= value && value <= filter.filterUpperValue + typeof value === 'number' && + filter.filterLowerValue <= value && + value <= filter.filterUpperValue ); } return false; @@ -198,13 +201,14 @@ const utils = { } ); - return ( - hparamMatches && - [...metricFilters.entries()].every(([metricTag, filter]) => { + const metricMatches = [...metricFilters.entries()].every( + ([metricTag, filter]) => { const value = metrics.get(metricTag); return utils.matchFilter(filter, value); - }) + } ); + + return hparamMatches && metricMatches; }); }, }; diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 61de01323d..d35d3173e6 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -86,12 +86,7 @@ describe('common selectors', () => { id: 'run1-id', name: 'run1', startTime: 0, - hparams: [ - { - name: 'accurracy', - value: 1, - }, - ], + hparams: null, metrics: null, }, experimentAlias: { @@ -108,12 +103,7 @@ describe('common selectors', () => { id: 'run2-id', name: 'run2', startTime: 0, - hparams: [ - { - name: 'accurracy', - value: 1, - }, - ], + hparams: null, metrics: null, }, experimentAlias: { @@ -130,12 +120,7 @@ describe('common selectors', () => { id: 'run1-id', name: 'run1', startTime: 0, - hparams: [ - { - name: 'accurracy', - value: 1, - }, - ], + hparams: null, metrics: null, }, experimentAlias: { @@ -584,7 +569,7 @@ describe('common selectors', () => { ).toBeFalse(); }); - it('returns values including value when filter type is DISCRETE', () => { + it('returns values including value when filter type is DISCRETE and values are strings', () => { const filter: DiscreteFilter = { type: DomainType.DISCRETE, includeUndefined: false, @@ -595,6 +580,40 @@ describe('common selectors', () => { expect(selectors.TEST_ONLY.utils.matchFilter(filter, 'bar')).toBeFalse(); }); + it('returns values including value when filter type is DISCRETE and values are numbers', () => { + const filter: DiscreteFilter = { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [], + filterValues: [0, 1, 2, 3, 4], + }; + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 0)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 2)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 5)).toBeFalse(); + }); + + it('returns values including value when filter type is DISCRETE and values are booleans', () => { + const filter: DiscreteFilter = { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [], + filterValues: [true, false], + }; + expect(selectors.TEST_ONLY.utils.matchFilter(filter, true)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, false)).toBeTrue(); + expect( + selectors.TEST_ONLY.utils.matchFilter( + { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [], + filterValues: [false], + }, + false + ) + ).toBeTrue(); + }); + it('checks if value is within bounds when filter type is INTERVAL', () => { const filter: IntervalFilter = { type: DomainType.INTERVAL, @@ -851,6 +870,40 @@ describe('common selectors', () => { ).length ).toEqual(0); }); + + it('filters by both hparams and metrics', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'lr', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 5, + includeUndefined: true, + }, + ], + ]), + new Map([ + [ + 'foo', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 3, + includeUndefined: false, + }, + ], + ]) + ).length + ).toEqual(1); + }); }); describe('getFilteredRenderableRuns', () => { From e80290c84ff216fe42f8aa3c88f7ff8ebd13b277 Mon Sep 17 00:00:00 2001 From: Riley Jones Date: Fri, 5 May 2023 17:43:42 +0000 Subject: [PATCH 4/6] start exporting `getRenderableRuns` --- .../views/main_view/common_selectors.ts | 76 +++++++++---------- .../views/main_view/common_selectors_test.ts | 17 +---- 2 files changed, 42 insertions(+), 51 deletions(-) diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 954953e69a..fb132ab13d 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -108,43 +108,6 @@ export const getSortedRenderableCardIdsWithMetadata = createSelector< }); const utils = { - getRenderableRuns(experimentIds: string[]) { - return createSelector( - getRunsFromExperimentIds(experimentIds), - getExperimentNames(experimentIds), - getCurrentRouteRunSelection, - getRunColorMap, - getExperimentIdToExperimentAliasMap, - ( - runs, - experimentNames, - selectionMap, - colorMap, - experimentIdToAlias - ): Array => { - return runs.map((run) => { - const hparamMap: RunTableItem['hparams'] = new Map(); - (run.hparams || []).forEach((hparam) => { - hparamMap.set(hparam.name, hparam.value); - }); - const metricMap: RunTableItem['metrics'] = new Map(); - (run.metrics || []).forEach((metric) => { - metricMap.set(metric.tag, metric.value); - }); - return { - run, - experimentName: experimentNames[run.experimentId] || '', - experimentAlias: experimentIdToAlias[run.experimentId], - selected: Boolean(selectionMap && selectionMap.get(run.id)), - runColor: colorMap[run.id], - hparams: hparamMap, - metrics: metricMap, - }; - }); - } - ); - }, - filterRunItemsByRegex( runItems: RunTableItem[], regexString: string, @@ -213,10 +176,47 @@ const utils = { }, }; +export function getRenderableRuns(experimentIds: string[]) { + return createSelector( + getRunsFromExperimentIds(experimentIds), + getExperimentNames(experimentIds), + getCurrentRouteRunSelection, + getRunColorMap, + getExperimentIdToExperimentAliasMap, + ( + runs, + experimentNames, + selectionMap, + colorMap, + experimentIdToAlias + ): Array => { + return runs.map((run) => { + const hparamMap: RunTableItem['hparams'] = new Map(); + (run.hparams || []).forEach((hparam) => { + hparamMap.set(hparam.name, hparam.value); + }); + const metricMap: RunTableItem['metrics'] = new Map(); + (run.metrics || []).forEach((metric) => { + metricMap.set(metric.tag, metric.value); + }); + return { + run, + experimentName: experimentNames[run.experimentId] || '', + experimentAlias: experimentIdToAlias[run.experimentId], + selected: Boolean(selectionMap && selectionMap.get(run.id)), + runColor: colorMap[run.id], + hparams: hparamMap, + metrics: metricMap, + }; + }); + } + ); +} + export function getFilteredRenderableRuns(experimentIds: string[]) { return createSelector( getRunSelectorRegexFilter, - utils.getRenderableRuns(experimentIds), + getRenderableRuns(experimentIds), getHparamFilterMapFromExperimentIds(experimentIds), getMetricFilterMapFromExperimentIds(experimentIds), getRouteKind, diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index d35d3173e6..700e1a3cc5 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -695,16 +695,12 @@ describe('common selectors', () => { describe('getRenderableRuns', () => { it('returns all runs associated with experiment', () => { - const exp1Result = selectors.TEST_ONLY.utils.getRenderableRuns(['exp1'])( - state - ); + const exp1Result = selectors.getRenderableRuns(['exp1'])(state); expect(exp1Result.length).toEqual(2); expect(exp1Result[0].run).toEqual({...run1, experimentId: 'exp1'}); expect(exp1Result[1].run).toEqual({...run2, experimentId: 'exp1'}); - const exp2Result = selectors.TEST_ONLY.utils.getRenderableRuns(['exp2'])( - state - ); + const exp2Result = selectors.getRenderableRuns(['exp2'])(state); expect(exp2Result.length).toEqual(3); expect(exp2Result[0].run).toEqual({...run2, experimentId: 'exp2'}); expect(exp2Result[1].run).toEqual({...run3, experimentId: 'exp2'}); @@ -712,10 +708,7 @@ describe('common selectors', () => { }); it('returns two runs when a run is associated with multiple experiments', () => { - const result = selectors.TEST_ONLY.utils.getRenderableRuns([ - 'exp1', - 'exp2', - ])(state); + const result = selectors.getRenderableRuns(['exp1', 'exp2'])(state); expect(result.length).toEqual(5); expect(result[0].run).toEqual({...run1, experimentId: 'exp1'}); expect(result[1].run).toEqual({...run2, experimentId: 'exp1'}); @@ -725,9 +718,7 @@ describe('common selectors', () => { }); it('returns empty list when no experiments are provided', () => { - expect(selectors.TEST_ONLY.utils.getRenderableRuns([])(state)).toEqual( - [] - ); + expect(selectors.getRenderableRuns([])(state)).toEqual([]); }); }); From 5046c51c20629a07f161578012953ac899e654b0 Mon Sep 17 00:00:00 2001 From: Riley Jones Date: Fri, 5 May 2023 17:50:02 +0000 Subject: [PATCH 5/6] lint build file --- tensorboard/webapp/metrics/views/main_view/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorboard/webapp/metrics/views/main_view/BUILD b/tensorboard/webapp/metrics/views/main_view/BUILD index c81e66906a..5dffd6600b 100644 --- a/tensorboard/webapp/metrics/views/main_view/BUILD +++ b/tensorboard/webapp/metrics/views/main_view/BUILD @@ -89,8 +89,8 @@ tf_ts_library( "//tensorboard/webapp/metrics/store", "//tensorboard/webapp/metrics/views:types", "//tensorboard/webapp/metrics/views:utils", - "//tensorboard/webapp/runs/views/runs_table:types", "//tensorboard/webapp/runs:types", + "//tensorboard/webapp/runs/views/runs_table:types", "//tensorboard/webapp/util:matcher", "//tensorboard/webapp/util:types", "@npm//@ngrx/store", From 9e31b34373b5179f35f10c1c2acbc76a95134153 Mon Sep 17 00:00:00 2001 From: Riley Jones Date: Mon, 8 May 2023 18:08:26 +0000 Subject: [PATCH 6/6] make all runids unique --- .../webapp/metrics/views/main_view/common_selectors_test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 700e1a3cc5..26f85f4fa6 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -117,7 +117,7 @@ describe('common selectors', () => { }, { run: { - id: 'run1-id', + id: 'run3-id', name: 'run1', startTime: 0, hparams: null, @@ -678,7 +678,7 @@ describe('common selectors', () => { expect( selectors.TEST_ONLY.utils.filterRunItemsByRegex( runTableItems, - 'run1', + 'run[13]', false ) ).toEqual([runTableItems[0], runTableItems[2]]);