diff --git a/tensorboard/webapp/metrics/views/main_view/BUILD b/tensorboard/webapp/metrics/views/main_view/BUILD index 41b28403c7..5dffd6600b 100644 --- a/tensorboard/webapp/metrics/views/main_view/BUILD +++ b/tensorboard/webapp/metrics/views/main_view/BUILD @@ -81,11 +81,17 @@ tf_ts_library( deps = [ "//tensorboard/webapp:app_state", "//tensorboard/webapp:selectors", + "//tensorboard/webapp/app_routing:types", + "//tensorboard/webapp/hparams:types", + "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/metrics:utils", "//tensorboard/webapp/metrics/data_source", "//tensorboard/webapp/metrics/store", "//tensorboard/webapp/metrics/views:types", "//tensorboard/webapp/metrics/views:utils", + "//tensorboard/webapp/runs:types", + "//tensorboard/webapp/runs/views/runs_table:types", + "//tensorboard/webapp/util:matcher", "//tensorboard/webapp/util:types", "@npm//@ngrx/store", ], @@ -185,6 +191,8 @@ tf_ts_library( "//tensorboard/webapp/app_routing:testing", "//tensorboard/webapp/app_routing/store:testing", "//tensorboard/webapp/customization", + "//tensorboard/webapp/experiments/store:testing", + "//tensorboard/webapp/hparams:types", "//tensorboard/webapp/metrics:test_lib", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics/actions", @@ -192,8 +200,10 @@ tf_ts_library( "//tensorboard/webapp/metrics/store", "//tensorboard/webapp/metrics/views:types", "//tensorboard/webapp/metrics/views/card_renderer", + "//tensorboard/webapp/runs/store:selectors", "//tensorboard/webapp/runs/store:testing", "//tensorboard/webapp/runs/store:types", + "//tensorboard/webapp/runs/views/runs_table:types", "//tensorboard/webapp/settings", "//tensorboard/webapp/testing:dom", "//tensorboard/webapp/testing:mat_icon", diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index a8967720af..fb132ab13d 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -18,12 +18,33 @@ import { getCurrentRouteRunSelection, getMetricsHideEmptyCards, getMetricsTagMetadata, + getExperimentIdsFromRoute, + getExperimentIdToExperimentAliasMap, + getExperimentNames, + getRunColorMap, + getRunSelectorRegexFilter, + getRouteKind, + getRunsFromExperimentIds, } from '../../../selectors'; import {DeepReadonly} from '../../../util/types'; +import { + getHparamFilterMapFromExperimentIds, + getMetricFilterMapFromExperimentIds, +} from '../../../hparams/_redux/hparams_selectors'; +import { + DiscreteFilter, + DiscreteHparamValue, + DomainType, + IntervalFilter, +} from '../../../hparams/types'; +import {RunTableItem} from '../../../runs/views/runs_table/types'; +import {matchRunToRegex} from '../../../util/matcher'; import {isSingleRunPlugin, PluginType} from '../../data_source'; import {getNonEmptyCardIdsWithMetadata, TagMetadata} from '../../store'; import {compareTagNames} from '../../utils'; import {CardIdWithMetadata} from '../metrics_view_types'; +import {RouteKind} from '../../../app_routing/types'; +import {Run} from '../../../runs/types'; export const getScalarTagsForRunSelection = createSelector( getMetricsTagMetadata, @@ -86,7 +107,145 @@ export const getSortedRenderableCardIdsWithMetadata = createSelector< }); }); +const utils = { + filterRunItemsByRegex( + runItems: RunTableItem[], + regexString: string, + shouldIncludeExperimentName: boolean + ): RunTableItem[] { + if (!regexString) { + return runItems; + } + + return runItems.filter((item) => { + return matchRunToRegex( + { + runName: item.run.name, + experimentAlias: item.experimentAlias, + }, + regexString, + shouldIncludeExperimentName + ); + }); + }, + + matchFilter( + filter: DiscreteFilter | IntervalFilter, + value: number | DiscreteHparamValue | undefined + ): boolean { + if (value === undefined) { + return filter.includeUndefined; + } + if (filter.type === DomainType.DISCRETE) { + // (upcast to work around bad TypeScript libdefs) + const values: Readonly> = + filter.filterValues; + return values.includes(value); + } else if (filter.type === DomainType.INTERVAL) { + return ( + typeof value === 'number' && + filter.filterLowerValue <= value && + value <= filter.filterUpperValue + ); + } + return false; + }, + + filterRunItemsByHparamAndMetricFilter( + runItems: RunTableItem[], + hparamFilters: Map, + metricFilters: Map + ) { + return runItems.filter(({hparams, metrics}) => { + const hparamMatches = [...hparamFilters.entries()].every( + ([hparamName, filter]) => { + const value = hparams.get(hparamName); + return utils.matchFilter(filter, value); + } + ); + + const metricMatches = [...metricFilters.entries()].every( + ([metricTag, filter]) => { + const value = metrics.get(metricTag); + return utils.matchFilter(filter, value); + } + ); + + return hparamMatches && metricMatches; + }); + }, +}; + +export function getRenderableRuns(experimentIds: string[]) { + return createSelector( + getRunsFromExperimentIds(experimentIds), + getExperimentNames(experimentIds), + getCurrentRouteRunSelection, + getRunColorMap, + getExperimentIdToExperimentAliasMap, + ( + runs, + experimentNames, + selectionMap, + colorMap, + experimentIdToAlias + ): Array => { + return runs.map((run) => { + const hparamMap: RunTableItem['hparams'] = new Map(); + (run.hparams || []).forEach((hparam) => { + hparamMap.set(hparam.name, hparam.value); + }); + const metricMap: RunTableItem['metrics'] = new Map(); + (run.metrics || []).forEach((metric) => { + metricMap.set(metric.tag, metric.value); + }); + return { + run, + experimentName: experimentNames[run.experimentId] || '', + experimentAlias: experimentIdToAlias[run.experimentId], + selected: Boolean(selectionMap && selectionMap.get(run.id)), + runColor: colorMap[run.id], + hparams: hparamMap, + metrics: metricMap, + }; + }); + } + ); +} + +export function getFilteredRenderableRuns(experimentIds: string[]) { + return createSelector( + getRunSelectorRegexFilter, + getRenderableRuns(experimentIds), + getHparamFilterMapFromExperimentIds(experimentIds), + getMetricFilterMapFromExperimentIds(experimentIds), + getRouteKind, + (regexFilter, runItems, hparamFilters, metricFilters, routeKind) => { + const regexFilteredItems = utils.filterRunItemsByRegex( + runItems, + regexFilter, + routeKind === RouteKind.COMPARE_EXPERIMENT + ); + + return utils.filterRunItemsByHparamAndMetricFilter( + regexFilteredItems, + hparamFilters, + metricFilters + ); + } + ); +} + +export const getFilteredRenderableRunsFromRoute = createSelector( + (state) => state, + getExperimentIdsFromRoute, + (state, experimentIds) => { + return getFilteredRenderableRuns(experimentIds || [])(state); + } +); + export const TEST_ONLY = { getRenderableCardIdsWithMetadata, getScalarTagsForRunSelection, + utils, }; diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 46cef42f6f..26f85f4fa6 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -18,11 +18,16 @@ import { buildStateFromAppRoutingState, } from '../../../app_routing/store/testing'; import {buildRoute} from '../../../app_routing/testing'; -import {Run} from '../../../runs/store/runs_types'; +import {buildExperiment} from '../../../experiments/store/testing'; +import {IntervalFilter, DiscreteFilter} from '../../../hparams/types'; +import {DomainType, Run} from '../../../runs/store/runs_types'; import { + buildRun, buildRunsState, buildStateFromRunsState, } from '../../../runs/store/testing'; +import {RunTableItem} from '../../../runs/views/runs_table/types'; +import {buildMockState} from '../../../testing/utils'; import { appStateFromMetricsState, buildMetricsSettingsState, @@ -35,6 +40,15 @@ describe('common selectors', () => { let runIds: Record; let runIdToExpId: Record; let runMetadata: Record; + + let runTableItems: RunTableItem[]; + + let run1: Run; + let run2: Run; + let run3: Run; + let run4: Run; + let state: ReturnType; + beforeEach(() => { runIds = {defaultExperimentId: ['run1', 'run2', 'run3']}; runIdToExpId = { @@ -65,6 +79,99 @@ describe('common selectors', () => { metrics: null, }, }; + + runTableItems = [ + { + run: { + id: 'run1-id', + name: 'run1', + startTime: 0, + hparams: null, + metrics: null, + }, + experimentAlias: { + aliasNumber: 1, + aliasText: 'exp1', + }, + experimentName: 'experiment1', + selected: true, + hparams: new Map([['lr', 5]]), + metrics: new Map([['foo', 1]]), + }, + { + run: { + id: 'run2-id', + name: 'run2', + startTime: 0, + hparams: null, + metrics: null, + }, + experimentAlias: { + aliasNumber: 1, + aliasText: 'exp1', + }, + experimentName: 'experiment1', + selected: true, + hparams: new Map([['lr', 3]]), + metrics: new Map([['foo', 2]]), + }, + { + run: { + id: 'run3-id', + name: 'run1', + startTime: 0, + hparams: null, + metrics: null, + }, + experimentAlias: { + aliasNumber: 1, + aliasText: 'exp2', + }, + experimentName: 'experiment2', + selected: true, + hparams: new Map([['lr', 1]]), + metrics: new Map([['foo', 3]]), + }, + ]; + + run1 = buildRun({name: 'run 1'}); + run2 = buildRun({id: '2', name: 'run 2'}); + run3 = buildRun({id: '3', name: 'run 3'}); + run4 = buildRun({id: '4', name: 'run 4'}); + state = buildMockState({ + runs: { + data: { + regexFilter: '', + runIds: { + exp1: ['run1', 'run2'], + exp2: ['run2', 'run3', 'run4'], + }, + runMetadata: { + run1, + run2, + run3, + run4, + }, + } as any, + ui: {} as any, + }, + experiments: { + data: { + experimentMap: { + exp1: buildExperiment({name: 'experiment1', id: 'exp1'}), + exp2: buildExperiment({name: 'experiment2', id: 'exp2'}), + }, + }, + }, + app_routing: { + activeRoute: { + routeKind: RouteKind.EXPERIMENT, + params: { + experimentIds: 'foo:exp1,bar:exp2', + }, + }, + } as any, + }); }); describe('getScalarTagsForRunSelection', () => { @@ -430,4 +537,410 @@ describe('common selectors', () => { ]); }); }); + + describe('matchFilter', () => { + it('respects includeUndefined when value is undefined', () => { + expect( + selectors.TEST_ONLY.utils.matchFilter( + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 1, + filterUpperValue: 5, + includeUndefined: true, + }, + undefined + ) + ).toBeTrue(); + + expect( + selectors.TEST_ONLY.utils.matchFilter( + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 1, + filterUpperValue: 5, + includeUndefined: false, + }, + undefined + ) + ).toBeFalse(); + }); + + it('returns values including value when filter type is DISCRETE and values are strings', () => { + const filter: DiscreteFilter = { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [], + filterValues: ['afoo', 'foob', 'foo', 'fo'], + }; + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 'foo')).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 'bar')).toBeFalse(); + }); + + it('returns values including value when filter type is DISCRETE and values are numbers', () => { + const filter: DiscreteFilter = { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [], + filterValues: [0, 1, 2, 3, 4], + }; + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 0)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 2)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 5)).toBeFalse(); + }); + + it('returns values including value when filter type is DISCRETE and values are booleans', () => { + const filter: DiscreteFilter = { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [], + filterValues: [true, false], + }; + expect(selectors.TEST_ONLY.utils.matchFilter(filter, true)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, false)).toBeTrue(); + expect( + selectors.TEST_ONLY.utils.matchFilter( + { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [], + filterValues: [false], + }, + false + ) + ).toBeTrue(); + }); + + it('checks if value is within bounds when filter type is INTERVAL', () => { + const filter: IntervalFilter = { + type: DomainType.INTERVAL, + includeUndefined: true, + minValue: 0, + maxValue: 10, + filterLowerValue: 1, + filterUpperValue: 5, + }; + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 0)).toBeFalse(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 1)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 3)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 5)).toBeTrue(); + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 6)).toBeFalse(); + + expect(selectors.TEST_ONLY.utils.matchFilter(filter, 'foo')).toBeFalse(); + }); + + it('returns false if filter type is neither DISCRETE nor INTERVAL', () => { + expect( + selectors.TEST_ONLY.utils.matchFilter({} as DiscreteFilter, 'foo') + ).toBeFalse(); + }); + }); + + describe('filterRunItemsByRegex', () => { + it('returns all runs when no regex is provided', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + '', + false + ) + ).toEqual(runTableItems); + }); + + it('only returns runs matching regex', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'run', + false + ) + ).toEqual(runTableItems); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'exp', + false + ) + ).toEqual([]); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'exp', + true + ) + ).toEqual(runTableItems); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'run[13]', + false + ) + ).toEqual([runTableItems[0], runTableItems[2]]); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByRegex( + runTableItems, + 'run2', + false + ) + ).toEqual([runTableItems[1]]); + }); + }); + + describe('getRenderableRuns', () => { + it('returns all runs associated with experiment', () => { + const exp1Result = selectors.getRenderableRuns(['exp1'])(state); + expect(exp1Result.length).toEqual(2); + expect(exp1Result[0].run).toEqual({...run1, experimentId: 'exp1'}); + expect(exp1Result[1].run).toEqual({...run2, experimentId: 'exp1'}); + + const exp2Result = selectors.getRenderableRuns(['exp2'])(state); + expect(exp2Result.length).toEqual(3); + expect(exp2Result[0].run).toEqual({...run2, experimentId: 'exp2'}); + expect(exp2Result[1].run).toEqual({...run3, experimentId: 'exp2'}); + expect(exp2Result[2].run).toEqual({...run4, experimentId: 'exp2'}); + }); + + it('returns two runs when a run is associated with multiple experiments', () => { + const result = selectors.getRenderableRuns(['exp1', 'exp2'])(state); + expect(result.length).toEqual(5); + expect(result[0].run).toEqual({...run1, experimentId: 'exp1'}); + expect(result[1].run).toEqual({...run2, experimentId: 'exp1'}); + expect(result[2].run).toEqual({...run2, experimentId: 'exp2'}); + expect(result[3].run).toEqual({...run3, experimentId: 'exp2'}); + expect(result[4].run).toEqual({...run4, experimentId: 'exp2'}); + }); + + it('returns empty list when no experiments are provided', () => { + expect(selectors.getRenderableRuns([])(state)).toEqual([]); + }); + }); + + describe('filterRunItemsByHparamAndMetricFilter', () => { + it('filters by hparams using discrete filters', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'lr', + { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [1, 3, 5], + filterValues: [1], + }, + ], + ]), + new Map() + ).length + ).toEqual(1); + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'lr', + { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [1, 3, 5], + filterValues: [1, 5], + }, + ], + ]), + new Map() + ).length + ).toEqual(2); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'who knows', + { + type: DomainType.DISCRETE, + includeUndefined: false, + possibleValues: [1, 3, 5], + filterValues: [1, 5], + }, + ], + ]), + new Map() + ).length + ).toEqual(0); + }); + + it('filters by hparams using interval filters', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'lr', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 5, + includeUndefined: true, + }, + ], + ]), + new Map() + ).length + ).toEqual(2); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'who knows', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 5, + includeUndefined: false, + }, + ], + ]), + new Map() + ).length + ).toEqual(0); + }); + + it('filters by metrics using interval filters', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map(), + new Map([ + [ + 'foo', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 3, + includeUndefined: false, + }, + ], + ]) + ).length + ).toEqual(2); + + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map(), + new Map([ + [ + 'bar', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 3, + includeUndefined: false, + }, + ], + ]) + ).length + ).toEqual(0); + }); + + it('filters by both hparams and metrics', () => { + expect( + selectors.TEST_ONLY.utils.filterRunItemsByHparamAndMetricFilter( + runTableItems, + new Map([ + [ + 'lr', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 5, + includeUndefined: true, + }, + ], + ]), + new Map([ + [ + 'foo', + { + type: DomainType.INTERVAL, + minValue: 0, + maxValue: 10, + filterLowerValue: 2, + filterUpperValue: 3, + includeUndefined: false, + }, + ], + ]) + ).length + ).toEqual(1); + }); + }); + + describe('getFilteredRenderableRuns', () => { + it('does not use experiment alias when route is not compare', () => { + state.runs!.data.regexFilter = 'foo'; + const result = selectors.getFilteredRenderableRuns(['exp1'])(state); + expect(result).toEqual([]); + }); + + it('uses experiment alias when route is compare', () => { + state.runs!.data.regexFilter = 'foo'; + state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; + const result = selectors.getFilteredRenderableRuns(['exp1'])(state); + expect(result.length).toEqual(2); + expect(result[0].run.name).toEqual('run 1'); + expect(result[1].run.name).toEqual('run 2'); + }); + + it('filters runs by hparam and metrics', () => { + const spy = spyOn( + selectors.TEST_ONLY.utils, + 'filterRunItemsByHparamAndMetricFilter' + ).and.callThrough(); + const results = selectors.getFilteredRenderableRuns(['exp1'])(state); + expect(spy).toHaveBeenCalledOnceWith(results, new Map(), new Map()); + }); + + it('returns empty list when no experiments are provided', () => { + expect(selectors.getFilteredRenderableRuns([])(state)).toEqual([]); + }); + }); + + describe('getFilteredRenderableRunsFromRoute', () => { + it('calls getFilteredRenderableRuns with experiment ids from the route when in compare view', () => { + state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; + const result = selectors.getFilteredRenderableRunsFromRoute(state); + expect(result).toEqual( + selectors.getFilteredRenderableRuns(['exp1', 'exp2'])(state) + ); + }); + + it('calls getFilteredRenderableRuns with experiment ids from the route when in single experiment view', () => { + const result = selectors.getFilteredRenderableRunsFromRoute(state); + expect(result).toEqual( + selectors.getFilteredRenderableRuns(['defaultExperimentId'])(state) + ); + }); + }); });