diff --git a/tensorboard/webapp/metrics/views/card_renderer/BUILD b/tensorboard/webapp/metrics/views/card_renderer/BUILD index 80f503266a..30f3cbf377 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/BUILD +++ b/tensorboard/webapp/metrics/views/card_renderer/BUILD @@ -312,6 +312,7 @@ tf_ng_module( "//tensorboard/webapp/metrics/store", "//tensorboard/webapp/metrics/views:types", "//tensorboard/webapp/metrics/views:utils", + "//tensorboard/webapp/metrics/views/main_view:common_selectors", "//tensorboard/webapp/runs/store:types", "//tensorboard/webapp/types", "//tensorboard/webapp/types:ui", @@ -379,6 +380,7 @@ tf_ts_library( "//tensorboard/webapp/metrics/actions", "//tensorboard/webapp/metrics/data_source", "//tensorboard/webapp/metrics/store", + "//tensorboard/webapp/metrics/views/main_view:common_selectors", "//tensorboard/webapp/runs/store:testing", "//tensorboard/webapp/runs/store:types", "//tensorboard/webapp/testing:mat_icon", diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts index cda18d1627..fc35200901 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts @@ -44,7 +44,6 @@ import { import { getCardPinnedState, getCardStateMap, - getCurrentRouteRunSelection, getDarkModeEnabled, getExperimentIdForRunId, getExperimentIdToExperimentAliasMap, @@ -57,6 +56,7 @@ import { getMetricsCardRangeSelectionEnabled, getRun, getRunColorMap, + getCurrentRouteRunSelection, } from '../../../selectors'; import {DataLoadState} from '../../../types/data'; import { @@ -94,6 +94,7 @@ import { RunToSeries, } from '../../store'; import {CardId, CardMetadata, HeaderEditInfo, XAxisType} from '../../types'; +import {getFilteredRenderableRunsIdsFromRoute} from '../main_view/common_selectors'; import {CardRenderer} from '../metrics_view_types'; import {getTagDisplayName} from '../utils'; import {DataDownloadDialogContainer} from './data_download_dialog_container'; @@ -115,8 +116,6 @@ import { TimeSelectionView, } from './utils'; -const DEFAULT_MIN = -Infinity; -const DEFAULT_MAX = Infinity; type ScalarCardMetadata = CardMetadata & { plugin: PluginType.SCALARS; }; @@ -502,6 +501,7 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy { }), combineLatestWith( this.store.select(getCurrentRouteRunSelection), + this.store.select(getFilteredRenderableRunsIdsFromRoute), this.store.select(getRunColorMap), this.store.select(getMetricsScalarSmoothing) ), @@ -510,55 +510,67 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy { // debounce by a microtask to emit only single change for the runs // store change. debounceTime(0), - map(([namedPartitionedSeries, runSelectionMap, colorMap, smoothing]) => { - const metadataMap: ScalarCardSeriesMetadataMap = {}; - const shouldSmooth = smoothing > 0; + map( + ([ + namedPartitionedSeries, + runSelectionMap, + renderableRuns, + colorMap, + smoothing, + ]) => { + const metadataMap: ScalarCardSeriesMetadataMap = {}; + const shouldSmooth = smoothing > 0; + + for (const partitioned of namedPartitionedSeries) { + const { + seriesId, + runId, + displayName, + alias, + partitionIndex, + partitionSize, + } = partitioned; + + metadataMap[seriesId] = { + type: SeriesType.ORIGINAL, + id: seriesId, + alias, + displayName: + partitionSize > 1 + ? `${displayName}: ${partitionIndex}` + : displayName, + visible: Boolean( + runSelectionMap && + runSelectionMap.get(runId) && + renderableRuns.has(runId) + ), + color: colorMap[runId] ?? '#fff', + aux: false, + opacity: 1, + }; + } + + if (!shouldSmooth) { + return metadataMap; + } + + for (const [id, metadata] of Object.entries(metadataMap)) { + const smoothedSeriesId = getSmoothedSeriesId(id); + metadataMap[smoothedSeriesId] = { + ...metadata, + id: smoothedSeriesId, + type: SeriesType.DERIVED, + aux: false, + originalSeriesId: id, + }; - for (const partitioned of namedPartitionedSeries) { - const { - seriesId, - runId, - displayName, - alias, - partitionIndex, - partitionSize, - } = partitioned; - - metadataMap[seriesId] = { - type: SeriesType.ORIGINAL, - id: seriesId, - alias, - displayName: - partitionSize > 1 - ? `${displayName}: ${partitionIndex}` - : displayName, - visible: Boolean(runSelectionMap && runSelectionMap.get(runId)), - color: colorMap[runId] ?? '#fff', - aux: false, - opacity: 1, - }; - } + metadata.aux = true; + metadata.opacity = 0.25; + } - if (!shouldSmooth) { return metadataMap; } - - for (const [id, metadata] of Object.entries(metadataMap)) { - const smoothedSeriesId = getSmoothedSeriesId(id); - metadataMap[smoothedSeriesId] = { - ...metadata, - id: smoothedSeriesId, - type: SeriesType.DERIVED, - aux: false, - originalSeriesId: id, - }; - - metadata.aux = true; - metadata.opacity = 0.25; - } - - return metadataMap; - }), + ), startWith({} as ScalarCardSeriesMetadataMap) ); diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts index d3bd65a5ce..1b678e1dcf 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts @@ -36,7 +36,7 @@ import {MatProgressSpinnerModule} from '@angular/material/progress-spinner'; import {By} from '@angular/platform-browser'; import {NoopAnimationsModule} from '@angular/platform-browser/animations'; import {Action, Store} from '@ngrx/store'; -import {MockStore, provideMockStore} from '@ngrx/store/testing'; +import {MockStore} from '@ngrx/store/testing'; import {Observable, of, ReplaySubject} from 'rxjs'; import {State} from '../../../app_state'; import {ExperimentAlias} from '../../../experiments/types'; @@ -99,8 +99,6 @@ import { getSingleSelectionHeaders, } from '../../store'; import { - appStateFromMetricsState, - buildMetricsState, buildScalarStepData, provideMockCardRunToSeriesData, } from '../../testing'; @@ -120,6 +118,8 @@ import { } from './scalar_card_types'; import {VisLinkedTimeSelectionWarningModule} from './vis_linked_time_selection_warning_module'; import {Extent} from '../../../widgets/line_chart_v2/lib/public_types'; +import {provideMockTbStore} from '../../../testing/utils'; +import * as commonSelectors from '../main_view/common_selectors'; @Component({ selector: 'line-chart', @@ -343,11 +343,7 @@ describe('scalar card', () => { TestableDataDownload, TestableLineChart, ], - providers: [ - provideMockStore({ - initialState: appStateFromMetricsState(buildMetricsState()), - }), - ], + providers: [provideMockTbStore()], schemas: [NO_ERRORS_SCHEMA], }).compileComponents(); @@ -482,6 +478,10 @@ describe('scalar card', () => { selectors.getCurrentRouteRunSelection, new Map([['run1', true]]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1']) + ); store.overrideSelector(selectors.getMetricsXAxisType, XAxisType.STEP); selectSpy .withArgs(selectors.getRun, {runId: 'run1'}) @@ -757,6 +757,10 @@ describe('scalar card', () => { selectors.getCurrentRouteRunSelection, new Map([['run1', true]]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1']) + ); }); const expectedPoints = { @@ -2412,6 +2416,7 @@ describe('scalar card', () => { end: null, }); store.refreshState(); + tick(); fixture.detectChanges(); testController.stopDrag(); @@ -2434,6 +2439,7 @@ describe('scalar card', () => { end: null, }); store.refreshState(); + tick(); fixture.detectChanges(); testController.stopDrag(); @@ -2816,6 +2822,10 @@ describe('scalar card', () => { ['run2', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 2}, @@ -2881,6 +2891,10 @@ describe('scalar card', () => { ['run2', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 1}, @@ -2955,6 +2969,10 @@ describe('scalar card', () => { selectors.getCurrentRouteRunSelection, new Map([['run1', true]]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1']) + ); store.overrideSelector(selectors.getMetricsScalarSmoothing, 0.3); @@ -3016,6 +3034,10 @@ describe('scalar card', () => { selectors.getCurrentRouteRunSelection, new Map([['run1', true]]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 2}, @@ -3079,6 +3101,10 @@ describe('scalar card', () => { ['run2', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 18}, @@ -3126,6 +3152,10 @@ describe('scalar card', () => { ['run2', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 2}, @@ -3174,6 +3204,10 @@ describe('scalar card', () => { ['run2', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 100}, @@ -3220,6 +3254,10 @@ describe('scalar card', () => { ['run2', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 1}, end: null, @@ -3265,6 +3303,10 @@ describe('scalar card', () => { ['run2', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2']) + ); store.overrideSelector(selectors.getExperimentIdToExperimentAliasMap, { eid1: {aliasText: 'test alias 1', aliasNumber: 100}, eid2: {aliasText: 'test alias 2', aliasNumber: 200}, @@ -3319,6 +3361,10 @@ describe('scalar card', () => { selectors.getCurrentRouteRunSelection, new Map([['run1', true]]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 20}, end: null, @@ -3362,6 +3408,10 @@ describe('scalar card', () => { ['run3', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2', 'run3']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 1}, @@ -3407,6 +3457,10 @@ describe('scalar card', () => { ['run3', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2', 'run3']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 1}, @@ -3460,6 +3514,10 @@ describe('scalar card', () => { ['run7', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2', 'run3', 'run4', 'run5', 'run6', 'run7']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 1}, @@ -3518,6 +3576,11 @@ describe('scalar card', () => { ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2', 'run3', 'run4', 'run5', 'run6', 'run7']) + ); + store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 1}, end: null, @@ -3644,6 +3707,10 @@ describe('scalar card', () => { ['run4', true], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2', 'run3', 'run4']) + ); store.overrideSelector(getMetricsLinkedTimeSelection, { start: {step: 2}, @@ -3669,6 +3736,10 @@ describe('scalar card', () => { ['run4', false], ]) ); + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIdsFromRoute, + new Set(['run1', 'run2']) + ); const fixture = createComponent('card1'); fixture.detectChanges(); const component = fixture.debugElement.query( @@ -3824,6 +3895,7 @@ describe('scalar card', () => { end: null, }); store.refreshState(); + tick(); fixture.detectChanges(); // One start fob @@ -3846,6 +3918,7 @@ describe('scalar card', () => { }); store.overrideSelector(getMetricsCardRangeSelectionEnabled, true); store.refreshState(); + tick(); fixture.detectChanges(); // One start fob, one end fob @@ -3921,6 +3994,7 @@ describe('scalar card', () => { end: null, }); store.refreshState(); + tick(); fixture.detectChanges(); testController.stopDrag(); @@ -3981,6 +4055,7 @@ describe('scalar card', () => { end: null, }); store.refreshState(); + tick(); fixture.detectChanges(); testController.stopDrag(); diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index fb132ab13d..d67a53d8f5 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -37,14 +37,16 @@ import { DomainType, IntervalFilter, } from '../../../hparams/types'; -import {RunTableItem} from '../../../runs/views/runs_table/types'; +import { + RunTableItem, + RunTableExperimentItem, +} from '../../../runs/views/runs_table/types'; import {matchRunToRegex} from '../../../util/matcher'; import {isSingleRunPlugin, PluginType} from '../../data_source'; import {getNonEmptyCardIdsWithMetadata, TagMetadata} from '../../store'; import {compareTagNames} from '../../utils'; import {CardIdWithMetadata} from '../metrics_view_types'; import {RouteKind} from '../../../app_routing/types'; -import {Run} from '../../../runs/types'; export const getScalarTagsForRunSelection = createSelector( getMetricsTagMetadata, @@ -176,7 +178,7 @@ const utils = { }, }; -export function getRenderableRuns(experimentIds: string[]) { +function getRenderableRuns(experimentIds: string[]) { return createSelector( getRunsFromExperimentIds(experimentIds), getExperimentNames(experimentIds), @@ -189,7 +191,7 @@ export function getRenderableRuns(experimentIds: string[]) { selectionMap, colorMap, experimentIdToAlias - ): Array => { + ): Array => { return runs.map((run) => { const hparamMap: RunTableItem['hparams'] = new Map(); (run.hparams || []).forEach((hparam) => { @@ -213,7 +215,7 @@ export function getRenderableRuns(experimentIds: string[]) { ); } -export function getFilteredRenderableRuns(experimentIds: string[]) { +function getFilteredRenderableRuns(experimentIds: string[]) { return createSelector( getRunSelectorRegexFilter, getRenderableRuns(experimentIds), @@ -244,6 +246,18 @@ export const getFilteredRenderableRunsFromRoute = createSelector( } ); +export const getFilteredRenderableRunsIdsFromRoute = createSelector( + getFilteredRenderableRunsFromRoute, + (filteredRenderableRuns) => { + return new Set(filteredRenderableRuns.map(({run: {id}}) => id)); + } +); + +export const factories = { + getRenderableRuns, + getFilteredRenderableRuns, +}; + export const TEST_ONLY = { getRenderableCardIdsWithMetadata, getScalarTagsForRunSelection, diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 26f85f4fa6..f227603ad2 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -695,12 +695,12 @@ describe('common selectors', () => { describe('getRenderableRuns', () => { it('returns all runs associated with experiment', () => { - const exp1Result = selectors.getRenderableRuns(['exp1'])(state); + const exp1Result = selectors.factories.getRenderableRuns(['exp1'])(state); expect(exp1Result.length).toEqual(2); expect(exp1Result[0].run).toEqual({...run1, experimentId: 'exp1'}); expect(exp1Result[1].run).toEqual({...run2, experimentId: 'exp1'}); - const exp2Result = selectors.getRenderableRuns(['exp2'])(state); + const exp2Result = selectors.factories.getRenderableRuns(['exp2'])(state); expect(exp2Result.length).toEqual(3); expect(exp2Result[0].run).toEqual({...run2, experimentId: 'exp2'}); expect(exp2Result[1].run).toEqual({...run3, experimentId: 'exp2'}); @@ -708,7 +708,9 @@ describe('common selectors', () => { }); it('returns two runs when a run is associated with multiple experiments', () => { - const result = selectors.getRenderableRuns(['exp1', 'exp2'])(state); + const result = selectors.factories.getRenderableRuns(['exp1', 'exp2'])( + state + ); expect(result.length).toEqual(5); expect(result[0].run).toEqual({...run1, experimentId: 'exp1'}); expect(result[1].run).toEqual({...run2, experimentId: 'exp1'}); @@ -718,7 +720,7 @@ describe('common selectors', () => { }); it('returns empty list when no experiments are provided', () => { - expect(selectors.getRenderableRuns([])(state)).toEqual([]); + expect(selectors.factories.getRenderableRuns([])(state)).toEqual([]); }); }); @@ -900,14 +902,18 @@ describe('common selectors', () => { describe('getFilteredRenderableRuns', () => { it('does not use experiment alias when route is not compare', () => { state.runs!.data.regexFilter = 'foo'; - const result = selectors.getFilteredRenderableRuns(['exp1'])(state); + const result = selectors.factories.getFilteredRenderableRuns(['exp1'])( + state + ); expect(result).toEqual([]); }); it('uses experiment alias when route is compare', () => { state.runs!.data.regexFilter = 'foo'; state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; - const result = selectors.getFilteredRenderableRuns(['exp1'])(state); + const result = selectors.factories.getFilteredRenderableRuns(['exp1'])( + state + ); expect(result.length).toEqual(2); expect(result[0].run.name).toEqual('run 1'); expect(result[1].run.name).toEqual('run 2'); @@ -918,12 +924,16 @@ describe('common selectors', () => { selectors.TEST_ONLY.utils, 'filterRunItemsByHparamAndMetricFilter' ).and.callThrough(); - const results = selectors.getFilteredRenderableRuns(['exp1'])(state); + const results = selectors.factories.getFilteredRenderableRuns(['exp1'])( + state + ); expect(spy).toHaveBeenCalledOnceWith(results, new Map(), new Map()); }); it('returns empty list when no experiments are provided', () => { - expect(selectors.getFilteredRenderableRuns([])(state)).toEqual([]); + expect(selectors.factories.getFilteredRenderableRuns([])(state)).toEqual( + [] + ); }); }); @@ -932,15 +942,30 @@ describe('common selectors', () => { state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; const result = selectors.getFilteredRenderableRunsFromRoute(state); expect(result).toEqual( - selectors.getFilteredRenderableRuns(['exp1', 'exp2'])(state) + selectors.factories.getFilteredRenderableRuns(['exp1', 'exp2'])(state) ); }); it('calls getFilteredRenderableRuns with experiment ids from the route when in single experiment view', () => { const result = selectors.getFilteredRenderableRunsFromRoute(state); expect(result).toEqual( - selectors.getFilteredRenderableRuns(['defaultExperimentId'])(state) + selectors.factories.getFilteredRenderableRuns(['defaultExperimentId'])( + state + ) ); }); }); + + describe('getFilteredRenderableRunsIdsFromRoute', () => { + it('returns a set of run ids from the route when in compare view', () => { + state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; + const result = selectors.getFilteredRenderableRunsIdsFromRoute(state); + expect(result).toEqual(new Set(['1', '2', '3', '4'])); + }); + + it('returns a set of run ids from the route when in single experiment view', () => { + const result = selectors.getFilteredRenderableRunsIdsFromRoute(state); + expect(result).toEqual(new Set()); + }); + }); }); diff --git a/tensorboard/webapp/runs/views/runs_table/types.ts b/tensorboard/webapp/runs/views/runs_table/types.ts index 0e2753cc24..0cbe0eb113 100644 --- a/tensorboard/webapp/runs/views/runs_table/types.ts +++ b/tensorboard/webapp/runs/views/runs_table/types.ts @@ -44,3 +44,8 @@ export interface RunTableItem { hparams: Map; metrics: Map; } + +export interface RunTableExperimentItem extends RunTableItem { + run: Run & {experimentId: string}; + runColor: string; +}