diff --git a/tensorboard/webapp/metrics/store/BUILD b/tensorboard/webapp/metrics/store/BUILD index 372039fffb..8db5c763f9 100644 --- a/tensorboard/webapp/metrics/store/BUILD +++ b/tensorboard/webapp/metrics/store/BUILD @@ -28,6 +28,7 @@ tf_ts_library( "//tensorboard/webapp/types", "//tensorboard/webapp/util:dom", "//tensorboard/webapp/util:lang", + "//tensorboard/webapp/util:memoize", "//tensorboard/webapp/util:ngrx", "//tensorboard/webapp/util:types", "//tensorboard/webapp/widgets/card_fob:types", diff --git a/tensorboard/webapp/metrics/store/metrics_selectors.ts b/tensorboard/webapp/metrics/store/metrics_selectors.ts index eb83cfddb1..ffb705aa76 100644 --- a/tensorboard/webapp/metrics/store/metrics_selectors.ts +++ b/tensorboard/webapp/metrics/store/metrics_selectors.ts @@ -48,6 +48,7 @@ import { } from './metrics_types'; import {ColumnHeader, DataTableMode} from '../../widgets/data_table/types'; import {Extent} from '../../widgets/line_chart_v2/lib/public_types'; +import {memoize} from '../../util/memoize'; const selectMetricsState = createFeatureSelector(METRICS_FEATURE_KEY); @@ -405,20 +406,6 @@ export const getMetricsStepMinMax = createSelector( } ); -export const getSingleSelectionHeaders = createSelector( - selectMetricsState, - (state: MetricsState): ColumnHeader[] => { - return state.singleSelectionHeaders; - } -); - -export const getRangeSelectionHeaders = createSelector( - selectMetricsState, - (state: MetricsState): ColumnHeader[] => { - return state.rangeSelectionHeaders; - } -); - /** * Returns value of the linked time set by user. When linked time selection is never * set, it returns the default value which is derived from the timeseries data @@ -628,3 +615,30 @@ export const getMetricsCardTimeSelection = createSelector( ); } ); + +export const getSingleSelectionHeaders = createSelector( + selectMetricsState, + (state: MetricsState): ColumnHeader[] => { + return state.singleSelectionHeaders; + } +); + +export const getRangeSelectionHeaders = createSelector( + selectMetricsState, + (state: MetricsState): ColumnHeader[] => { + return state.rangeSelectionHeaders; + } +); + +export const getColumnHeadersForCard = memoize((cardId: string) => { + return createSelector( + (state) => state, + getSingleSelectionHeaders, + getRangeSelectionHeaders, + (state, singleSelectionHeaders, rangeSelectionHeaders) => { + return getMetricsCardRangeSelectionEnabled(state, cardId) + ? rangeSelectionHeaders + : singleSelectionHeaders; + } + ); +}); diff --git a/tensorboard/webapp/metrics/store/metrics_selectors_test.ts b/tensorboard/webapp/metrics/store/metrics_selectors_test.ts index 254f3e5806..0307d7f8bd 100644 --- a/tensorboard/webapp/metrics/store/metrics_selectors_test.ts +++ b/tensorboard/webapp/metrics/store/metrics_selectors_test.ts @@ -26,7 +26,11 @@ import { createTimeSeriesData, } from '../testing'; import {HistogramMode, TooltipSort, XAxisType} from '../types'; -import {DataTableMode} from '../../widgets/data_table/types'; +import { + ColumnHeader, + ColumnHeaderType, + DataTableMode, +} from '../../widgets/data_table/types'; import * as selectors from './metrics_selectors'; import {CardFeatureOverride, MetricsState} from './metrics_types'; @@ -1533,4 +1537,170 @@ describe('metrics selectors', () => { ); }); }); + + describe('getSingleSelectionHeaders', () => { + it('returns all single selection headers', () => { + const state = appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders: [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ], + }) + ); + expect(selectors.getSingleSelectionHeaders(state)).toEqual([ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]); + }); + }); + + describe('getRangeSelectionHeaders', () => { + it('returns all range selection headers', () => { + const state = appStateFromMetricsState( + buildMetricsState({ + rangeSelectionHeaders: [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ], + }) + ); + expect(selectors.getRangeSelectionHeaders(state)).toEqual([ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]); + }); + }); + + describe('getColumnHeadersForCard', () => { + let singleSelectionHeaders: ColumnHeader[]; + let rangeSelectionHeaders: ColumnHeader[]; + + beforeEach(() => { + singleSelectionHeaders = [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]; + rangeSelectionHeaders = [ + { + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }, + ]; + }); + + it('returns single selection headers when card range selection is disabled', () => { + expect( + selectors.getColumnHeadersForCard('card1')( + appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + }) + ) + ) + ).toEqual(singleSelectionHeaders); + expect( + selectors.getColumnHeadersForCard('card1')( + appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: + CardFeatureOverride.OVERRIDE_AS_DISABLED, + }, + }, + }) + ) + ) + ).toEqual(singleSelectionHeaders); + }); + + it('returns range selection headers when card range selection is enabled', () => { + expect( + selectors.getColumnHeadersForCard('card1')( + appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: + CardFeatureOverride.OVERRIDE_AS_ENABLED, + }, + }, + }) + ) + ) + ).toEqual(rangeSelectionHeaders); + }); + + it('returns range selection headers when global range selection is enabled', () => { + expect( + selectors.getColumnHeadersForCard('card1')( + appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + rangeSelectionEnabled: true, + }) + ) + ) + ).toEqual(rangeSelectionHeaders); + }); + }); }); diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts index 77a33002bf..848865fd5f 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts @@ -57,6 +57,7 @@ import { getRun, getRunColorMap, getCurrentRouteRunSelection, + getColumnHeadersForCard, } from '../../../selectors'; import {DataLoadState} from '../../../types/data'; import { @@ -89,8 +90,6 @@ import { getMetricsScalarSmoothing, getMetricsTooltipSort, getMetricsXAxisType, - getRangeSelectionHeaders, - getSingleSelectionHeaders, RunToSeries, } from '../../store'; import {CardId, CardMetadata, HeaderEditInfo, XAxisType} from '../../types'; @@ -461,18 +460,8 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy { this.cardId ); - this.columnHeaders$ = combineLatest([ - this.stepOrLinkedTimeSelection$, - this.store.select(getSingleSelectionHeaders), - this.store.select(getRangeSelectionHeaders), - ]).pipe( - map(([timeSelection, singleSelectionHeaders, rangeSelectionHeaders]) => { - if (!timeSelection || timeSelection.end === null) { - return singleSelectionHeaders; - } else { - return rangeSelectionHeaders; - } - }) + this.columnHeaders$ = this.store.select( + getColumnHeadersForCard(this.cardId) ); this.chartMetadataMap$ = partitionedSeries$.pipe( diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts index e6c264f21d..c133e43cdc 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts @@ -2668,6 +2668,8 @@ describe('scalar card', () => { describe('getTimeSelectionTableData', () => { beforeEach(() => { store.overrideSelector(getMetricsLinkedTimeEnabled, true); + // These tests now rely on the selector getColumnHeadersForCard which in turn + // relies on these selectors. store.overrideSelector(getSingleSelectionHeaders, [ { type: ColumnHeaderType.RUN, @@ -2880,6 +2882,7 @@ describe('scalar card', () => { runToSeries ); store.overrideSelector(getMetricsRangeSelectionEnabled, true); + store.overrideSelector(getMetricsCardRangeSelectionEnabled, true); store.overrideSelector( selectors.getCurrentRouteRunSelection, new Map([ @@ -2961,6 +2964,7 @@ describe('scalar card', () => { runToSeries ); store.overrideSelector(getMetricsRangeSelectionEnabled, true); + store.overrideSelector(getMetricsCardRangeSelectionEnabled, true); store.overrideSelector( selectors.getCurrentRouteRunSelection, new Map([['run1', true]]) @@ -3026,6 +3030,7 @@ describe('scalar card', () => { runToSeries ); store.overrideSelector(getMetricsRangeSelectionEnabled, true); + store.overrideSelector(getMetricsCardRangeSelectionEnabled, true); store.overrideSelector( selectors.getCurrentRouteRunSelection, new Map([['run1', true]]) @@ -3141,6 +3146,7 @@ describe('scalar card', () => { runToSeries ); store.overrideSelector(getMetricsRangeSelectionEnabled, true); + store.overrideSelector(getMetricsCardRangeSelectionEnabled, true); store.overrideSelector( selectors.getCurrentRouteRunSelection, new Map([