From 132762e37a1d1b300427cdf49e022031a0715e03 Mon Sep 17 00:00:00 2001 From: Stephan Lee Date: Fri, 24 Jul 2020 15:01:33 -0700 Subject: [PATCH 1/3] migrate: projector The plugin is fully function: `bazel run tensorboard/plugins/polymer3/vz_projector:standalone` works end to end. --- package.json | 10 +- .../polymer3/tf_projector_plugin/BUILD | 7 +- .../tf-projector-plugin.html | 4 +- .../projector/polymer3/vz_projector/BUILD | 121 +- .../polymer3/vz_projector/analyticsLogger.ts | 90 +- .../polymer3/vz_projector/bh_tsne.ts | 1045 +++++------ .../polymer3/vz_projector/bundle.html | 49 - .../projector/polymer3/vz_projector/bundle.ts | 27 + .../vz_projector/data-provider-demo.ts | 240 +-- .../vz_projector/data-provider-proto.ts | 192 +- .../vz_projector/data-provider-server.ts | 269 ++- .../polymer3/vz_projector/data-provider.ts | 865 +++++---- .../projector/polymer3/vz_projector/data.ts | 1364 +++++++------- .../polymer3/vz_projector/external.d.ts | 32 - .../projector/polymer3/vz_projector/heap.ts | 247 ++- .../projector/polymer3/vz_projector/knn.ts | 460 +++-- .../projector/polymer3/vz_projector/label.ts | 251 ++- .../polymer3/vz_projector/logging.ts | 168 +- .../vz_projector/projectorEventContext.ts | 69 +- .../projectorScatterPlotAdapter.ts | 1409 +++++++------- .../polymer3/vz_projector/renderContext.ts | 94 +- .../polymer3/vz_projector/scatterPlot.ts | 1425 +++++++------- .../scatterPlotRectangleSelector.ts | 175 +- .../vz_projector/scatterPlotVisualizer.ts | 64 +- .../scatterPlotVisualizer3DLabels.ts | 601 +++--- .../scatterPlotVisualizerCanvasLabels.ts | 325 ++-- .../scatterPlotVisualizerPolylines.ts | 247 ++- .../scatterPlotVisualizerSprites.ts | 630 +++---- .../projector/polymer3/vz_projector/sptree.ts | 280 ++- .../{standalone.html => standalone_lib.html} | 11 +- .../polymer3/vz_projector/styles.html | 184 -- .../projector/polymer3/vz_projector/styles.ts | 182 ++ .../polymer3/vz_projector/test/BUILD | 1 + .../projector/polymer3/vz_projector/util.ts | 467 +++-- .../projector/polymer3/vz_projector/vector.ts | 505 +++-- ...projector-app.html => vz-projector-app.ts} | 83 +- .../vz-projector-bookmark-panel.html | 220 --- .../vz-projector-bookmark-panel.html.ts | 211 +++ .../vz-projector-bookmark-panel.ts | 489 +++-- ...shboard.html => vz-projector-dashboard.ts} | 89 +- .../vz_projector/vz-projector-data-panel.html | 677 ------- .../vz-projector-data-panel.html.ts | 652 +++++++ .../vz_projector/vz-projector-data-panel.ts | 1438 +++++++-------- .../vz_projector/vz-projector-input.html | 71 - .../vz_projector/vz-projector-input.ts | 220 ++- .../vz-projector-inspector-panel.html | 351 ---- .../vz-projector-inspector-panel.html.ts | 333 ++++ .../vz-projector-inspector-panel.ts | 951 +++++----- .../vz_projector/vz-projector-legend.html | 84 - .../vz_projector/vz-projector-legend.ts | 193 +- .../vz-projector-metadata-card.html | 104 -- .../vz-projector-metadata-card.ts | 188 +- ...=> vz-projector-projections-panel.html.ts} | 34 +- .../vz-projector-projections-panel.ts | 1384 +++++++------- .../vz_projector/vz-projector-util.ts | 37 - ...vz-projector.html => vz-projector.html.ts} | 51 +- .../polymer3/vz_projector/vz-projector.ts | 1320 +++++++------ yarn.lock | 1631 +++++++++++------ 58 files changed, 11128 insertions(+), 11793 deletions(-) delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/bundle.html create mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/bundle.ts rename tensorboard/plugins/projector/polymer3/vz_projector/{standalone.html => standalone_lib.html} (78%) delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/styles.html create mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/styles.ts rename tensorboard/plugins/projector/polymer3/vz_projector/{vz-projector-app.html => vz-projector-app.ts} (59%) delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.html create mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.html.ts rename tensorboard/plugins/projector/polymer3/vz_projector/{vz-projector-dashboard.html => vz-projector-dashboard.ts} (61%) delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.html create mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.html.ts delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-input.html delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.html create mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.html.ts delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-legend.html delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-metadata-card.html rename tensorboard/plugins/projector/polymer3/vz_projector/{vz-projector-projections-panel.html => vz-projector-projections-panel.html.ts} (93%) delete mode 100644 tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-util.ts rename tensorboard/plugins/projector/polymer3/vz_projector/{vz-projector.html => vz-projector.html.ts} (84%) diff --git a/package.json b/package.json index a5d0e79df8..50675b7fb5 100644 --- a/package.json +++ b/package.json @@ -88,10 +88,10 @@ "@polymer/paper-checkbox": "^3.1.0", "@polymer/paper-dialog": "^3.0.1", "@polymer/paper-dialog-scrollable": "^3.0.1", - "@polymer/paper-dropdown-menu": "^3.0.1", + "@polymer/paper-dropdown-menu": "^3.1.0", "@polymer/paper-header-panel": "^3.0.1", "@polymer/paper-icon-button": "^3.0.2", - "@polymer/paper-input": "^3.1.0", + "@polymer/paper-input": "^3.2.1", "@polymer/paper-item": "^3.0.1", "@polymer/paper-listbox": "^3.0.1", "@polymer/paper-material": "^3.0.1", @@ -100,16 +100,22 @@ "@polymer/paper-spinner": "^3.0.2", "@polymer/paper-styles": "^3.0.1", "@polymer/paper-tabs": "^3.1.0", + "@polymer/paper-toast": "^3.0.1", "@polymer/paper-toggle-button": "^3.0.1", "@polymer/paper-toolbar": "^3.0.1", + "@polymer/paper-tooltip": "^3.0.1", "@polymer/polymer": "^3.4.1", "d3": "5.7.0", "lodash": "^4.17.19", "monaco-editor-core": "^0.20.0", "monaco-languages": "^1.10.0", "plottable": "^3.9.0", + "numericjs": "^1.2.6", "requirejs": "^2.3.6", "rxjs": "^6.5.5", + "three": "~0.108.0", + "umap-js": "^1.3.2", + "weblas": "^0.9.1", "zone.js": "^0.10.2" } } diff --git a/tensorboard/plugins/projector/polymer3/tf_projector_plugin/BUILD b/tensorboard/plugins/projector/polymer3/tf_projector_plugin/BUILD index 36e7a608b4..c1d781de2a 100644 --- a/tensorboard/plugins/projector/polymer3/tf_projector_plugin/BUILD +++ b/tensorboard/plugins/projector/polymer3/tf_projector_plugin/BUILD @@ -18,12 +18,13 @@ tf_web_library( tensorboard_html_binary( name = "projector_binary", - compile = True, + compile = False, input_path = "/tf-projector/tf-projector-plugin.html", js_path = "/projector_binary.js", output_path = "/tf-projector/projector_binary.html", deps = [ ":tf_projector_plugin", + "//tensorboard/plugins/projector/polymer3/vz_projector:standalone_lib", ], ) @@ -34,9 +35,7 @@ tf_web_library( ], path = "/tf-projector", deps = [ - "//tensorboard/components:security", - "//tensorboard/components/tf_imports:polymer", - "//tensorboard/plugins/projector/vz_projector", + "//tensorboard/plugins/projector/polymer3/vz_projector:standalone_lib", "@com_google_fonts_roboto", ], ) diff --git a/tensorboard/plugins/projector/polymer3/tf_projector_plugin/tf-projector-plugin.html b/tensorboard/plugins/projector/polymer3/tf_projector_plugin/tf-projector-plugin.html index b15056bc5a..d82bf34640 100644 --- a/tensorboard/plugins/projector/polymer3/tf_projector_plugin/tf-projector-plugin.html +++ b/tensorboard/plugins/projector/polymer3/tf_projector_plugin/tf-projector-plugin.html @@ -15,10 +15,7 @@ limitations under the License. --> - - - + diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/BUILD b/tensorboard/plugins/projector/polymer3/vz_projector/BUILD index 9baf12ddde..3d97a328cf 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/BUILD +++ b/tensorboard/plugins/projector/polymer3/vz_projector/BUILD @@ -1,3 +1,4 @@ +load("//tensorboard/defs:defs.bzl", "tf_js_binary", "tf_ts_library") load("//tensorboard/defs:web.bzl", "tf_web_library") load("//tensorboard/defs:vulcanize.bzl", "tensorboard_html_binary") @@ -5,11 +6,11 @@ package(default_visibility = ["//tensorboard:internal"]) licenses(["notice"]) # Apache 2.0 -tf_web_library( +tf_ts_library( name = "vz_projector", srcs = [ "analyticsLogger.ts", - "bundle.html", + "bundle.ts", "data.ts", "data-provider.ts", "data-provider-demo.ts", @@ -29,104 +30,110 @@ tf_web_library( "scatterPlotVisualizerCanvasLabels.ts", "scatterPlotVisualizerPolylines.ts", "scatterPlotVisualizerSprites.ts", - "styles.html", + "styles.ts", "umap.d.ts", "util.ts", "vector.ts", - "vz-projector.html", + "vz-projector.html.ts", "vz-projector.ts", - "vz-projector-app.html", - "vz-projector-bookmark-panel.html", + "vz-projector-app.ts", + "vz-projector-bookmark-panel.html.ts", "vz-projector-bookmark-panel.ts", - "vz-projector-dashboard.html", - "vz-projector-data-panel.html", + "vz-projector-dashboard.ts", + "vz-projector-data-panel.html.ts", "vz-projector-data-panel.ts", - "vz-projector-input.html", "vz-projector-input.ts", - "vz-projector-inspector-panel.html", + "vz-projector-inspector-panel.html.ts", "vz-projector-inspector-panel.ts", - "vz-projector-legend.html", "vz-projector-legend.ts", - "vz-projector-metadata-card.html", "vz-projector-metadata-card.ts", - "vz-projector-projections-panel.html", + "vz-projector-projections-panel.html.ts", "vz-projector-projections-panel.ts", - "vz-projector-util.ts", ], - path = "/vz-projector", + strict_checks = False, deps = [ ":bh_tsne", ":heap", ":sptree", - "//tensorboard/components/tf_backend", - "//tensorboard/components/tf_dashboard_common", - "//tensorboard/components/tf_imports:d3", - "//tensorboard/components/tf_imports:numericjs", - "//tensorboard/components/tf_imports:polymer", - "//tensorboard/components/tf_imports:threejs", - "//tensorboard/components/tf_imports:umap-js", - "//tensorboard/components/tf_imports:weblas", - "//tensorboard/components/tf_tensorboard:registry", - "@org_polymer_iron_collapse", - "@org_polymer_iron_icons", - "@org_polymer_paper_button", - "@org_polymer_paper_checkbox", - "@org_polymer_paper_dialog", - "@org_polymer_paper_dialog_scrollable", - "@org_polymer_paper_dropdown_menu", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_input", - "@org_polymer_paper_item", - "@org_polymer_paper_listbox", - "@org_polymer_paper_slider", - "@org_polymer_paper_spinner", - "@org_polymer_paper_styles", - "@org_polymer_paper_toast", - "@org_polymer_paper_toggle_button", - "@org_polymer_paper_tooltip", + "//tensorboard/components_polymer3/polymer:register_style_dom_module", + "@npm//@polymer/decorators", + "@npm//@polymer/iron-collapse", + "@npm//@polymer/iron-icons", + "@npm//@polymer/iron-iconset-svg", + "@npm//@polymer/paper-button", + "@npm//@polymer/paper-checkbox", + "@npm//@polymer/paper-dialog", + "@npm//@polymer/paper-dialog-scrollable", + "@npm//@polymer/paper-dropdown-menu", + "@npm//@polymer/paper-icon-button", + "@npm//@polymer/paper-input", + "@npm//@polymer/paper-item", + "@npm//@polymer/paper-listbox", + "@npm//@polymer/paper-slider", + "@npm//@polymer/paper-spinner", + "@npm//@polymer/paper-styles", + "@npm//@polymer/paper-toast", + "@npm//@polymer/paper-toggle-button", + "@npm//@polymer/paper-tooltip", + "@npm//@polymer/polymer", + "@npm//d3", + "@npm//numericjs", + "@npm//three", + "@npm//umap-js", + "@npm//weblas", ], ) -tf_web_library( +tf_ts_library( name = "heap", srcs = ["heap.ts"], - path = "/vz-projector", + strict_checks = False, ) -tf_web_library( +tf_ts_library( name = "sptree", srcs = ["sptree.ts"], - path = "/vz-projector", + strict_checks = False, ) -tf_web_library( +tf_ts_library( name = "bh_tsne", srcs = ["bh_tsne.ts"], - path = "/vz-projector", + strict_checks = False, deps = [":sptree"], ) +tf_js_binary( + name = "standalone_bundle", + compile = 1, + entry_point = "bundle.ts", + deps = [ + ":vz_projector", + ], +) + ################# Standalone development ################# tf_web_library( name = "standalone_lib", srcs = [ - "standalone.html", + "standalone_lib.html", "standalone_projector_config.json", + ":standalone_bundle.js", ], path = "/", deps = [ - ":vz_projector", - "//tensorboard/components/tf_imports:polymer", - "@org_polymer_iron_icons", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_tooltip", + "@com_google_fonts_roboto", ], ) tensorboard_html_binary( - name = "devserver", - input_path = "/standalone.html", - output_path = "/index.html", - deps = [":standalone_lib"], + name = "standalone", + compile = False, + input_path = "/standalone_lib.html", + js_path = "/standalone_bundle.js", + output_path = "/standalone.html", + deps = [ + ":standalone_lib", + ], ) diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/analyticsLogger.ts b/tensorboard/plugins/projector/polymer3/vz_projector/analyticsLogger.ts index 8c9a904ed0..a6a533db23 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/analyticsLogger.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/analyticsLogger.ts @@ -12,56 +12,52 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - export class AnalyticsLogger { - private eventLogging: boolean; - private pageViewLogging: boolean; +import {ProjectionType} from './data'; - /** - * Constructs an event logger using Google Analytics. It assumes there is a - * Google Analytics script added to the page elsewhere. If there is no such - * script, the logger acts as a no-op. - * - * @param pageViewLogging Whether to log page views. - * @param eventLogging Whether to log user interaction. - */ - constructor(pageViewLogging: boolean, eventLogging: boolean) { - if (typeof ga === 'undefined' || ga == null) { - this.eventLogging = false; - this.pageViewLogging = false; - return; - } - this.eventLogging = eventLogging; - this.pageViewLogging = pageViewLogging; +export class AnalyticsLogger { + private eventLogging: boolean; + private pageViewLogging: boolean; + /** + * Constructs an event logger using Google Analytics. It assumes there is a + * Google Analytics script added to the page elsewhere. If there is no such + * script, the logger acts as a no-op. + * + * @param pageViewLogging Whether to log page views. + * @param eventLogging Whether to log user interaction. + */ + constructor(pageViewLogging: boolean, eventLogging: boolean) { + if (typeof ga === 'undefined' || ga == null) { + this.eventLogging = false; + this.pageViewLogging = false; + return; } - - logPageView(pageTitle: string) { - if (this.pageViewLogging) { - // Always send a page view. - ga('send', {hitType: 'pageview', page: `/v/${pageTitle}`}); - } + this.eventLogging = eventLogging; + this.pageViewLogging = pageViewLogging; + } + logPageView(pageTitle: string) { + if (this.pageViewLogging) { + // Always send a page view. + ga('send', {hitType: 'pageview', page: `/v/${pageTitle}`}); } - - logProjectionChanged(projection: ProjectionType) { - if (this.eventLogging) { - ga('send', { - hitType: 'event', - eventCategory: 'Projection', - eventAction: 'click', - eventLabel: projection, - }); - } + } + logProjectionChanged(projection: ProjectionType) { + if (this.eventLogging) { + ga('send', { + hitType: 'event', + eventCategory: 'Projection', + eventAction: 'click', + eventLabel: projection, + }); } - - logWebGLDisabled() { - if (this.eventLogging) { - ga('send', { - hitType: 'event', - eventCategory: 'Error', - eventAction: 'PageLoad', - eventLabel: 'WebGL_disabled', - }); - } + } + logWebGLDisabled() { + if (this.eventLogging) { + ga('send', { + hitType: 'event', + eventCategory: 'Error', + eventAction: 'PageLoad', + eventLabel: 'WebGL_disabled', + }); } } -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/bh_tsne.ts b/tensorboard/plugins/projector/polymer3/vz_projector/bh_tsne.ts index a34946f7d1..ab7ed2bdc2 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/bh_tsne.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/bh_tsne.ts @@ -12,593 +12,558 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - /** - * This is a fork of the Karpathy's TSNE.js (original license below). - * This fork implements Barnes-Hut approximation and runs in O(NlogN) - * time, as opposed to the Karpathy's O(N^2) version. - * - * @author smilkov@google.com (Daniel Smilkov) - */ - - /** - * @license - * The MIT License (MIT) - * Copyright (c) 2015 Andrej Karpathy - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - - type AugmSPNode = SPNode & {numCells: number; yCell: number[]; rCell: number}; - - /** - * Barnes-hut approximation level. Higher means more approximation and faster - * results. Recommended value mentioned in the paper is 0.8. - */ - const THETA = 0.8; - - const MIN_POSSIBLE_PROB = 1e-9; - - // Variables used for memorizing the second random number since running - // gaussRandom() generates two random numbers at the cost of 1 atomic - // computation. This optimization results in 2X speed-up of the generator. - let return_v = false; - let v_val = 0.0; - - /** Returns the square euclidean distance between two vectors. */ - export function dist2(a: number[], b: number[]): number { - if (a.length !== b.length) { - throw new Error('Vectors a and b must be of same length'); - } - - let result = 0; - for (let i = 0; i < a.length; ++i) { - let diff = a[i] - b[i]; - result += diff * diff; - } - return result; - } - - /** Returns the square euclidean distance between two 2D points. */ - export function dist2_2D(a: number[], b: number[]): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - return dX * dX + dY * dY; +/** + * This is a fork of the Karpathy's TSNE.js (original license below). + * This fork implements Barnes-Hut approximation and runs in O(NlogN) + * time, as opposed to the Karpathy's O(N^2) version. + * + * @author smilkov@google.com (Daniel Smilkov) + */ + +/** + * @license + * The MIT License (MIT) + * Copyright (c) 2015 Andrej Karpathy + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +import {SPNode, SPTree} from './sptree'; + +type AugmSPNode = SPNode & { + numCells: number; + yCell: number[]; + rCell: number; +}; +/** + * Barnes-hut approximation level. Higher means more approximation and faster + * results. Recommended value mentioned in the paper is 0.8. + */ +const THETA = 0.8; +const MIN_POSSIBLE_PROB = 1e-9; +// Variables used for memorizing the second random number since running +// gaussRandom() generates two random numbers at the cost of 1 atomic +// computation. This optimization results in 2X speed-up of the generator. +let return_v = false; +let v_val = 0; +/** Returns the square euclidean distance between two vectors. */ +export function dist2(a: number[], b: number[]): number { + if (a.length !== b.length) { + throw new Error('Vectors a and b must be of same length'); } - - /** Returns the square euclidean distance between two 3D points. */ - export function dist2_3D(a: number[], b: number[]): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - let dZ = a[2] - b[2]; - return dX * dX + dY * dY + dZ * dZ; + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; } - - function gaussRandom(rng: () => number): number { - if (return_v) { - return_v = false; - return v_val; - } - let u = 2 * rng() - 1; - let v = 2 * rng() - 1; - let r = u * u + v * v; - if (r === 0 || r > 1) { - return gaussRandom(rng); - } - let c = Math.sqrt((-2 * Math.log(r)) / r); - v_val = v * c; // cache this for next function call for efficiency - return_v = true; - return u * c; + return result; +} +/** Returns the square euclidean distance between two 2D points. */ +export function dist2_2D(a: number[], b: number[]): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} +/** Returns the square euclidean distance between two 3D points. */ +export function dist2_3D(a: number[], b: number[]): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + let dZ = a[2] - b[2]; + return dX * dX + dY * dY + dZ * dZ; +} +function gaussRandom(rng: () => number): number { + if (return_v) { + return_v = false; + return v_val; } - - // return random normal number - function randn(rng: () => number, mu: number, std: number) { - return mu + gaussRandom(rng) * std; + let u = 2 * rng() - 1; + let v = 2 * rng() - 1; + let r = u * u + v * v; + if (r === 0 || r > 1) { + return gaussRandom(rng); } - - // utilitity that creates contiguous vector of zeros of size n - function zeros(n: number): Float64Array { - return new Float64Array(n); + let c = Math.sqrt((-2 * Math.log(r)) / r); + v_val = v * c; // cache this for next function call for efficiency + return_v = true; + return u * c; +} +// return random normal number +function randn(rng: () => number, mu: number, std: number) { + return mu + gaussRandom(rng) * std; +} +// utilitity that creates contiguous vector of zeros of size n +function zeros(n: number): Float64Array { + return new Float64Array(n); +} +// utility that returns a matrix filled with random numbers +// generated by the provided generator. +function randnMatrix(n: number, d: number, rng: () => number) { + let nd = n * d; + let x = zeros(nd); + for (let i = 0; i < nd; ++i) { + x[i] = randn(rng, 0, 0.0001); } - - // utility that returns a matrix filled with random numbers - // generated by the provided generator. - function randnMatrix(n: number, d: number, rng: () => number) { - let nd = n * d; - let x = zeros(nd); - for (let i = 0; i < nd; ++i) { - x[i] = randn(rng, 0.0, 1e-4); - } - return x; + return x; +} +// utility that returns a matrix filled with the provided value. +function arrayofs(n: number, d: number, val: number) { + let x: number[][] = []; + for (let i = 0; i < n; ++i) { + x.push(d === 3 ? [val, val, val] : [val, val]); } - - // utility that returns a matrix filled with the provided value. - function arrayofs(n: number, d: number, val: number) { - let x: number[][] = []; - for (let i = 0; i < n; ++i) { - x.push(d === 3 ? [val, val, val] : [val, val]); - } - return x; - } - - // compute (p_{i|j} + p_{j|i})/(2n) - function nearest2P( - nearest: {index: number; dist: number}[][], - perplexity: number, - tol: number - ) { - let N = nearest.length; - let Htarget = Math.log(perplexity); // target entropy of distribution - let P = zeros(N * N); // temporary probability matrix - let K = nearest[0].length; - let pRow: number[] = new Array(K); // pij[]. - - for (let i = 0; i < N; ++i) { - let neighbors = nearest[i]; - let betaMin = -Infinity; - let betaMax = Infinity; - let beta = 1; // initial value of precision - let maxTries = 50; - - // perform binary search to find a suitable precision beta - // so that the entropy of the distribution is appropriate - let numTries = 0; - while (true) { - // compute entropy and kernel row with beta precision - let psum = 0.0; - for (let k = 0; k < neighbors.length; ++k) { - let neighbor = neighbors[k]; - let pij = i === neighbor.index ? 0 : Math.exp(-neighbor.dist * beta); - pij = Math.max(pij, MIN_POSSIBLE_PROB); - pRow[k] = pij; - psum += pij; - } - // normalize p and compute entropy - let Hhere = 0.0; - for (let k = 0; k < pRow.length; ++k) { - pRow[k] /= psum; - let pij = pRow[k]; - if (pij > 1e-7) { - Hhere -= pij * Math.log(pij); - } + return x; +} +// compute (p_{i|j} + p_{j|i})/(2n) +function nearest2P( + nearest: { + index: number; + dist: number; + }[][], + perplexity: number, + tol: number +) { + let N = nearest.length; + let Htarget = Math.log(perplexity); // target entropy of distribution + let P = zeros(N * N); // temporary probability matrix + let K = nearest[0].length; + let pRow: number[] = new Array(K); // pij[]. + for (let i = 0; i < N; ++i) { + let neighbors = nearest[i]; + let betaMin = -Infinity; + let betaMax = Infinity; + let beta = 1; // initial value of precision + let maxTries = 50; + // perform binary search to find a suitable precision beta + // so that the entropy of the distribution is appropriate + let numTries = 0; + while (true) { + // compute entropy and kernel row with beta precision + let psum = 0; + for (let k = 0; k < neighbors.length; ++k) { + let neighbor = neighbors[k]; + let pij = i === neighbor.index ? 0 : Math.exp(-neighbor.dist * beta); + pij = Math.max(pij, MIN_POSSIBLE_PROB); + pRow[k] = pij; + psum += pij; + } + // normalize p and compute entropy + let Hhere = 0; + for (let k = 0; k < pRow.length; ++k) { + pRow[k] /= psum; + let pij = pRow[k]; + if (pij > 1e-7) { + Hhere -= pij * Math.log(pij); } - - // adjust beta based on result - if (Hhere > Htarget) { - // entropy was too high (distribution too diffuse) - // so we need to increase the precision for more peaky distribution - betaMin = beta; // move up the bounds - if (betaMax === Infinity) { - beta = beta * 2; - } else { - beta = (beta + betaMax) / 2; - } + } + // adjust beta based on result + if (Hhere > Htarget) { + // entropy was too high (distribution too diffuse) + // so we need to increase the precision for more peaky distribution + betaMin = beta; // move up the bounds + if (betaMax === Infinity) { + beta = beta * 2; } else { - // converse case. make distrubtion less peaky - betaMax = beta; - if (betaMin === -Infinity) { - beta = beta / 2; - } else { - beta = (beta + betaMin) / 2; - } + beta = (beta + betaMax) / 2; } - numTries++; - // stopping conditions: too many tries or got a good precision - if (numTries >= maxTries || Math.abs(Hhere - Htarget) < tol) { - break; + } else { + // converse case. make distrubtion less peaky + betaMax = beta; + if (betaMin === -Infinity) { + beta = beta / 2; + } else { + beta = (beta + betaMin) / 2; } } - - // copy over the final prow to P at row i - for (let k = 0; k < pRow.length; ++k) { - let pij = pRow[k]; - let j = neighbors[k].index; - P[i * N + j] = pij; - } - } // end loop over examples i - - // symmetrize P and normalize it to sum to 1 over all ij - let N2 = N * 2; - for (let i = 0; i < N; ++i) { - for (let j = i + 1; j < N; ++j) { - let i_j = i * N + j; - let j_i = j * N + i; - let value = (P[i_j] + P[j_i]) / N2; - P[i_j] = value; - P[j_i] = value; + numTries++; + // stopping conditions: too many tries or got a good precision + if (numTries >= maxTries || Math.abs(Hhere - Htarget) < tol) { + break; } } - return P; - } - - // helper function - function sign(x: number) { - return x > 0 ? 1 : x < 0 ? -1 : 0; + // copy over the final prow to P at row i + for (let k = 0; k < pRow.length; ++k) { + let pij = pRow[k]; + let j = neighbors[k].index; + P[i * N + j] = pij; + } + } // end loop over examples i + // symmetrize P and normalize it to sum to 1 over all ij + let N2 = N * 2; + for (let i = 0; i < N; ++i) { + for (let j = i + 1; j < N; ++j) { + let i_j = i * N + j; + let j_i = j * N + i; + let value = (P[i_j] + P[j_i]) / N2; + P[i_j] = value; + P[j_i] = value; + } } - - function computeForce_2d( + return P; +} +// helper function +function sign(x: number) { + return x > 0 ? 1 : x < 0 ? -1 : 0; +} +function computeForce_2d( + force: number[], + mult: number, + pointA: number[], + pointB: number[] +) { + force[0] += mult * (pointA[0] - pointB[0]); + force[1] += mult * (pointA[1] - pointB[1]); +} +function computeForce_3d( + force: number[], + mult: number, + pointA: number[], + pointB: number[] +) { + force[0] += mult * (pointA[0] - pointB[0]); + force[1] += mult * (pointA[1] - pointB[1]); + force[2] += mult * (pointA[2] - pointB[2]); +} +export interface TSNEOptions { + /** How many dimensions. */ + dim: number; + /** Roughly how many neighbors each point influences. */ + perplexity?: number; + /** Learning rate. */ + epsilon?: number; + /** A random number generator. */ + rng?: () => number; +} +export class TSNE { + private perplexity: number; + private epsilon: number; + private superviseFactor: number; + private unlabeledClass: string; + private labels: string[]; + private labelCounts: { + [key: string]: number; + }; + /** Random generator */ + private rng: () => number; + private iter = 0; + private Y: Float64Array; + private N: number; + private P: Float64Array; + private gains: number[][]; + private ystep: number[][]; + private nearest: { + index: number; + dist: number; + }[][]; + private dim: number; + private dist2: (a: number[], b: number[]) => number; + private computeForce: ( force: number[], mult: number, pointA: number[], pointB: number[] - ) { - force[0] += mult * (pointA[0] - pointB[0]); - force[1] += mult * (pointA[1] - pointB[1]); + ) => void; + constructor(opt: TSNEOptions) { + opt = opt || {dim: 2}; + this.perplexity = opt.perplexity || 30; + this.epsilon = opt.epsilon || 10; + this.rng = opt.rng || Math.random; + this.dim = opt.dim; + if (opt.dim === 2) { + this.dist2 = dist2_2D; + this.computeForce = computeForce_2d; + } else if (opt.dim === 3) { + this.dist2 = dist2_3D; + this.computeForce = computeForce_3d; + } else { + throw new Error('Only 2D and 3D is supported'); + } } - - function computeForce_3d( - force: number[], - mult: number, - pointA: number[], - pointB: number[] + // this function takes a fattened distance matrix and creates + // matrix P from them. + // D is assumed to be provided as an array of size N^2. + initDataDist( + nearest: { + index: number; + dist: number; + }[][] ) { - force[0] += mult * (pointA[0] - pointB[0]); - force[1] += mult * (pointA[1] - pointB[1]); - force[2] += mult * (pointA[2] - pointB[2]); + let N = nearest.length; + this.nearest = nearest; + this.P = nearest2P(nearest, this.perplexity, 0.0001); + this.N = N; + this.initSolution(); // refresh this } - - export interface TSNEOptions { - /** How many dimensions. */ - dim: number; - /** Roughly how many neighbors each point influences. */ - perplexity?: number; - /** Learning rate. */ - epsilon?: number; - /** A random number generator. */ - rng?: () => number; + // (re)initializes the solution to random + initSolution() { + // generate random solution to t-SNE + this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution + this.gains = arrayofs(this.N, this.dim, 1); // step gains + // to accelerate progress in unchanging directions + this.ystep = arrayofs(this.N, this.dim, 0); // momentum accumulator + this.iter = 0; } - - export class TSNE { - private perplexity: number; - private epsilon: number; - private superviseFactor: number; - private unlabeledClass: string; - private labels: string[]; - private labelCounts: {[key: string]: number}; - /** Random generator */ - private rng: () => number; - private iter = 0; - private Y: Float64Array; - private N: number; - private P: Float64Array; - private gains: number[][]; - private ystep: number[][]; - private nearest: {index: number; dist: number}[][]; - private dim: number; - private dist2: (a: number[], b: number[]) => number; - private computeForce: ( - force: number[], - mult: number, - pointA: number[], - pointB: number[] - ) => void; - - constructor(opt: TSNEOptions) { - opt = opt || {dim: 2}; - this.perplexity = opt.perplexity || 30; - this.epsilon = opt.epsilon || 10; - this.rng = opt.rng || Math.random; - this.dim = opt.dim; - if (opt.dim === 2) { - this.dist2 = dist2_2D; - this.computeForce = computeForce_2d; - } else if (opt.dim === 3) { - this.dist2 = dist2_3D; - this.computeForce = computeForce_3d; - } else { - throw new Error('Only 2D and 3D is supported'); + getDim() { + return this.dim; + } + // return pointer to current solution + getSolution() { + return this.Y; + } + // For each point, randomly offset point within a 5% hypersphere centered + // around it, whilst remaining in the assumed t-SNE plot hypersphere + perturb() { + let N = this.N; + let maxArea = 0; + let ymean = this.dim === 3 ? [0, 0, 0] : [0, 0]; + // Determine radius of t-SNE hypersphere, assumed zero mean and normalized + // dimensions. Here area is proportional to pi*radius^2, to skip root calc. + for (let i = 0; i < N; ++i) { + let area = 0; + for (let d = 0; d < this.dim; ++d) { + area += Math.pow(this.Y[i * this.dim + d], 2); + } + if (area > maxArea) { + maxArea = area; } } - - // this function takes a fattened distance matrix and creates - // matrix P from them. - // D is assumed to be provided as an array of size N^2. - initDataDist(nearest: {index: number; dist: number}[][]) { - let N = nearest.length; - this.nearest = nearest; - this.P = nearest2P(nearest, this.perplexity, 1e-4); - this.N = N; - this.initSolution(); // refresh this - } - - // (re)initializes the solution to random - initSolution() { - // generate random solution to t-SNE - this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution - this.gains = arrayofs(this.N, this.dim, 1.0); // step gains - // to accelerate progress in unchanging directions - this.ystep = arrayofs(this.N, this.dim, 0.0); // momentum accumulator - this.iter = 0; - } - - getDim() { - return this.dim; - } - - // return pointer to current solution - getSolution() { - return this.Y; - } - - // For each point, randomly offset point within a 5% hypersphere centered - // around it, whilst remaining in the assumed t-SNE plot hypersphere - perturb() { - let N = this.N; - let maxArea = 0; - let ymean = this.dim === 3 ? [0, 0, 0] : [0, 0]; - - // Determine radius of t-SNE hypersphere, assumed zero mean and normalized - // dimensions. Here area is proportional to pi*radius^2, to skip root calc. - for (let i = 0; i < N; ++i) { + let maxRadius = Math.pow(maxArea, 0.5); + for (let i = 0; i < N; ++i) { + let diff = new Array(this.dim); + // Find a perturbation of point that fits inside t-SNE hypersphere + while (true) { let area = 0; - for (let d = 0; d < this.dim; ++d) { - area += Math.pow(this.Y[i * this.dim + d], 2); + diff[d] = 0.1 * maxRadius * (Math.random() - 0.5); + area += Math.pow(this.Y[i * this.dim + d] + diff[d], 2); } - - if (area > maxArea) { - maxArea = area; + if (area < maxArea) { + break; } } - - let maxRadius = Math.pow(maxArea, 0.5); - - for (let i = 0; i < N; ++i) { - let diff = new Array(this.dim); - - // Find a perturbation of point that fits inside t-SNE hypersphere - while (true) { - let area = 0; - - for (let d = 0; d < this.dim; ++d) { - diff[d] = 0.1 * maxRadius * (Math.random() - 0.5); - area += Math.pow(this.Y[i * this.dim + d] + diff[d], 2); - } - - if (area < maxArea) { - break; - } - } - - // Apply offset to point - for (let d = 0; d < this.dim; ++d) { - this.Y[i * this.dim + d] += diff[d]; - ymean[d] += this.Y[i * this.dim + d]; - } - } - - // reproject Y to be zero mean - for (let i = 0; i < N; ++i) { - for (let d = 0; d < this.dim; ++d) { - this.Y[i * this.dim + d] -= ymean[d] / N; - } + // Apply offset to point + for (let d = 0; d < this.dim; ++d) { + this.Y[i * this.dim + d] += diff[d]; + ymean[d] += this.Y[i * this.dim + d]; } } - - // perform a single step of optimization to improve the embedding - step() { - this.iter += 1; - let N = this.N; - - let grad = this.costGrad(this.Y); // evaluate gradient - - // perform gradient step - let ymean = this.dim === 3 ? [0, 0, 0] : [0, 0]; - for (let i = 0; i < N; ++i) { - for (let d = 0; d < this.dim; ++d) { - let gid = grad[i][d]; - let sid = this.ystep[i][d]; - let gainid = this.gains[i][d]; - - // compute gain update - let newgain = sign(gid) === sign(sid) ? gainid * 0.8 : gainid + 0.2; - if (newgain < 0.01) { - newgain = 0.01; // clamp - } - this.gains[i][d] = newgain; // store for next turn - - // compute momentum step direction - let momval = this.iter < 250 ? 0.5 : 0.8; - let newsid = momval * sid - this.epsilon * newgain * grad[i][d]; - this.ystep[i][d] = newsid; // remember the step we took - - // step! - let i_d = i * this.dim + d; - this.Y[i_d] += newsid; - ymean[d] += this.Y[i_d]; // accumulate mean so that we - // can center later - } + // reproject Y to be zero mean + for (let i = 0; i < N; ++i) { + for (let d = 0; d < this.dim; ++d) { + this.Y[i * this.dim + d] -= ymean[d] / N; } - - // reproject Y to be zero mean - for (let i = 0; i < N; ++i) { - for (let d = 0; d < this.dim; ++d) { - this.Y[i * this.dim + d] -= ymean[d] / N; + } + } + // perform a single step of optimization to improve the embedding + step() { + this.iter += 1; + let N = this.N; + let grad = this.costGrad(this.Y); // evaluate gradient + // perform gradient step + let ymean = this.dim === 3 ? [0, 0, 0] : [0, 0]; + for (let i = 0; i < N; ++i) { + for (let d = 0; d < this.dim; ++d) { + let gid = grad[i][d]; + let sid = this.ystep[i][d]; + let gainid = this.gains[i][d]; + // compute gain update + let newgain = sign(gid) === sign(sid) ? gainid * 0.8 : gainid + 0.2; + if (newgain < 0.01) { + newgain = 0.01; // clamp } + this.gains[i][d] = newgain; // store for next turn + // compute momentum step direction + let momval = this.iter < 250 ? 0.5 : 0.8; + let newsid = momval * sid - this.epsilon * newgain * grad[i][d]; + this.ystep[i][d] = newsid; // remember the step we took + // step! + let i_d = i * this.dim + d; + this.Y[i_d] += newsid; + ymean[d] += this.Y[i_d]; // accumulate mean so that we + // can center later } } - - setSupervision(superviseLabels: string[], superviseInput?: string) { - if (superviseLabels != null) { - this.labels = superviseLabels; - this.labelCounts = {}; - let uniqueEntries = Array.from(new Set(superviseLabels)); - uniqueEntries.forEach((l) => (this.labelCounts[l] = 0)); - superviseLabels.forEach((l) => (this.labelCounts[l] += 1)); - } - if (superviseInput != null) { - this.unlabeledClass = superviseInput; + // reproject Y to be zero mean + for (let i = 0; i < N; ++i) { + for (let d = 0; d < this.dim; ++d) { + this.Y[i * this.dim + d] -= ymean[d] / N; } } - - setSuperviseFactor(superviseFactor: number) { - if (superviseFactor != null) { - this.superviseFactor = superviseFactor; + } + setSupervision(superviseLabels: string[], superviseInput?: string) { + if (superviseLabels != null) { + this.labels = superviseLabels; + this.labelCounts = {}; + let uniqueEntries = Array.from(new Set(superviseLabels)); + uniqueEntries.forEach((l) => (this.labelCounts[l] = 0)); + superviseLabels.forEach((l) => (this.labelCounts[l] += 1)); + } + if (superviseInput != null) { + this.unlabeledClass = superviseInput; + } + } + setSuperviseFactor(superviseFactor: number) { + if (superviseFactor != null) { + this.superviseFactor = superviseFactor; + } + } + // return cost and gradient, given an arrangement + costGrad(Y: Float64Array): number[][] { + let N = this.N; + let P = this.P; + // Trick that helps with local optima. + let alpha = this.iter < 100 ? 4 : 1; + let superviseFactor = this.superviseFactor / 100; // set in range [0, 1] + let unlabeledClass = this.unlabeledClass; + let labels = this.labels; + let labelCounts = this.labelCounts; + let supervised = + superviseFactor != null && + superviseFactor > 0 && + labels != null && + labelCounts != null; + let unlabeledCount = + supervised && unlabeledClass != null && unlabeledClass !== '' + ? labelCounts[unlabeledClass] + : 0; + // Make data for the SP tree. + let points: number[][] = new Array(N); // (x, y)[] + for (let i = 0; i < N; ++i) { + let iTimesD = i * this.dim; + let row = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + row[d] = Y[iTimesD + d]; } + points[i] = row; } - - // return cost and gradient, given an arrangement - costGrad(Y: Float64Array): number[][] { - let N = this.N; - let P = this.P; - - // Trick that helps with local optima. - let alpha = this.iter < 100 ? 4 : 1; - - let superviseFactor = this.superviseFactor / 100; // set in range [0, 1] - let unlabeledClass = this.unlabeledClass; - let labels = this.labels; - let labelCounts = this.labelCounts; - let supervised = - superviseFactor != null && - superviseFactor > 0 && - labels != null && - labelCounts != null; - let unlabeledCount = - supervised && unlabeledClass != null && unlabeledClass !== '' - ? labelCounts[unlabeledClass] - : 0; - - // Make data for the SP tree. - let points: number[][] = new Array(N); // (x, y)[] - for (let i = 0; i < N; ++i) { - let iTimesD = i * this.dim; - let row = new Array(this.dim); - for (let d = 0; d < this.dim; ++d) { - row[d] = Y[iTimesD + d]; - } - points[i] = row; + // Make a tree. + let tree = new SPTree(points); + let root = tree.root as AugmSPNode; + // Annotate the tree. + let annotateTree = ( + node: AugmSPNode + ): { + numCells: number; + yCell: number[]; + } => { + let numCells = 1; + if (node.children == null) { + // Update the current node and tell the parent. + node.numCells = numCells; + node.yCell = node.point; + return {numCells, yCell: node.yCell}; } - - // Make a tree. - let tree = new SPTree(points); - let root = tree.root as AugmSPNode; - // Annotate the tree. - - let annotateTree = ( - node: AugmSPNode - ): {numCells: number; yCell: number[]} => { - let numCells = 1; - if (node.children == null) { - // Update the current node and tell the parent. - node.numCells = numCells; - node.yCell = node.point; - return {numCells, yCell: node.yCell}; + // node.point is a 2 or 3-dim number[], so slice() makes a copy. + let yCell = node.point.slice(); + for (let i = 0; i < node.children.length; ++i) { + let child = node.children[i]; + if (child == null) { + continue; } - // node.point is a 2 or 3-dim number[], so slice() makes a copy. - let yCell = node.point.slice(); - for (let i = 0; i < node.children.length; ++i) { - let child = node.children[i]; - if (child == null) { - continue; - } - let result = annotateTree(child as AugmSPNode); - numCells += result.numCells; - for (let d = 0; d < this.dim; ++d) { - yCell[d] += result.yCell[d]; - } + let result = annotateTree(child as AugmSPNode); + numCells += result.numCells; + for (let d = 0; d < this.dim; ++d) { + yCell[d] += result.yCell[d]; } - // Update the node and tell the parent. - node.numCells = numCells; - node.yCell = yCell.map((v) => v / numCells); - return {numCells, yCell}; - }; - - // Augment the tree with more info. - annotateTree(root); - tree.visit((node: AugmSPNode, low: number[], high: number[]) => { - node.rCell = high[0] - low[0]; - return false; - }); - // compute current Q distribution, unnormalized first - let grad: number[][] = []; - let Z = 0; - let sum_pij = 0; - let forces: [number[], number[]][] = new Array(N); - for (let i = 0; i < N; ++i) { - let pointI = points[i]; + } + // Update the node and tell the parent. + node.numCells = numCells; + node.yCell = yCell.map((v) => v / numCells); + return {numCells, yCell}; + }; + // Augment the tree with more info. + annotateTree(root); + tree.visit((node: AugmSPNode, low: number[], high: number[]) => { + node.rCell = high[0] - low[0]; + return false; + }); + // compute current Q distribution, unnormalized first + let grad: number[][] = []; + let Z = 0; + let sum_pij = 0; + let forces: [number[], number[]][] = new Array(N); + for (let i = 0; i < N; ++i) { + let pointI = points[i]; + if (supervised) { + var sameCount = labelCounts[labels[i]]; + var otherCount = N - sameCount - unlabeledCount; + } + // Compute the positive forces for the i-th node. + let Fpos = this.dim === 3 ? [0, 0, 0] : [0, 0]; + let neighbors = this.nearest[i]; + for (let k = 0; k < neighbors.length; ++k) { + let j = neighbors[k].index; + let pij = P[i * N + j]; + // apply semi-supervised prior probabilities if (supervised) { - var sameCount = labelCounts[labels[i]]; - var otherCount = N - sameCount - unlabeledCount; - } - // Compute the positive forces for the i-th node. - let Fpos = this.dim === 3 ? [0, 0, 0] : [0, 0]; - let neighbors = this.nearest[i]; - for (let k = 0; k < neighbors.length; ++k) { - let j = neighbors[k].index; - let pij = P[i * N + j]; - // apply semi-supervised prior probabilities - if (supervised) { - if (labels[i] === unlabeledClass || labels[j] === unlabeledClass) { - pij *= 1 / N; - } else if (labels[i] !== labels[j]) { - pij *= Math.max(1 / N - superviseFactor / otherCount, 1e-7); - } else if (labels[i] === labels[j]) { - pij *= Math.min(1 / N + superviseFactor / sameCount, 1 - 1e-7); - } - sum_pij += pij; + if (labels[i] === unlabeledClass || labels[j] === unlabeledClass) { + pij *= 1 / N; + } else if (labels[i] !== labels[j]) { + pij *= Math.max(1 / N - superviseFactor / otherCount, 1e-7); + } else if (labels[i] === labels[j]) { + pij *= Math.min(1 / N + superviseFactor / sameCount, 1 - 1e-7); } - let pointJ = points[j]; - let squaredDistItoJ = this.dist2(pointI, pointJ); - let premult = pij / (1 + squaredDistItoJ); - this.computeForce(Fpos, premult, pointI, pointJ); + sum_pij += pij; } - // Compute the negative forces for the i-th node. - let FnegZ = this.dim === 3 ? [0, 0, 0] : [0, 0]; - tree.visit((node: AugmSPNode) => { - let squaredDistToCell = this.dist2(pointI, node.yCell); - // Squared distance from point i to cell. - if ( - node.children == null || - (squaredDistToCell > 0 && - node.rCell / Math.sqrt(squaredDistToCell) < THETA) - ) { - let qijZ = 1 / (1 + squaredDistToCell); - let dZ = node.numCells * qijZ; - Z += dZ; - dZ *= qijZ; - this.computeForce(FnegZ, dZ, pointI, node.yCell); - return true; - } - // Cell is too close to approximate. - let squaredDistToPoint = this.dist2(pointI, node.point); - let qijZ = 1 / (1 + squaredDistToPoint); - Z += qijZ; - qijZ *= qijZ; - this.computeForce(FnegZ, qijZ, pointI, node.point); - return false; - }, true); - forces[i] = [Fpos, FnegZ]; + let pointJ = points[j]; + let squaredDistItoJ = this.dist2(pointI, pointJ); + let premult = pij / (1 + squaredDistItoJ); + this.computeForce(Fpos, premult, pointI, pointJ); } - // Normalize the negative forces and compute the gradient. - let A = 4 * alpha; - if (supervised) { - A /= sum_pij; - } - const B = 4 / Z; - for (let i = 0; i < N; ++i) { - let [FPos, FNegZ] = forces[i]; - let gsum = new Array(this.dim); - for (let d = 0; d < this.dim; ++d) { - gsum[d] = A * FPos[d] - B * FNegZ[d]; + // Compute the negative forces for the i-th node. + let FnegZ = this.dim === 3 ? [0, 0, 0] : [0, 0]; + tree.visit((node: AugmSPNode) => { + let squaredDistToCell = this.dist2(pointI, node.yCell); + // Squared distance from point i to cell. + if ( + node.children == null || + (squaredDistToCell > 0 && + node.rCell / Math.sqrt(squaredDistToCell) < THETA) + ) { + let qijZ = 1 / (1 + squaredDistToCell); + let dZ = node.numCells * qijZ; + Z += dZ; + dZ *= qijZ; + this.computeForce(FnegZ, dZ, pointI, node.yCell); + return true; } - grad.push(gsum); + // Cell is too close to approximate. + let squaredDistToPoint = this.dist2(pointI, node.point); + let qijZ = 1 / (1 + squaredDistToPoint); + Z += qijZ; + qijZ *= qijZ; + this.computeForce(FnegZ, qijZ, pointI, node.point); + return false; + }, true); + forces[i] = [Fpos, FnegZ]; + } + // Normalize the negative forces and compute the gradient. + let A = 4 * alpha; + if (supervised) { + A /= sum_pij; + } + const B = 4 / Z; + for (let i = 0; i < N; ++i) { + let [FPos, FNegZ] = forces[i]; + let gsum = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + gsum[d] = A * FPos[d] - B * FNegZ[d]; } - return grad; + grad.push(gsum); } + return grad; } -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/bundle.html b/tensorboard/plugins/projector/polymer3/vz_projector/bundle.html deleted file mode 100644 index cf2b73284e..0000000000 --- a/tensorboard/plugins/projector/polymer3/vz_projector/bundle.html +++ /dev/null @@ -1,49 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/bundle.ts b/tensorboard/plugins/projector/polymer3/vz_projector/bundle.ts new file mode 100644 index 0000000000..66ba5f4104 --- /dev/null +++ b/tensorboard/plugins/projector/polymer3/vz_projector/bundle.ts @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import '@polymer/iron-iconset-svg'; +import './styles'; +import './vz-projector-app'; +import './vz-projector-bookmark-panel'; +import './vz-projector-dashboard'; +import './vz-projector-data-panel'; +import './vz-projector-inspector-panel'; +import './vz-projector-input'; +import './vz-projector-legend'; +import './vz-projector-projections-panel'; +import './vz-projector-metadata-card'; +import './vz-projector'; diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-demo.ts b/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-demo.ts index b4d9857e99..105f70002e 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-demo.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-demo.ts @@ -12,135 +12,135 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - const BYTES_EXTENSION = '.bytes'; +import {SpriteAndMetadataInfo, State, DataSet} from './data'; +import { + DataProvider, + EmbeddingInfo, + parseTensors, + ProjectorConfig, + retrieveSpriteAndMetadataInfo, + retrieveTensorAsBytes, + TENSORS_MSG_ID, +} from './data-provider'; +import * as logging from './logging'; - /** Data provider that loads data from a demo folder. */ - export class DemoDataProvider implements DataProvider { - private projectorConfigPath: string; - private projectorConfig: ProjectorConfig; +const BYTES_EXTENSION = '.bytes'; - constructor(projectorConfigPath: string) { - this.projectorConfigPath = projectorConfigPath; - } - - private getEmbeddingInfo(tensorName: string): EmbeddingInfo { - let embeddings = this.projectorConfig.embeddings; - for (let i = 0; i < embeddings.length; i++) { - let embedding = embeddings[i]; - if (embedding.tensorName === tensorName) { - return embedding; - } +/** Data provider that loads data from a demo folder. */ +export class DemoDataProvider implements DataProvider { + private projectorConfigPath: string; + private projectorConfig: ProjectorConfig; + constructor(projectorConfigPath: string) { + this.projectorConfigPath = projectorConfigPath; + } + private getEmbeddingInfo(tensorName: string): EmbeddingInfo { + let embeddings = this.projectorConfig.embeddings; + for (let i = 0; i < embeddings.length; i++) { + let embedding = embeddings[i]; + if (embedding.tensorName === tensorName) { + return embedding; } - return null; } - - retrieveRuns(callback: (runs: string[]) => void): void { - callback(['Demo']); - } - - retrieveProjectorConfig( - run: string, - callback: (d: ProjectorConfig) => void - ): void { - const msgId = logging.setModalMessage('Fetching projector config...'); - - const xhr = new XMLHttpRequest(); - xhr.open('GET', this.projectorConfigPath); - xhr.onerror = (err) => { - let errorMessage = err.message; - // If the error is a valid XMLHttpResponse, it's possible this is a - // cross-origin error. - if (xhr.responseText != null) { - errorMessage = - 'Cannot fetch projector config, possibly a ' + - 'Cross-Origin request error.'; - } - logging.setErrorMessage(errorMessage, 'fetching projector config'); - }; - xhr.onload = () => { - const projectorConfig = JSON.parse(xhr.responseText) as ProjectorConfig; - logging.setModalMessage(null, msgId); - this.projectorConfig = projectorConfig; - callback(projectorConfig); - }; - xhr.send(); - } - - retrieveTensor( - run: string, - tensorName: string, - callback: (ds: DataSet) => void - ) { - let embedding = this.getEmbeddingInfo(tensorName); - let url = `${embedding.tensorPath}`; - if ( - embedding.tensorPath.substr(-1 * BYTES_EXTENSION.length) === - BYTES_EXTENSION - ) { - retrieveTensorAsBytes( - this, - this.getEmbeddingInfo(tensorName), - run, - tensorName, - url, - callback - ); - } else { - logging.setModalMessage('Fetching tensors...', TENSORS_MSG_ID); - const request = new XMLHttpRequest(); - request.open('GET', url); - request.responseType = 'arraybuffer'; - - request.onerror = () => { - logging.setErrorMessage(request.responseText, 'fetching tensors'); - }; - request.onload = () => { - parseTensors(request.response).then((points) => { - callback(new DataSet(points)); - }); - }; - request.send(); + return null; + } + retrieveRuns(callback: (runs: string[]) => void): void { + callback(['Demo']); + } + retrieveProjectorConfig( + run: string, + callback: (d: ProjectorConfig) => void + ): void { + const msgId = logging.setModalMessage('Fetching projector config...'); + const xhr = new XMLHttpRequest(); + xhr.open('GET', this.projectorConfigPath); + xhr.onerror = (err: any) => { + let errorMessage = err.message; + // If the error is a valid XMLHttpResponse, it's possible this is a + // cross-origin error. + if (xhr.responseText != null) { + errorMessage = + 'Cannot fetch projector config, possibly a ' + + 'Cross-Origin request error.'; } - } - - retrieveSpriteAndMetadata( - run: string, - tensorName: string, - callback: (r: SpriteAndMetadataInfo) => void + logging.setErrorMessage(errorMessage, 'fetching projector config'); + }; + xhr.onload = () => { + const projectorConfig = JSON.parse(xhr.responseText) as ProjectorConfig; + logging.setModalMessage(null, msgId); + this.projectorConfig = projectorConfig; + callback(projectorConfig); + }; + xhr.send(); + } + retrieveTensor( + run: string, + tensorName: string, + callback: (ds: DataSet) => void + ) { + let embedding = this.getEmbeddingInfo(tensorName); + let url = `${embedding.tensorPath}`; + if ( + embedding.tensorPath.substr(-1 * BYTES_EXTENSION.length) === + BYTES_EXTENSION ) { - let embedding = this.getEmbeddingInfo(tensorName); - let spriteImagePath = null; - if (embedding.sprite && embedding.sprite.imagePath) { - spriteImagePath = embedding.sprite.imagePath; - } - retrieveSpriteAndMetadataInfo( - embedding.metadataPath, - spriteImagePath, - embedding.sprite, + retrieveTensorAsBytes( + this, + this.getEmbeddingInfo(tensorName), + run, + tensorName, + url, callback ); - } - - getBookmarks( - run: string, - tensorName: string, - callback: (r: State[]) => void - ) { - let embedding = this.getEmbeddingInfo(tensorName); - let msgId = logging.setModalMessage('Fetching bookmarks...'); - - const xhr = new XMLHttpRequest(); - xhr.open('GET', embedding.bookmarksPath); - xhr.onerror = (err) => { - logging.setErrorMessage(xhr.responseText); + } else { + logging.setModalMessage('Fetching tensors...', TENSORS_MSG_ID); + const request = new XMLHttpRequest(); + request.open('GET', url); + request.responseType = 'arraybuffer'; + request.onerror = () => { + logging.setErrorMessage(request.responseText, 'fetching tensors'); }; - xhr.onload = () => { - const bookmarks = JSON.parse(xhr.responseText) as State[]; - logging.setModalMessage(null, msgId); - callback(bookmarks); + request.onload = () => { + parseTensors(request.response).then((points) => { + callback(new DataSet(points)); + }); }; - xhr.send(); + request.send(); } } -} // namespace vz_projector + retrieveSpriteAndMetadata( + run: string, + tensorName: string, + callback: (r: SpriteAndMetadataInfo) => void + ) { + let embedding = this.getEmbeddingInfo(tensorName); + let spriteImagePath = null; + if (embedding.sprite && embedding.sprite.imagePath) { + spriteImagePath = embedding.sprite.imagePath; + } + retrieveSpriteAndMetadataInfo( + embedding.metadataPath, + spriteImagePath, + embedding.sprite, + callback + ); + } + getBookmarks( + run: string, + tensorName: string, + callback: (r: State[]) => void + ) { + let embedding = this.getEmbeddingInfo(tensorName); + let msgId = logging.setModalMessage('Fetching bookmarks...'); + const xhr = new XMLHttpRequest(); + xhr.open('GET', embedding.bookmarksPath); + xhr.onerror = (err) => { + logging.setErrorMessage(xhr.responseText); + }; + xhr.onload = () => { + const bookmarks = JSON.parse(xhr.responseText) as State[]; + logging.setModalMessage(null, msgId); + callback(bookmarks); + }; + xhr.send(); + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-proto.ts b/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-proto.ts index fcbb19980e..075d460e4f 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-proto.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-proto.ts @@ -12,107 +12,105 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - export class ProtoDataProvider implements DataProvider { - private dataProto: DataProto; +import { + DataSet, + DataProto, + State, + DataPoint, + SpriteAndMetadataInfo, + PointMetadata, +} from './data'; +import {analyzeMetadata, DataProvider, ProjectorConfig} from './data-provider'; - constructor(dataProto: DataProto) { - this.dataProto = dataProto; - } - - retrieveRuns(callback: (runs: string[]) => void): void { - callback(['proto']); - } - - retrieveProjectorConfig( - run: string, - callback: (d: ProjectorConfig) => void - ) { - callback({ - modelCheckpointPath: 'proto', - embeddings: [ - { - tensorName: 'proto', - tensorShape: this.dataProto.shape, - metadataPath: 'proto', - }, - ], - }); - } - - retrieveTensor( - run: string, - tensorName: string, - callback: (ds: DataSet) => void - ) { - callback(this.flatArrayToDataset(this.dataProto.tensor)); - } - - retrieveSpriteAndMetadata( - run: string, - tensorName: string, - callback: (r: SpriteAndMetadataInfo) => void - ): void { - let columnNames = this.dataProto.metadata.columns.map((c) => c.name); - let n = this.dataProto.shape[0]; - let pointsMetadata: PointMetadata[] = new Array(n); - this.dataProto.metadata.columns.forEach((c) => { - let values = c.numericValues || c.stringValues; - for (let i = 0; i < n; i++) { - pointsMetadata[i] = pointsMetadata[i] || {}; - pointsMetadata[i][c.name] = values[i]; - } - }); - let spritesPromise: Promise = Promise.resolve(null); - if (this.dataProto.metadata.sprite != null) { - spritesPromise = new Promise((resolve, reject) => { - const image = new Image(); - image.onload = () => resolve(image); - image.onerror = () => reject('Failed converting base64 to an image'); - image.src = this.dataProto.metadata.sprite.imageBase64; - }); +export class ProtoDataProvider implements DataProvider { + private dataProto: DataProto; + constructor(dataProto: DataProto) { + this.dataProto = dataProto; + } + retrieveRuns(callback: (runs: string[]) => void): void { + callback(['proto']); + } + retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void) { + callback({ + modelCheckpointPath: 'proto', + embeddings: [ + { + tensorName: 'proto', + tensorShape: this.dataProto.shape, + metadataPath: 'proto', + }, + ], + }); + } + retrieveTensor( + run: string, + tensorName: string, + callback: (ds: DataSet) => void + ) { + callback(this.flatArrayToDataset(this.dataProto.tensor)); + } + retrieveSpriteAndMetadata( + run: string, + tensorName: string, + callback: (r: SpriteAndMetadataInfo) => void + ): void { + let columnNames = this.dataProto.metadata.columns.map((c) => c.name); + let n = this.dataProto.shape[0]; + let pointsMetadata: PointMetadata[] = new Array(n); + this.dataProto.metadata.columns.forEach((c) => { + let values = c.numericValues || c.stringValues; + for (let i = 0; i < n; i++) { + pointsMetadata[i] = pointsMetadata[i] || {}; + pointsMetadata[i][c.name] = values[i]; } - spritesPromise.then((image) => { - const result: SpriteAndMetadataInfo = { - stats: analyzeMetadata(columnNames, pointsMetadata), - pointsInfo: pointsMetadata, - }; - if (image != null) { - result.spriteImage = image; - result.spriteMetadata = { - singleImageDim: this.dataProto.metadata.sprite.singleImageDim, - imagePath: 'proto', - }; - } - callback(result); + }); + let spritesPromise: Promise = Promise.resolve(null); + if (this.dataProto.metadata.sprite != null) { + spritesPromise = new Promise((resolve, reject) => { + const image = new Image(); + image.onload = () => resolve(image); + image.onerror = () => reject('Failed converting base64 to an image'); + image.src = this.dataProto.metadata.sprite.imageBase64; }); } - - getBookmarks( - run: string, - tensorName: string, - callback: (r: State[]) => void - ): void { - return callback([]); - } - - private flatArrayToDataset(tensor: number[]): DataSet { - let points: DataPoint[] = []; - let n = this.dataProto.shape[0]; - let d = this.dataProto.shape[1]; - if (n * d !== tensor.length) { - throw 'The shape doesn\'t match the length of the flattened array'; - } - for (let i = 0; i < n; i++) { - let offset = i * d; - points.push({ - vector: new Float32Array(tensor.slice(offset, offset + d)), - metadata: {}, - projections: null, - index: i, - }); + spritesPromise.then((image) => { + const result: SpriteAndMetadataInfo = { + stats: analyzeMetadata(columnNames, pointsMetadata), + pointsInfo: pointsMetadata, + }; + if (image != null) { + result.spriteImage = image; + result.spriteMetadata = { + singleImageDim: this.dataProto.metadata.sprite.singleImageDim, + imagePath: 'proto', + }; } - return new DataSet(points); + callback(result); + }); + } + getBookmarks( + run: string, + tensorName: string, + callback: (r: State[]) => void + ): void { + return callback([]); + } + private flatArrayToDataset(tensor: number[]): DataSet { + let points: DataPoint[] = []; + let n = this.dataProto.shape[0]; + let d = this.dataProto.shape[1]; + if (n * d !== tensor.length) { + throw "The shape doesn't match the length of the flattened array"; + } + for (let i = 0; i < n; i++) { + let offset = i * d; + points.push({ + vector: new Float32Array(tensor.slice(offset, offset + d)), + metadata: {}, + projections: null, + index: i, + }); } + return new DataSet(points); } -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-server.ts b/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-server.ts index d074b5eb26..b7503ee743 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-server.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/data-provider-server.ts @@ -12,145 +12,142 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - // Limit for the number of data points we receive from the server. - export const LIMIT_NUM_POINTS = 100000; +import {DataSet, State, SpriteAndMetadataInfo} from './data'; +import { + retrieveSpriteAndMetadataInfo, + retrieveTensorAsBytes, + EmbeddingInfo, + DataProvider, + ProjectorConfig, +} from './data-provider'; +import * as logging from './logging'; - /** - * Data provider that loads data provided by a python server (usually backed - * by a checkpoint file). - */ - export class ServerDataProvider implements DataProvider { - private routePrefix: string; - private runProjectorConfigCache: {[run: string]: ProjectorConfig} = {}; - - constructor(routePrefix: string) { - this.routePrefix = routePrefix; - } - - private getEmbeddingInfo( - run: string, - tensorName: string, - callback: (e: EmbeddingInfo) => void - ): void { - this.retrieveProjectorConfig(run, (config) => { - const embeddings = config.embeddings; - for (let i = 0; i < embeddings.length; i++) { - const embedding = embeddings[i]; - if (embedding.tensorName === tensorName) { - callback(embedding); - return; - } +export const LIMIT_NUM_POINTS = 100000; +/** + * Data provider that loads data provided by a python server (usually backed + * by a checkpoint file). + */ +export class ServerDataProvider implements DataProvider { + private routePrefix: string; + private runProjectorConfigCache: { + [run: string]: ProjectorConfig; + } = {}; + constructor(routePrefix: string) { + this.routePrefix = routePrefix; + } + private getEmbeddingInfo( + run: string, + tensorName: string, + callback: (e: EmbeddingInfo) => void + ): void { + this.retrieveProjectorConfig(run, (config) => { + const embeddings = config.embeddings; + for (let i = 0; i < embeddings.length; i++) { + const embedding = embeddings[i]; + if (embedding.tensorName === tensorName) { + callback(embedding); + return; } - callback(null); - }); - } - - retrieveRuns(callback: (runs: string[]) => void): void { - const msgId = logging.setModalMessage('Fetching runs...'); - - const xhr = new XMLHttpRequest(); - xhr.open('GET', `${this.routePrefix}/runs`); - xhr.onerror = (err) => { - logging.setErrorMessage(xhr.responseText, 'fetching runs'); - }; - xhr.onload = () => { - const runs = JSON.parse(xhr.responseText); - logging.setModalMessage(null, msgId); - callback(runs); - }; - xhr.send(); - } - - retrieveProjectorConfig( - run: string, - callback: (d: ProjectorConfig) => void - ): void { - if (run in this.runProjectorConfigCache) { - callback(this.runProjectorConfigCache[run]); - return; } - - const msgId = logging.setModalMessage('Fetching projector config...'); - - const xhr = new XMLHttpRequest(); - xhr.open('GET', `${this.routePrefix}/info?run=${run}`); - xhr.onerror = (err) => { - logging.setErrorMessage(xhr.responseText, 'fetching projector config'); - }; - xhr.onload = () => { - const config = JSON.parse(xhr.responseText) as ProjectorConfig; - logging.setModalMessage(null, msgId); - this.runProjectorConfigCache[run] = config; - callback(config); - }; - xhr.send(); - } - - retrieveTensor( - run: string, - tensorName: string, - callback: (ds: DataSet) => void - ) { - this.getEmbeddingInfo(run, tensorName, (embedding) => { - retrieveTensorAsBytes( - this, - embedding, - run, - tensorName, - `${this.routePrefix}/tensor?run=${run}&name=${tensorName}` + - `&num_rows=${LIMIT_NUM_POINTS}`, - callback - ); - }); - } - - retrieveSpriteAndMetadata( - run: string, - tensorName: string, - callback: (r: SpriteAndMetadataInfo) => void - ) { - this.getEmbeddingInfo(run, tensorName, (embedding) => { - let metadataPath = null; - if (embedding.metadataPath) { - metadataPath = - `${this.routePrefix}/metadata?` + - `run=${run}&name=${tensorName}&num_rows=${LIMIT_NUM_POINTS}`; - } - let spriteImagePath = null; - if (embedding.sprite && embedding.sprite.imagePath) { - spriteImagePath = `${this.routePrefix}/sprite_image?run=${run}&name=${tensorName}`; - } - retrieveSpriteAndMetadataInfo( - metadataPath, - spriteImagePath, - embedding.sprite, - callback - ); - }); + callback(null); + }); + } + retrieveRuns(callback: (runs: string[]) => void): void { + const msgId = logging.setModalMessage('Fetching runs...'); + const xhr = new XMLHttpRequest(); + xhr.open('GET', `${this.routePrefix}/runs`); + xhr.onerror = (err) => { + logging.setErrorMessage(xhr.responseText, 'fetching runs'); + }; + xhr.onload = () => { + const runs = JSON.parse(xhr.responseText); + logging.setModalMessage(null, msgId); + callback(runs); + }; + xhr.send(); + } + retrieveProjectorConfig( + run: string, + callback: (d: ProjectorConfig) => void + ): void { + if (run in this.runProjectorConfigCache) { + callback(this.runProjectorConfigCache[run]); + return; } - - getBookmarks( - run: string, - tensorName: string, - callback: (r: State[]) => void - ) { - const msgId = logging.setModalMessage('Fetching bookmarks...'); - - const xhr = new XMLHttpRequest(); - xhr.open( - 'GET', - `${this.routePrefix}/bookmarks?run=${run}&name=${tensorName}` + const msgId = logging.setModalMessage('Fetching projector config...'); + const xhr = new XMLHttpRequest(); + xhr.open('GET', `${this.routePrefix}/info?run=${run}`); + xhr.onerror = (err) => { + logging.setErrorMessage(xhr.responseText, 'fetching projector config'); + }; + xhr.onload = () => { + const config = JSON.parse(xhr.responseText) as ProjectorConfig; + logging.setModalMessage(null, msgId); + this.runProjectorConfigCache[run] = config; + callback(config); + }; + xhr.send(); + } + retrieveTensor( + run: string, + tensorName: string, + callback: (ds: DataSet) => void + ) { + this.getEmbeddingInfo(run, tensorName, (embedding) => { + retrieveTensorAsBytes( + this, + embedding, + run, + tensorName, + `${this.routePrefix}/tensor?run=${run}&name=${tensorName}` + + `&num_rows=${LIMIT_NUM_POINTS}`, + callback ); - xhr.onerror = (err) => { - logging.setErrorMessage(xhr.responseText, 'fetching bookmarks'); - }; - xhr.onload = () => { - logging.setModalMessage(null, msgId); - const bookmarks = JSON.parse(xhr.responseText); - callback(bookmarks); - }; - xhr.send(); - } + }); + } + retrieveSpriteAndMetadata( + run: string, + tensorName: string, + callback: (r: SpriteAndMetadataInfo) => void + ) { + this.getEmbeddingInfo(run, tensorName, (embedding) => { + let metadataPath = null; + if (embedding.metadataPath) { + metadataPath = + `${this.routePrefix}/metadata?` + + `run=${run}&name=${tensorName}&num_rows=${LIMIT_NUM_POINTS}`; + } + let spriteImagePath = null; + if (embedding.sprite && embedding.sprite.imagePath) { + spriteImagePath = `${this.routePrefix}/sprite_image?run=${run}&name=${tensorName}`; + } + retrieveSpriteAndMetadataInfo( + metadataPath, + spriteImagePath, + embedding.sprite, + callback + ); + }); + } + getBookmarks( + run: string, + tensorName: string, + callback: (r: State[]) => void + ) { + const msgId = logging.setModalMessage('Fetching bookmarks...'); + const xhr = new XMLHttpRequest(); + xhr.open( + 'GET', + `${this.routePrefix}/bookmarks?run=${run}&name=${tensorName}` + ); + xhr.onerror = (err) => { + logging.setErrorMessage(xhr.responseText, 'fetching bookmarks'); + }; + xhr.onload = () => { + logging.setModalMessage(null, msgId); + const bookmarks = JSON.parse(xhr.responseText); + callback(bookmarks); + }; + xhr.send(); } -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/data-provider.ts b/tensorboard/plugins/projector/polymer3/vz_projector/data-provider.ts index 64404385a0..9e9ca06356 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/data-provider.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/data-provider.ts @@ -12,483 +12,452 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - /** Maximum number of colors supported in the color map. */ - const NUM_COLORS_COLOR_MAP = 50; - const MAX_SPRITE_IMAGE_SIZE_PX = 8192; - - export const METADATA_MSG_ID = 'metadata'; - export const TENSORS_MSG_ID = 'tensors'; - - /** Matches the json format of `projector_config.proto` */ - export interface SpriteMetadata { - imagePath: string; - singleImageDim: [number, number]; - } - - /** Matches the json format of `projector_config.proto` */ - export interface EmbeddingInfo { - /** Name of the tensor. */ - tensorName: string; - /** The shape of the tensor. */ - tensorShape: [number, number]; - /** - * The path to the tensors TSV file. If empty, it is assumed that the tensor - * is stored in the checkpoint file. - */ - tensorPath?: string; - /** The path to the metadata file associated with the tensor. */ - metadataPath?: string; - /** The path to the bookmarks file associated with the tensor. */ - bookmarksPath?: string; - sprite?: SpriteMetadata; - } - +import { + ColumnStats, + DataPoint, + DataSet, + PointMetadata, + SpriteAndMetadataInfo, + State, +} from './data'; +import * as logging from './logging'; +import * as util from './util'; + +const NUM_COLORS_COLOR_MAP = 50; +const MAX_SPRITE_IMAGE_SIZE_PX = 8192; +export const METADATA_MSG_ID = 'metadata'; +export const TENSORS_MSG_ID = 'tensors'; +/** Matches the json format of `projector_config.proto` */ +export interface SpriteMetadata { + imagePath: string; + singleImageDim: [number, number]; +} +/** Matches the json format of `projector_config.proto` */ +export interface EmbeddingInfo { + /** Name of the tensor. */ + tensorName: string; + /** The shape of the tensor. */ + tensorShape: [number, number]; /** - * Matches the json format of `projector_config.proto` - * This should be kept in sync with the code in vz-projector-data-panel which - * holds a template for users to build a projector config JSON object from the - * projector UI. + * The path to the tensors TSV file. If empty, it is assumed that the tensor + * is stored in the checkpoint file. */ - export interface ProjectorConfig { - embeddings: EmbeddingInfo[]; - modelCheckpointPath?: string; - } - - export type ServingMode = 'demo' | 'server' | 'proto'; - - /** Interface between the data storage and the UI. */ - export interface DataProvider { - /** Returns a list of run names that have embedding config files. */ - retrieveRuns(callback: (runs: string[]) => void): void; - - /** - * Returns the projector configuration: number of tensors, their shapes, - * and their associated metadata files. - */ - retrieveProjectorConfig( - run: string, - callback: (d: ProjectorConfig) => void - ): void; - - /** Fetches and returns the tensor with the specified name. */ - retrieveTensor( - run: string, - tensorName: string, - callback: (ds: DataSet) => void - ); - - /** - * Fetches the metadata for the specified tensor. - */ - retrieveSpriteAndMetadata( - run: string, - tensorName: string, - callback: (r: SpriteAndMetadataInfo) => void - ): void; - - getBookmarks( - run: string, - tensorName: string, - callback: (r: State[]) => void - ): void; - } - - export function retrieveTensorAsBytes( - dp: DataProvider, - embedding: EmbeddingInfo, + tensorPath?: string; + /** The path to the metadata file associated with the tensor. */ + metadataPath?: string; + /** The path to the bookmarks file associated with the tensor. */ + bookmarksPath?: string; + sprite?: SpriteMetadata; +} +/** + * Matches the json format of `projector_config.proto` + * This should be kept in sync with the code in vz-projector-data-panel which + * holds a template for users to build a projector config JSON object from the + * projector UI. + */ +export interface ProjectorConfig { + embeddings: EmbeddingInfo[]; + modelCheckpointPath?: string; +} +export type ServingMode = 'demo' | 'server' | 'proto'; +/** Interface between the data storage and the UI. */ +export interface DataProvider { + /** Returns a list of run names that have embedding config files. */ + retrieveRuns(callback: (runs: string[]) => void): void; + /** + * Returns the projector configuration: number of tensors, their shapes, + * and their associated metadata files. + */ + retrieveProjectorConfig( + run: string, + callback: (d: ProjectorConfig) => void + ): void; + /** Fetches and returns the tensor with the specified name. */ + retrieveTensor( run: string, tensorName: string, - tensorsPath: string, callback: (ds: DataSet) => void - ) { - // Get the tensor. - logging.setModalMessage('Fetching tensor values...', TENSORS_MSG_ID); - let xhr = new XMLHttpRequest(); - xhr.open('GET', tensorsPath); - xhr.responseType = 'arraybuffer'; - xhr.onprogress = (ev) => { - if (ev.lengthComputable) { - let percent = ((ev.loaded * 100) / ev.total).toFixed(1); - logging.setModalMessage( - 'Fetching tensor values: ' + percent + '%', - TENSORS_MSG_ID - ); + ); + /** + * Fetches the metadata for the specified tensor. + */ + retrieveSpriteAndMetadata( + run: string, + tensorName: string, + callback: (r: SpriteAndMetadataInfo) => void + ): void; + getBookmarks( + run: string, + tensorName: string, + callback: (r: State[]) => void + ): void; +} +export function retrieveTensorAsBytes( + dp: DataProvider, + embedding: EmbeddingInfo, + run: string, + tensorName: string, + tensorsPath: string, + callback: (ds: DataSet) => void +) { + // Get the tensor. + logging.setModalMessage('Fetching tensor values...', TENSORS_MSG_ID); + let xhr = new XMLHttpRequest(); + xhr.open('GET', tensorsPath); + xhr.responseType = 'arraybuffer'; + xhr.onprogress = (ev) => { + if (ev.lengthComputable) { + let percent = ((ev.loaded * 100) / ev.total).toFixed(1); + logging.setModalMessage( + 'Fetching tensor values: ' + percent + '%', + TENSORS_MSG_ID + ); + } + }; + xhr.onload = () => { + if (xhr.status !== 200) { + let msg = String.fromCharCode.apply(null, new Uint8Array(xhr.response)); + logging.setErrorMessage(msg, 'fetching tensors'); + return; + } + let data: Float32Array; + try { + data = new Float32Array(xhr.response); + } catch (e) { + logging.setErrorMessage(e, 'parsing tensor bytes'); + return; + } + let dim = embedding.tensorShape[1]; + let N = data.length / dim; + if (embedding.tensorShape[0] > N) { + logging.setWarningMessage( + `Showing the first ${N.toLocaleString()}` + + ` of ${embedding.tensorShape[0].toLocaleString()} data points` + ); + } + parseTensorsFromFloat32Array(data, dim).then((dataPoints) => { + callback(new DataSet(dataPoints)); + }); + }; + xhr.send(); +} +export function parseRawTensors( + content: ArrayBuffer, + callback: (ds: DataSet) => void +) { + parseTensors(content).then((data) => { + callback(new DataSet(data)); + }); +} +export function parseRawMetadata( + contents: ArrayBuffer, + callback: (r: SpriteAndMetadataInfo) => void +) { + parseMetadata(contents).then((result) => callback(result)); +} +/** + * Parse an ArrayBuffer in a streaming fashion line by line (or custom delim). + * Can handle very large files. + * + * @param content The array buffer. + * @param callback The callback called on each line. + * @param chunkSize The size of each read chunk, defaults to ~1MB. (optional) + * @param delim The delimiter used to split a line, defaults to '\n'. (optional) + * @returns A promise for when it is finished. + */ +function streamParse( + content: ArrayBuffer, + callback: (line: string) => void, + chunkSize = 1000000, + delim = '\n' +): Promise { + return new Promise((resolve, reject) => { + let offset = 0; + let bufferSize = content.byteLength - 1; + let data = ''; + function readHandler(str) { + offset += chunkSize; + let parts = str.split(delim); + let first = data + parts[0]; + if (parts.length === 1) { + data = first; + readChunk(offset, chunkSize); + return; } - }; - xhr.onload = () => { - if (xhr.status !== 200) { - let msg = String.fromCharCode.apply(null, new Uint8Array(xhr.response)); - logging.setErrorMessage(msg, 'fetching tensors'); + data = parts[parts.length - 1]; + callback(first); + for (let i = 1; i < parts.length - 1; i++) { + callback(parts[i]); + } + if (offset >= bufferSize) { + if (data) { + callback(data); + } + resolve(); return; } - let data: Float32Array; - try { - data = new Float32Array(xhr.response); - } catch (e) { - logging.setErrorMessage(e, 'parsing tensor bytes'); + readChunk(offset, chunkSize); + } + function readChunk(offset: number, size: number) { + const contentChunk = content.slice(offset, offset + size); + const blob = new Blob([contentChunk]); + const file = new FileReader(); + file.onload = (e: any) => readHandler(e.target.result); + file.readAsText(blob); + } + readChunk(offset, chunkSize); + }); +} +/** Parses a tsv text file. */ +export function parseTensors( + content: ArrayBuffer, + valueDelim = '\t' +): Promise { + logging.setModalMessage('Parsing tensors...', TENSORS_MSG_ID); + return new Promise((resolve, reject) => { + const data: DataPoint[] = []; + let numDim: number; + streamParse(content, (line: string) => { + line = line.trim(); + if (line === '') { return; } - - let dim = embedding.tensorShape[1]; - let N = data.length / dim; - if (embedding.tensorShape[0] > N) { - logging.setWarningMessage( - `Showing the first ${N.toLocaleString()}` + - ` of ${embedding.tensorShape[0].toLocaleString()} data points` - ); + const row = line.split(valueDelim); + const dataPoint: DataPoint = { + metadata: {}, + vector: null, + index: data.length, + projections: null, + }; + // If the first label is not a number, take it as the label. + if (isNaN(row[0] as any) || numDim === row.length - 1) { + dataPoint.metadata['label'] = row[0]; + dataPoint.vector = new Float32Array(row.slice(1).map(Number)); + } else { + dataPoint.vector = new Float32Array(row.map(Number)); } - parseTensorsFromFloat32Array(data, dim).then((dataPoints) => { - callback(new DataSet(dataPoints)); - }); - }; - xhr.send(); - } - - export function parseRawTensors( - content: ArrayBuffer, - callback: (ds: DataSet) => void - ) { - parseTensors(content).then((data) => { - callback(new DataSet(data)); - }); - } - - export function parseRawMetadata( - contents: ArrayBuffer, - callback: (r: SpriteAndMetadataInfo) => void - ) { - parseMetadata(contents).then((result) => callback(result)); - } - - /** - * Parse an ArrayBuffer in a streaming fashion line by line (or custom delim). - * Can handle very large files. - * - * @param content The array buffer. - * @param callback The callback called on each line. - * @param chunkSize The size of each read chunk, defaults to ~1MB. (optional) - * @param delim The delimiter used to split a line, defaults to '\n'. (optional) - * @returns A promise for when it is finished. - */ - function streamParse( - content: ArrayBuffer, - callback: (line: string) => void, - chunkSize = 1000000, - delim = '\n' - ): Promise { - return new Promise((resolve, reject) => { - let offset = 0; - let bufferSize = content.byteLength - 1; - let data = ''; - - function readHandler(str) { - offset += chunkSize; - let parts = str.split(delim); - let first = data + parts[0]; - if (parts.length === 1) { - data = first; - readChunk(offset, chunkSize); - return; - } - data = parts[parts.length - 1]; - callback(first); - for (let i = 1; i < parts.length - 1; i++) { - callback(parts[i]); - } - if (offset >= bufferSize) { - if (data) { - callback(data); - } - resolve(); - return; - } - readChunk(offset, chunkSize); + data.push(dataPoint); + if (numDim == null) { + numDim = dataPoint.vector.length; } - - function readChunk(offset: number, size: number) { - const contentChunk = content.slice(offset, offset + size); - - const blob = new Blob([contentChunk]); - const file = new FileReader(); - file.onload = (e: any) => readHandler(e.target.result); - file.readAsText(blob); + if (numDim !== dataPoint.vector.length) { + logging.setModalMessage( + 'Parsing failed. Vector dimensions do not match' + ); + throw Error('Parsing failed'); } - - readChunk(offset, chunkSize); + if (numDim <= 1) { + logging.setModalMessage( + 'Parsing failed. Found a vector with only one dimension?' + ); + throw Error('Parsing failed'); + } + }).then(() => { + logging.setModalMessage(null, TENSORS_MSG_ID); + resolve(data); }); - } - - /** Parses a tsv text file. */ - export function parseTensors( - content: ArrayBuffer, - valueDelim = '\t' - ): Promise { - logging.setModalMessage('Parsing tensors...', TENSORS_MSG_ID); - - return new Promise((resolve, reject) => { - const data: DataPoint[] = []; - let numDim: number; - - streamParse(content, (line: string) => { - line = line.trim(); - if (line === '') { - return; + }); +} +/** Parses a tsv text file. */ +export function parseTensorsFromFloat32Array( + data: Float32Array, + dim: number +): Promise { + return util + .runAsyncTask( + 'Parsing tensors...', + () => { + const N = data.length / dim; + const dataPoints: DataPoint[] = []; + let offset = 0; + for (let i = 0; i < N; ++i) { + dataPoints.push({ + metadata: {}, + vector: data.subarray(offset, offset + dim), + index: i, + projections: null, + }); + offset += dim; } - const row = line.split(valueDelim); - const dataPoint: DataPoint = { - metadata: {}, - vector: null, - index: data.length, - projections: null, - }; - // If the first label is not a number, take it as the label. - if (isNaN(row[0] as any) || numDim === row.length - 1) { - dataPoint.metadata['label'] = row[0]; - dataPoint.vector = new Float32Array(row.slice(1).map(Number)); + return dataPoints; + }, + TENSORS_MSG_ID + ) + .then((dataPoints) => { + logging.setModalMessage(null, TENSORS_MSG_ID); + return dataPoints; + }); +} +export function analyzeMetadata( + columnNames, + pointsMetadata: PointMetadata[] +): ColumnStats[] { + const columnStats: ColumnStats[] = columnNames.map((name) => { + return { + name: name, + isNumeric: true, + tooManyUniqueValues: false, + min: Number.POSITIVE_INFINITY, + max: Number.NEGATIVE_INFINITY, + }; + }); + const mapOfValues: [ + { + [value: string]: number; + } + ] = columnNames.map(() => new Object()); + pointsMetadata.forEach((metadata) => { + columnNames.forEach((name: string, colIndex: number) => { + const stats = columnStats[colIndex]; + const map = mapOfValues[colIndex]; + const value = metadata[name]; + // Skip missing values. + if (value == null) { + return; + } + if (!stats.tooManyUniqueValues) { + if (value in map) { + map[value]++; } else { - dataPoint.vector = new Float32Array(row.map(Number)); - } - data.push(dataPoint); - if (numDim == null) { - numDim = dataPoint.vector.length; - } - if (numDim !== dataPoint.vector.length) { - logging.setModalMessage( - 'Parsing failed. Vector dimensions do not match' - ); - throw Error('Parsing failed'); + map[value] = 1; } - if (numDim <= 1) { - logging.setModalMessage( - 'Parsing failed. Found a vector with only one dimension?' - ); - throw Error('Parsing failed'); + if (Object.keys(map).length > NUM_COLORS_COLOR_MAP) { + stats.tooManyUniqueValues = true; } - }).then(() => { - logging.setModalMessage(null, TENSORS_MSG_ID); - resolve(data); - }); + } + if (isNaN(value as any)) { + stats.isNumeric = false; + } else { + metadata[name] = +value; + stats.min = Math.min(stats.min, +value); + stats.max = Math.max(stats.max, +value); + } }); - } - - /** Parses a tsv text file. */ - export function parseTensorsFromFloat32Array( - data: Float32Array, - dim: number - ): Promise { - return util - .runAsyncTask( - 'Parsing tensors...', - () => { - const N = data.length / dim; - const dataPoints: DataPoint[] = []; - let offset = 0; - for (let i = 0; i < N; ++i) { - dataPoints.push({ - metadata: {}, - vector: data.subarray(offset, offset + dim), - index: i, - projections: null, - }); - offset += dim; - } - return dataPoints; - }, - TENSORS_MSG_ID - ) - .then((dataPoints) => { - logging.setModalMessage(null, TENSORS_MSG_ID); - return dataPoints; - }); - } - - export function analyzeMetadata( - columnNames, - pointsMetadata: PointMetadata[] - ): ColumnStats[] { - const columnStats: ColumnStats[] = columnNames.map((name) => { - return { - name: name, - isNumeric: true, - tooManyUniqueValues: false, - min: Number.POSITIVE_INFINITY, - max: Number.NEGATIVE_INFINITY, - }; + }); + columnStats.forEach((stats, colIndex) => { + stats.uniqueEntries = Object.keys(mapOfValues[colIndex]).map((label) => { + return {label, count: mapOfValues[colIndex][label]}; }); - - const mapOfValues: [{[value: string]: number}] = columnNames.map( - () => new Object() - ); - - pointsMetadata.forEach((metadata) => { - columnNames.forEach((name: string, colIndex: number) => { - const stats = columnStats[colIndex]; - const map = mapOfValues[colIndex]; - const value = metadata[name]; - - // Skip missing values. - if (value == null) { + }); + return columnStats; +} +export function parseMetadata( + content: ArrayBuffer +): Promise { + logging.setModalMessage('Parsing metadata...', METADATA_MSG_ID); + return new Promise((resolve, reject) => { + let pointsMetadata: PointMetadata[] = []; + let hasHeader = false; + let lineNumber = 0; + let columnNames = ['label']; + streamParse(content, (line: string) => { + if (line.trim().length === 0) { + return; + } + if (lineNumber === 0) { + hasHeader = line.indexOf('\t') >= 0; + // If the first row doesn't contain metadata keys, we assume that the + // values are labels. + if (hasHeader) { + columnNames = line.split('\t'); + lineNumber++; return; } - - if (!stats.tooManyUniqueValues) { - if (value in map) { - map[value]++; - } else { - map[value] = 1; - } - if (Object.keys(map).length > NUM_COLORS_COLOR_MAP) { - stats.tooManyUniqueValues = true; - } - } - if (isNaN(value as any)) { - stats.isNumeric = false; - } else { - metadata[name] = +value; - stats.min = Math.min(stats.min, +value); - stats.max = Math.max(stats.max, +value); - } + } + lineNumber++; + let rowValues = line.split('\t'); + let metadata: PointMetadata = {}; + pointsMetadata.push(metadata); + columnNames.forEach((name: string, colIndex: number) => { + let value = rowValues[colIndex]; + // Normalize missing values. + value = value === '' ? null : value; + metadata[name] = value; }); - }); - columnStats.forEach((stats, colIndex) => { - stats.uniqueEntries = Object.keys(mapOfValues[colIndex]).map((label) => { - return {label, count: mapOfValues[colIndex][label]}; + }).then(() => { + logging.setModalMessage(null, METADATA_MSG_ID); + resolve({ + stats: analyzeMetadata(columnNames, pointsMetadata), + pointsInfo: pointsMetadata, }); }); - return columnStats; - } - - export function parseMetadata( - content: ArrayBuffer - ): Promise { - logging.setModalMessage('Parsing metadata...', METADATA_MSG_ID); - - return new Promise((resolve, reject) => { - let pointsMetadata: PointMetadata[] = []; - let hasHeader = false; - let lineNumber = 0; - let columnNames = ['label']; - streamParse(content, (line: string) => { - if (line.trim().length === 0) { - return; - } - if (lineNumber === 0) { - hasHeader = line.indexOf('\t') >= 0; - - // If the first row doesn't contain metadata keys, we assume that the - // values are labels. - if (hasHeader) { - columnNames = line.split('\t'); - lineNumber++; - return; + }); +} +export function fetchImage(url: string): Promise { + return new Promise((resolve, reject) => { + let image = new Image(); + image.onload = () => resolve(image); + image.onerror = (err) => reject(err); + image.crossOrigin = ''; + image.src = url; + }); +} +export function retrieveSpriteAndMetadataInfo( + metadataPath: string, + spriteImagePath: string, + spriteMetadata: SpriteMetadata, + callback: (r: SpriteAndMetadataInfo) => void +) { + let metadataPromise: Promise = Promise.resolve({}); + if (metadataPath) { + metadataPromise = new Promise((resolve, reject) => { + logging.setModalMessage('Fetching metadata...', METADATA_MSG_ID); + const request = new XMLHttpRequest(); + request.open('GET', metadataPath); + request.responseType = 'arraybuffer'; + request.onreadystatechange = () => { + if (request.readyState === 4) { + if (request.status === 200) { + // The metadata was successfully retrieved. Parse it. + resolve(parseMetadata(request.response)); + } else { + // The response contains the error message, but we must convert it + // to a string. + const errorReader = new FileReader(); + errorReader.onload = () => { + logging.setErrorMessage( + errorReader.result as string, + 'fetching metadata' + ); + reject(); + }; + errorReader.readAsText(new Blob([request.response])); } } - - lineNumber++; - - let rowValues = line.split('\t'); - let metadata: PointMetadata = {}; - pointsMetadata.push(metadata); - columnNames.forEach((name: string, colIndex: number) => { - let value = rowValues[colIndex]; - // Normalize missing values. - value = value === '' ? null : value; - metadata[name] = value; - }); - }).then(() => { - logging.setModalMessage(null, METADATA_MSG_ID); - resolve({ - stats: analyzeMetadata(columnNames, pointsMetadata), - pointsInfo: pointsMetadata, - }); - }); + }; + request.send(null); }); } - - export function fetchImage(url: string): Promise { - return new Promise((resolve, reject) => { - let image = new Image(); - image.onload = () => resolve(image); - image.onerror = (err) => reject(err); - image.crossOrigin = ''; - image.src = url; - }); + let spriteMsgId = null; + let spritesPromise: Promise = null; + if (spriteImagePath) { + spriteMsgId = logging.setModalMessage('Fetching sprite image...'); + spritesPromise = fetchImage(spriteImagePath); } - - export function retrieveSpriteAndMetadataInfo( - metadataPath: string, - spriteImagePath: string, - spriteMetadata: SpriteMetadata, - callback: (r: SpriteAndMetadataInfo) => void - ) { - let metadataPromise: Promise = Promise.resolve({}); - if (metadataPath) { - metadataPromise = new Promise( - (resolve, reject) => { - logging.setModalMessage('Fetching metadata...', METADATA_MSG_ID); - - const request = new XMLHttpRequest(); - request.open('GET', metadataPath); - request.responseType = 'arraybuffer'; - - request.onreadystatechange = () => { - if (request.readyState === 4) { - if (request.status === 200) { - // The metadata was successfully retrieved. Parse it. - resolve(parseMetadata(request.response)); - } else { - // The response contains the error message, but we must convert it - // to a string. - const errorReader = new FileReader(); - errorReader.onload = () => { - logging.setErrorMessage( - errorReader.result, - 'fetching metadata' - ); - reject(); - }; - errorReader.readAsText(new Blob([request.response])); - } - } - }; - request.send(null); - } - ); + // Fetch the metadata and the image in parallel. + Promise.all([metadataPromise, spritesPromise]).then((values) => { + if (spriteMsgId) { + logging.setModalMessage(null, spriteMsgId); } - let spriteMsgId = null; - let spritesPromise: Promise = null; - if (spriteImagePath) { - spriteMsgId = logging.setModalMessage('Fetching sprite image...'); - spritesPromise = fetchImage(spriteImagePath); - } - - // Fetch the metadata and the image in parallel. - Promise.all([metadataPromise, spritesPromise]).then((values) => { - if (spriteMsgId) { - logging.setModalMessage(null, spriteMsgId); - } - const [metadata, spriteImage] = values; - - if ( - spriteImage && - (spriteImage.height > MAX_SPRITE_IMAGE_SIZE_PX || - spriteImage.width > MAX_SPRITE_IMAGE_SIZE_PX) - ) { - logging.setModalMessage( - `Error: Sprite image of dimensions ${spriteImage.width}px x ` + - `${spriteImage.height}px exceeds maximum dimensions ` + - `${MAX_SPRITE_IMAGE_SIZE_PX}px x ${MAX_SPRITE_IMAGE_SIZE_PX}px` - ); - } else { - metadata.spriteImage = spriteImage; - metadata.spriteMetadata = spriteMetadata; - try { - callback(metadata); - } catch (e) { - logging.setModalMessage(String(e)); - } + const [metadata, spriteImage] = values; + if ( + spriteImage && + (spriteImage.height > MAX_SPRITE_IMAGE_SIZE_PX || + spriteImage.width > MAX_SPRITE_IMAGE_SIZE_PX) + ) { + logging.setModalMessage( + `Error: Sprite image of dimensions ${spriteImage.width}px x ` + + `${spriteImage.height}px exceeds maximum dimensions ` + + `${MAX_SPRITE_IMAGE_SIZE_PX}px x ${MAX_SPRITE_IMAGE_SIZE_PX}px` + ); + } else { + metadata.spriteImage = spriteImage; + metadata.spriteMetadata = spriteMetadata; + try { + callback(metadata); + } catch (e) { + logging.setModalMessage(String(e)); } - }); - } -} // namespace vz_projector + } + }); +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/data.ts b/tensorboard/plugins/projector/polymer3/vz_projector/data.ts index c407c44566..9f7c7d9faa 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/data.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/data.ts @@ -12,761 +12,709 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - export type DistanceFunction = (a: vector.Vector, b: vector.Vector) => number; - export type ProjectionComponents3D = [string, string, string]; - - export interface PointMetadata { - [key: string]: number | string; - } - - export interface DataProto { - shape: [number, number]; - tensor: number[]; - metadata: { - columns: Array<{ - name: string; - stringValues: string[]; - numericValues: number[]; - }>; - sprite: {imageBase64: string; singleImageDim: [number, number]}; +import numeric from 'numericjs'; + +import {TSNE} from './bh_tsne'; +import {SpriteMetadata} from './data-provider'; +import {CameraDef} from './scatterPlot'; +import * as knn from './knn'; +import * as vector from './vector'; +import * as logging from './logging'; +import * as util from './util'; + +export type DistanceFunction = (a: vector.Vector, b: vector.Vector) => number; +export type ProjectionComponents3D = [string, string, string]; + +export interface PointMetadata { + [key: string]: number | string; +} + +export interface DataProto { + shape: [number, number]; + tensor: number[]; + metadata: { + columns: Array<{ + name: string; + stringValues: string[]; + numericValues: number[]; + }>; + sprite: { + imageBase64: string; + singleImageDim: [number, number]; }; + }; +} + +/** Statistics for a metadata column. */ +export interface ColumnStats { + name: string; + isNumeric: boolean; + tooManyUniqueValues: boolean; + uniqueEntries?: Array<{ + label: string; + count: number; + }>; + min: number; + max: number; +} +export interface SpriteAndMetadataInfo { + stats?: ColumnStats[]; + pointsInfo?: PointMetadata[]; + spriteImage?: HTMLImageElement; + spriteMetadata?: SpriteMetadata; +} + +/** A single collection of points which make up a sequence through space. */ +export interface Sequence { + /** Indices into the DataPoints array in the Data object. */ + pointIndices: number[]; +} +export interface DataPoint { + /** The point in the original space. */ + vector: Float32Array; + /* + * Metadata for each point. Each metadata is a set of key/value pairs + * where the value can be a string or a number. + */ + metadata: PointMetadata; + /** index of the sequence, used for highlighting on click */ + sequenceIndex?: number; + /** index in the original data source */ + index: number; + /** This is where the calculated projections space are cached */ + projections: { + [key: string]: number; + }; +} +const IS_FIREFOX = navigator.userAgent.toLowerCase().indexOf('firefox') >= 0; +/** Controls whether nearest neighbors computation is done on the GPU or CPU. */ +const KNN_GPU_ENABLED = util.hasWebGLSupport() && !IS_FIREFOX; +export const TSNE_SAMPLE_SIZE = 10000; +export const UMAP_SAMPLE_SIZE = 5000; +export const PCA_SAMPLE_SIZE = 50000; +/** Number of dimensions to sample when doing approximate PCA. */ +export const PCA_SAMPLE_DIM = 200; +/** Number of pca components to compute. */ +const NUM_PCA_COMPONENTS = 10; +/** Id of message box used for umap optimization progress bar. */ +const UMAP_MSG_ID = 'umap-optimization'; +/** + * Reserved metadata attributes used for sequence information + * NOTE: Use "__seq_next__" as "__next__" is deprecated. + */ +const SEQUENCE_METADATA_ATTRS = ['__next__', '__seq_next__']; +function getSequenceNextPointIndex( + pointMetadata: PointMetadata +): number | null { + let sequenceAttr = null; + for (let metadataAttr of SEQUENCE_METADATA_ATTRS) { + if (metadataAttr in pointMetadata && pointMetadata[metadataAttr] !== '') { + sequenceAttr = pointMetadata[metadataAttr]; + break; + } } - - /** Statistics for a metadata column. */ - export interface ColumnStats { - name: string; - isNumeric: boolean; - tooManyUniqueValues: boolean; - uniqueEntries?: Array<{label: string; count: number}>; - min: number; - max: number; - } - - export interface SpriteAndMetadataInfo { - stats?: ColumnStats[]; - pointsInfo?: PointMetadata[]; - spriteImage?: HTMLImageElement; - spriteMetadata?: SpriteMetadata; - } - - /** A single collection of points which make up a sequence through space. */ - export interface Sequence { - /** Indices into the DataPoints array in the Data object. */ - pointIndices: number[]; - } - - export interface DataPoint { - /** The point in the original space. */ - vector: Float32Array; - - /* - * Metadata for each point. Each metadata is a set of key/value pairs - * where the value can be a string or a number. - */ - metadata: PointMetadata; - - /** index of the sequence, used for highlighting on click */ - sequenceIndex?: number; - - /** index in the original data source */ - index: number; - - /** This is where the calculated projections space are cached */ - projections: {[key: string]: number}; + if (sequenceAttr == null) { + return null; } - - const IS_FIREFOX = navigator.userAgent.toLowerCase().indexOf('firefox') >= 0; - /** Controls whether nearest neighbors computation is done on the GPU or CPU. */ - const KNN_GPU_ENABLED = util.hasWebGLSupport() && !IS_FIREFOX; - - export const TSNE_SAMPLE_SIZE = 10000; - export const UMAP_SAMPLE_SIZE = 5000; - export const PCA_SAMPLE_SIZE = 50000; - /** Number of dimensions to sample when doing approximate PCA. */ - export const PCA_SAMPLE_DIM = 200; - /** Number of pca components to compute. */ - const NUM_PCA_COMPONENTS = 10; - - /** Id of message box used for umap optimization progress bar. */ - const UMAP_MSG_ID = 'umap-optimization'; - + return +sequenceAttr; +} +/** + * Dataset contains a DataPoints array that should be treated as immutable. This + * acts as a working subset of the original data, with cached properties + * from computationally expensive operations. Because creating a subset + * requires normalizing and shifting the vector space, we make a copy of the + * data so we can still always create new subsets based on the original data. + */ +export class DataSet { + points: DataPoint[]; + sequences: Sequence[]; + shuffledDataIndices: number[] = []; /** - * Reserved metadata attributes used for sequence information - * NOTE: Use "__seq_next__" as "__next__" is deprecated. + * This keeps a list of all current projections so you can easily test to see + * if it's been calculated already. */ - const SEQUENCE_METADATA_ATTRS = ['__next__', '__seq_next__']; - - function getSequenceNextPointIndex( - pointMetadata: PointMetadata - ): number | null { - let sequenceAttr = null; - for (let metadataAttr of SEQUENCE_METADATA_ATTRS) { - if (metadataAttr in pointMetadata && pointMetadata[metadataAttr] !== '') { - sequenceAttr = pointMetadata[metadataAttr]; - break; + projections: { + [projection: string]: boolean; + } = {}; + nearest: knn.NearestEntry[][]; + spriteAndMetadataInfo: SpriteAndMetadataInfo; + fracVariancesExplained: number[]; + tSNEIteration: number = 0; + tSNEShouldPause = false; + tSNEShouldStop = true; + superviseFactor: number; + superviseLabels: string[]; + superviseInput: string = ''; + dim: [number, number] = [0, 0]; + hasTSNERun: boolean = false; + private tsne: TSNE; + hasUmapRun = false; + private umap: UMAP; + /** Creates a new Dataset */ + constructor( + points: DataPoint[], + spriteAndMetadataInfo?: SpriteAndMetadataInfo + ) { + this.points = points; + this.shuffledDataIndices = util.shuffle(util.range(this.points.length)); + this.sequences = this.computeSequences(points); + this.dim = [this.points.length, this.points[0].vector.length]; + this.spriteAndMetadataInfo = spriteAndMetadataInfo; + } + private computeSequences(points: DataPoint[]) { + // Keep a list of indices seen so we don't compute sequences for a given + // point twice. + let indicesSeen = new Int8Array(points.length); + // Compute sequences. + let indexToSequence: { + [index: number]: Sequence; + } = {}; + let sequences: Sequence[] = []; + for (let i = 0; i < points.length; i++) { + if (indicesSeen[i]) { + continue; + } + indicesSeen[i] = 1; + // Ignore points without a sequence attribute. + let next = getSequenceNextPointIndex(points[i].metadata); + if (next == null) { + continue; + } + if (next in indexToSequence) { + let existingSequence = indexToSequence[next]; + // Pushing at the beginning of the array. + existingSequence.pointIndices.unshift(i); + indexToSequence[i] = existingSequence; + continue; + } + // The current point is pointing to a new/unseen sequence. + let newSequence: Sequence = {pointIndices: []}; + indexToSequence[i] = newSequence; + sequences.push(newSequence); + let currentIndex = i; + while (points[currentIndex]) { + newSequence.pointIndices.push(currentIndex); + let next = getSequenceNextPointIndex(points[currentIndex].metadata); + if (next != null) { + indicesSeen[next] = 1; + currentIndex = next; + } else { + currentIndex = -1; + } } } - if (sequenceAttr == null) { - return null; + return sequences; + } + projectionCanBeRendered(projection: ProjectionType): boolean { + if (projection !== 'tsne') { + return true; } - return +sequenceAttr; + return this.tSNEIteration > 0; } - /** - * Dataset contains a DataPoints array that should be treated as immutable. This - * acts as a working subset of the original data, with cached properties - * from computationally expensive operations. Because creating a subset - * requires normalizing and shifting the vector space, we make a copy of the - * data so we can still always create new subsets based on the original data. + * Returns a new subset dataset by copying out data. We make a copy because + * we have to modify the vectors by normalizing them. + * + * @param subset Array of indices of points that we want in the subset. + * + * @return A subset of the original dataset. */ - export class DataSet { - points: DataPoint[]; - sequences: Sequence[]; - - shuffledDataIndices: number[] = []; - - /** - * This keeps a list of all current projections so you can easily test to see - * if it's been calculated already. - */ - projections: {[projection: string]: boolean} = {}; - nearest: knn.NearestEntry[][]; - spriteAndMetadataInfo: SpriteAndMetadataInfo; - fracVariancesExplained: number[]; - - tSNEIteration: number = 0; - tSNEShouldPause = false; - tSNEShouldStop = true; - superviseFactor: number; - superviseLabels: string[]; - superviseInput: string = ''; - dim: [number, number] = [0, 0]; - hasTSNERun: boolean = false; - private tsne: TSNE; - - hasUmapRun = false; - private umap: UMAP; - - /** Creates a new Dataset */ - constructor( - points: DataPoint[], - spriteAndMetadataInfo?: SpriteAndMetadataInfo - ) { - this.points = points; - this.shuffledDataIndices = util.shuffle(util.range(this.points.length)); - this.sequences = this.computeSequences(points); - this.dim = [this.points.length, this.points[0].vector.length]; - this.spriteAndMetadataInfo = spriteAndMetadataInfo; - } - - private computeSequences(points: DataPoint[]) { - // Keep a list of indices seen so we don't compute sequences for a given - // point twice. - let indicesSeen = new Int8Array(points.length); - // Compute sequences. - let indexToSequence: {[index: number]: Sequence} = {}; - let sequences: Sequence[] = []; - for (let i = 0; i < points.length; i++) { - if (indicesSeen[i]) { - continue; - } - indicesSeen[i] = 1; - - // Ignore points without a sequence attribute. - let next = getSequenceNextPointIndex(points[i].metadata); - if (next == null) { - continue; - } - if (next in indexToSequence) { - let existingSequence = indexToSequence[next]; - // Pushing at the beginning of the array. - existingSequence.pointIndices.unshift(i); - indexToSequence[i] = existingSequence; - continue; - } - // The current point is pointing to a new/unseen sequence. - let newSequence: Sequence = {pointIndices: []}; - indexToSequence[i] = newSequence; - sequences.push(newSequence); - let currentIndex = i; - while (points[currentIndex]) { - newSequence.pointIndices.push(currentIndex); - let next = getSequenceNextPointIndex(points[currentIndex].metadata); - if (next != null) { - indicesSeen[next] = 1; - currentIndex = next; - } else { - currentIndex = -1; - } - } - } - return sequences; + getSubset(subset?: number[]): DataSet { + const pointsSubset = + subset != null && subset.length > 0 + ? subset.map((i) => this.points[i]) + : this.points; + let points = pointsSubset.map((dp) => { + return { + metadata: dp.metadata, + index: dp.index, + vector: dp.vector.slice(), + projections: {} as { + [key: string]: number; + }, + }; + }); + return new DataSet(points, this.spriteAndMetadataInfo); + } + /** + * Computes the centroid, shifts all points to that centroid, + * then makes them all unit norm. + */ + normalize() { + // Compute the centroid of all data points. + let centroid = vector.centroid(this.points, (a) => a.vector); + if (centroid == null) { + throw Error('centroid should not be null'); } - - projectionCanBeRendered(projection: ProjectionType): boolean { - if (projection !== 'tsne') { - return true; + // Shift all points by the centroid and make them unit norm. + for (let id = 0; id < this.points.length; ++id) { + let dataPoint = this.points[id]; + dataPoint.vector = vector.sub(dataPoint.vector, centroid); + if (vector.norm2(dataPoint.vector) > 0) { + // If we take the unit norm of a vector of all 0s, we get a vector of + // all NaNs. We prevent that with a guard. + vector.unit(dataPoint.vector); } - return this.tSNEIteration > 0; } - - /** - * Returns a new subset dataset by copying out data. We make a copy because - * we have to modify the vectors by normalizing them. - * - * @param subset Array of indices of points that we want in the subset. - * - * @return A subset of the original dataset. - */ - getSubset(subset?: number[]): DataSet { - const pointsSubset = - subset != null && subset.length > 0 - ? subset.map((i) => this.points[i]) - : this.points; - let points = pointsSubset.map((dp) => { - return { - metadata: dp.metadata, - index: dp.index, - vector: dp.vector.slice(), - projections: {} as {[key: string]: number}, - }; - }); - return new DataSet(points, this.spriteAndMetadataInfo); + } + /** Projects the dataset onto a given vector and caches the result. */ + projectLinear(dir: vector.Vector, label: string) { + this.projections[label] = true; + this.points.forEach((dataPoint) => { + dataPoint.projections[label] = vector.dot(dataPoint.vector, dir); + }); + } + /** Projects the dataset along the top 10 principal components. */ + projectPCA(): Promise { + if (this.projections['pca-0'] != null) { + return Promise.resolve(null); } - - /** - * Computes the centroid, shifts all points to that centroid, - * then makes them all unit norm. - */ - normalize() { - // Compute the centroid of all data points. - let centroid = vector.centroid(this.points, (a) => a.vector); - if (centroid == null) { - throw Error('centroid should not be null'); + return util.runAsyncTask('Computing PCA...', () => { + // Approximate pca vectors by sampling the dimensions. + let dim = this.points[0].vector.length; + let vectors = this.shuffledDataIndices.map((i) => this.points[i].vector); + if (dim > PCA_SAMPLE_DIM) { + vectors = vector.projectRandom(vectors, PCA_SAMPLE_DIM); } - // Shift all points by the centroid and make them unit norm. - for (let id = 0; id < this.points.length; ++id) { - let dataPoint = this.points[id]; - dataPoint.vector = vector.sub(dataPoint.vector, centroid); - if (vector.norm2(dataPoint.vector) > 0) { - // If we take the unit norm of a vector of all 0s, we get a vector of - // all NaNs. We prevent that with a guard. - vector.unit(dataPoint.vector); - } + const sampledVectors = vectors.slice(0, PCA_SAMPLE_SIZE); + const {dot, transpose, svd: numericSvd} = numeric; + // numeric dynamically generates `numeric.div` and Closure compiler has + // incorrectly compiles `numeric.div` property accessor. We use below + // signature to prevent Closure from mangling and guessing. + const div = numeric['div']; + const scalar = dot(transpose(sampledVectors), sampledVectors); + const sigma = div(scalar, sampledVectors.length); + const svd = numericSvd(sigma); + const variances: number[] = svd.S; + let totalVariance = 0; + for (let i = 0; i < variances.length; ++i) { + totalVariance += variances[i]; } - } - - /** Projects the dataset onto a given vector and caches the result. */ - projectLinear(dir: vector.Vector, label: string) { - this.projections[label] = true; - this.points.forEach((dataPoint) => { - dataPoint.projections[label] = vector.dot(dataPoint.vector, dir); - }); - } - - /** Projects the dataset along the top 10 principal components. */ - projectPCA(): Promise { - if (this.projections['pca-0'] != null) { - return Promise.resolve(null); + for (let i = 0; i < variances.length; ++i) { + variances[i] /= totalVariance; } - return util.runAsyncTask('Computing PCA...', () => { - // Approximate pca vectors by sampling the dimensions. - let dim = this.points[0].vector.length; - let vectors = this.shuffledDataIndices.map( - (i) => this.points[i].vector - ); - if (dim > PCA_SAMPLE_DIM) { - vectors = vector.projectRandom(vectors, PCA_SAMPLE_DIM); - } - const sampledVectors = vectors.slice(0, PCA_SAMPLE_SIZE); - const {dot, transpose, svd: numericSvd} = numeric; - // numeric dynamically generates `numeric.div` and Closure compiler has - // incorrectly compiles `numeric.div` property accessor. We use below - // signature to prevent Closure from mangling and guessing. - const div = numeric['div']; - - const scalar = dot(transpose(sampledVectors), sampledVectors); - const sigma = div(scalar, sampledVectors.length); - const svd = numericSvd(sigma); - - const variances: number[] = svd.S; - let totalVariance = 0; - for (let i = 0; i < variances.length; ++i) { - totalVariance += variances[i]; - } - for (let i = 0; i < variances.length; ++i) { - variances[i] /= totalVariance; - } - this.fracVariancesExplained = variances; - let U: number[][] = svd.U; - let pcaVectors = vectors.map((vector) => { - let newV = new Float32Array(NUM_PCA_COMPONENTS); - for (let newDim = 0; newDim < NUM_PCA_COMPONENTS; newDim++) { - let dot = 0; - for (let oldDim = 0; oldDim < vector.length; oldDim++) { - dot += vector[oldDim] * U[oldDim][newDim]; - } - newV[newDim] = dot; - } - return newV; - }); - for (let d = 0; d < NUM_PCA_COMPONENTS; d++) { - let label = 'pca-' + d; - this.projections[label] = true; - for (let i = 0; i < pcaVectors.length; i++) { - let pointIndex = this.shuffledDataIndices[i]; - this.points[pointIndex].projections[label] = pcaVectors[i][d]; + this.fracVariancesExplained = variances; + let U: number[][] = svd.U; + let pcaVectors = vectors.map((vector) => { + let newV = new Float32Array(NUM_PCA_COMPONENTS); + for (let newDim = 0; newDim < NUM_PCA_COMPONENTS; newDim++) { + let dot = 0; + for (let oldDim = 0; oldDim < vector.length; oldDim++) { + dot += vector[oldDim] * U[oldDim][newDim]; } + newV[newDim] = dot; } + return newV; }); - } - - /** Runs tsne on the data. */ - projectTSNE( - perplexity: number, - learningRate: number, - tsneDim: number, - stepCallback: (iter: number) => void - ) { - this.hasTSNERun = true; - let k = Math.floor(3 * perplexity); - let opt = {epsilon: learningRate, perplexity: perplexity, dim: tsneDim}; - this.tsne = new TSNE(opt); - this.tsne.setSupervision(this.superviseLabels, this.superviseInput); - this.tsne.setSuperviseFactor(this.superviseFactor); - this.tSNEShouldPause = false; - this.tSNEShouldStop = false; - this.tSNEIteration = 0; - - let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE); - let step = () => { - if (this.tSNEShouldStop) { - this.projections['tsne'] = false; - stepCallback(null); - this.tsne = null; - this.hasTSNERun = false; - return; - } - - if (!this.tSNEShouldPause) { - this.tsne.step(); - let result = this.tsne.getSolution(); - sampledIndices.forEach((index, i) => { - let dataPoint = this.points[index]; - - dataPoint.projections['tsne-0'] = result[i * tsneDim + 0]; - dataPoint.projections['tsne-1'] = result[i * tsneDim + 1]; - if (tsneDim === 3) { - dataPoint.projections['tsne-2'] = result[i * tsneDim + 2]; - } - }); - this.projections['tsne'] = true; - this.tSNEIteration++; - stepCallback(this.tSNEIteration); + for (let d = 0; d < NUM_PCA_COMPONENTS; d++) { + let label = 'pca-' + d; + this.projections[label] = true; + for (let i = 0; i < pcaVectors.length; i++) { + let pointIndex = this.shuffledDataIndices[i]; + this.points[pointIndex].projections[label] = pcaVectors[i][d]; } - requestAnimationFrame(step); - }; - - const sampledData = sampledIndices.map((i) => this.points[i]); - const knnComputation = this.computeKnn(sampledData, k); - - knnComputation.then((nearest) => { - util - .runAsyncTask('Initializing T-SNE...', () => { - this.tsne.initDataDist(nearest); - }) - .then(step); - }); - } - - /** Runs UMAP on the data. */ - async projectUmap( - nComponents: number, - nNeighbors: number, - stepCallback: (iter: number) => void - ) { - this.hasUmapRun = true; - this.umap = new UMAP({nComponents, nNeighbors}); - - let currentEpoch = 0; - const epochStepSize = 10; - const sampledIndices = this.shuffledDataIndices.slice( - 0, - UMAP_SAMPLE_SIZE - ); - - const sampledData = sampledIndices.map((i) => this.points[i]); - // TODO: Switch to a Float32-based UMAP internal - const X = sampledData.map((x) => Array.from(x.vector)); - - const nearest = await this.computeKnn(sampledData, nNeighbors); - - const nEpochs = await util.runAsyncTask( - 'Initializing UMAP...', - () => { - const knnIndices = nearest.map((row) => - row.map((entry) => entry.index) - ); - const knnDistances = nearest.map((row) => - row.map((entry) => entry.dist) - ); - - // Initialize UMAP and return the number of epochs. - this.umap.setPrecomputedKNN(knnIndices, knnDistances); - return this.umap.initializeFit(X); - }, - UMAP_MSG_ID - ); - - // Now, iterate through all epoch batches of the UMAP optimization, updating - // the modal window with the progress rather than animating each step since - // the UMAP animation is not nearly as informative as t-SNE. - return new Promise((resolve, reject) => { - const step = () => { - // Compute a batch of epochs since we don't want to update the UI - // on every epoch. - const epochsBatch = Math.min(epochStepSize, nEpochs - currentEpoch); - for (let i = 0; i < epochsBatch; i++) { - currentEpoch = this.umap.step(); - } - const progressMsg = `Optimizing UMAP (epoch ${currentEpoch} of ${nEpochs})`; - - // Wrap the logic in a util.runAsyncTask in order to correctly update - // the modal with the progress of the optimization. - util - .runAsyncTask( - progressMsg, - () => { - if (currentEpoch < nEpochs) { - requestAnimationFrame(step); - } else { - const result = this.umap.getEmbedding(); - sampledIndices.forEach((index, i) => { - const dataPoint = this.points[index]; - - dataPoint.projections['umap-0'] = result[i][0]; - dataPoint.projections['umap-1'] = result[i][1]; - if (nComponents === 3) { - dataPoint.projections['umap-2'] = result[i][2]; - } - }); - this.projections['umap'] = true; - - logging.setModalMessage(null, UMAP_MSG_ID); - this.hasUmapRun = true; - stepCallback(currentEpoch); - resolve(); - } - }, - UMAP_MSG_ID, - 0 - ) - .catch((error) => { - logging.setModalMessage(null, UMAP_MSG_ID); - reject(error); - }); - }; - - requestAnimationFrame(step); - }); - } - - /** Computes KNN to provide to the UMAP and t-SNE algorithms. */ - private async computeKnn( - data: DataPoint[], - nNeighbors: number - ): Promise { - // Handle the case where we've previously found the nearest neighbors. - const previouslyComputedNNeighbors = - this.nearest && this.nearest.length ? this.nearest[0].length : 0; - if (this.nearest != null && previouslyComputedNNeighbors >= nNeighbors) { - return Promise.resolve( - this.nearest.map((neighbors) => neighbors.slice(0, nNeighbors)) - ); - } else { - const result = await (KNN_GPU_ENABLED - ? knn.findKNNGPUCosine(data, nNeighbors, (d) => d.vector) - : knn.findKNN( - data, - nNeighbors, - (d) => d.vector, - (a, b) => vector.cosDistNorm(a, b) - )); - this.nearest = result; - return Promise.resolve(result); } - } - - /* Perturb TSNE and update dataset point coordinates. */ - perturbTsne() { - if (this.hasTSNERun && this.tsne) { - this.tsne.perturb(); - let tsneDim = this.tsne.getDim(); + }); + } + /** Runs tsne on the data. */ + projectTSNE( + perplexity: number, + learningRate: number, + tsneDim: number, + stepCallback: (iter: number) => void + ) { + this.hasTSNERun = true; + let k = Math.floor(3 * perplexity); + let opt = {epsilon: learningRate, perplexity: perplexity, dim: tsneDim}; + this.tsne = new TSNE(opt); + this.tsne.setSupervision(this.superviseLabels, this.superviseInput); + this.tsne.setSuperviseFactor(this.superviseFactor); + this.tSNEShouldPause = false; + this.tSNEShouldStop = false; + this.tSNEIteration = 0; + let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE); + let step = () => { + if (this.tSNEShouldStop) { + this.projections['tsne'] = false; + stepCallback(null); + this.tsne = null; + this.hasTSNERun = false; + return; + } + if (!this.tSNEShouldPause) { + this.tsne.step(); let result = this.tsne.getSolution(); - let sampledIndices = this.shuffledDataIndices.slice( - 0, - TSNE_SAMPLE_SIZE - ); - sampledIndices.forEach((index, i) => { let dataPoint = this.points[index]; - dataPoint.projections['tsne-0'] = result[i * tsneDim + 0]; dataPoint.projections['tsne-1'] = result[i * tsneDim + 1]; if (tsneDim === 3) { dataPoint.projections['tsne-2'] = result[i * tsneDim + 2]; } }); + this.projections['tsne'] = true; + this.tSNEIteration++; + stepCallback(this.tSNEIteration); } - } - - setSupervision(superviseColumn: string, superviseInput?: string) { - if (superviseColumn != null) { - this.superviseLabels = this.shuffledDataIndices - .slice(0, TSNE_SAMPLE_SIZE) - .map((index) => - this.points[index].metadata[superviseColumn] !== undefined - ? String(this.points[index].metadata[superviseColumn]) - : `Unknown #${index}` - ); - } - if (superviseInput != null) { - this.superviseInput = superviseInput; - } - if (this.tsne) { - this.tsne.setSupervision(this.superviseLabels, this.superviseInput); - } - } - - setSuperviseFactor(superviseFactor: number) { - if (superviseFactor != null) { - this.superviseFactor = superviseFactor; - if (this.tsne) { - this.tsne.setSuperviseFactor(superviseFactor); + requestAnimationFrame(step); + }; + const sampledData = sampledIndices.map((i) => this.points[i]); + const knnComputation = this.computeKnn(sampledData, k); + knnComputation.then((nearest) => { + util + .runAsyncTask('Initializing T-SNE...', () => { + this.tsne.initDataDist(nearest); + }) + .then(step); + }); + } + /** Runs UMAP on the data. */ + async projectUmap( + nComponents: number, + nNeighbors: number, + stepCallback: (iter: number) => void + ) { + this.hasUmapRun = true; + this.umap = new UMAP({nComponents, nNeighbors}); + let currentEpoch = 0; + const epochStepSize = 10; + const sampledIndices = this.shuffledDataIndices.slice(0, UMAP_SAMPLE_SIZE); + const sampledData = sampledIndices.map((i) => this.points[i]); + // TODO: Switch to a Float32-based UMAP internal + const X = sampledData.map((x) => Array.from(x.vector)); + const nearest = await this.computeKnn(sampledData, nNeighbors); + const nEpochs = await util.runAsyncTask( + 'Initializing UMAP...', + () => { + const knnIndices = nearest.map((row) => + row.map((entry) => entry.index) + ); + const knnDistances = nearest.map((row) => + row.map((entry) => entry.dist) + ); + // Initialize UMAP and return the number of epochs. + this.umap.setPrecomputedKNN(knnIndices, knnDistances); + return this.umap.initializeFit(X); + }, + UMAP_MSG_ID + ); + // Now, iterate through all epoch batches of the UMAP optimization, updating + // the modal window with the progress rather than animating each step since + // the UMAP animation is not nearly as informative as t-SNE. + return new Promise((resolve, reject) => { + const step = () => { + // Compute a batch of epochs since we don't want to update the UI + // on every epoch. + const epochsBatch = Math.min(epochStepSize, nEpochs - currentEpoch); + for (let i = 0; i < epochsBatch; i++) { + currentEpoch = this.umap.step(); } - } + const progressMsg = `Optimizing UMAP (epoch ${currentEpoch} of ${nEpochs})`; + // Wrap the logic in a util.runAsyncTask in order to correctly update + // the modal with the progress of the optimization. + util + .runAsyncTask( + progressMsg, + () => { + if (currentEpoch < nEpochs) { + requestAnimationFrame(step); + } else { + const result = this.umap.getEmbedding(); + sampledIndices.forEach((index, i) => { + const dataPoint = this.points[index]; + dataPoint.projections['umap-0'] = result[i][0]; + dataPoint.projections['umap-1'] = result[i][1]; + if (nComponents === 3) { + dataPoint.projections['umap-2'] = result[i][2]; + } + }); + this.projections['umap'] = true; + logging.setModalMessage(null, UMAP_MSG_ID); + this.hasUmapRun = true; + stepCallback(currentEpoch); + resolve(); + } + }, + UMAP_MSG_ID, + 0 + ) + .catch((error) => { + logging.setModalMessage(null, UMAP_MSG_ID); + reject(error); + }); + }; + requestAnimationFrame(step); + }); + } + /** Computes KNN to provide to the UMAP and t-SNE algorithms. */ + private async computeKnn( + data: DataPoint[], + nNeighbors: number + ): Promise { + // Handle the case where we've previously found the nearest neighbors. + const previouslyComputedNNeighbors = + this.nearest && this.nearest.length ? this.nearest[0].length : 0; + if (this.nearest != null && previouslyComputedNNeighbors >= nNeighbors) { + return Promise.resolve( + this.nearest.map((neighbors) => neighbors.slice(0, nNeighbors)) + ); + } else { + const result = await (KNN_GPU_ENABLED + ? knn.findKNNGPUCosine(data, nNeighbors, (d) => d.vector) + : knn.findKNN( + data, + nNeighbors, + (d) => d.vector, + (a, b) => vector.cosDistNorm(a, b) + )); + this.nearest = result; + return Promise.resolve(result); } - - /** - * Merges metadata to the dataset and returns whether it succeeded. - */ - mergeMetadata(metadata: SpriteAndMetadataInfo): boolean { - if (metadata.pointsInfo.length !== this.points.length) { - let errorMessage = - `Number of tensors (${this.points.length}) do not` + - ` match the number of lines in metadata` + - ` (${metadata.pointsInfo.length}).`; - - if ( - metadata.stats.length === 1 && - this.points.length + 1 === metadata.pointsInfo.length - ) { - // If there is only one column of metadata and the number of points is - // exactly one less than the number of metadata lines, this is due to an - // unnecessary header line in the metadata and we can show a meaningful - // error. - logging.setErrorMessage( - errorMessage + - ' Single column metadata should not have a header ' + - 'row.', - 'merging metadata' - ); - return false; - } else if ( - metadata.stats.length > 1 && - this.points.length - 1 === metadata.pointsInfo.length - ) { - // If there are multiple columns of metadata and the number of points is - // exactly one greater than the number of lines in the metadata, this - // means there is a missing metadata header. - logging.setErrorMessage( - errorMessage + - ' Multi-column metadata should have a header ' + - 'row with column labels.', - 'merging metadata' - ); - return false; + } + /* Perturb TSNE and update dataset point coordinates. */ + perturbTsne() { + if (this.hasTSNERun && this.tsne) { + this.tsne.perturb(); + let tsneDim = this.tsne.getDim(); + let result = this.tsne.getSolution(); + let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE); + sampledIndices.forEach((index, i) => { + let dataPoint = this.points[index]; + dataPoint.projections['tsne-0'] = result[i * tsneDim + 0]; + dataPoint.projections['tsne-1'] = result[i * tsneDim + 1]; + if (tsneDim === 3) { + dataPoint.projections['tsne-2'] = result[i * tsneDim + 2]; } - - logging.setWarningMessage(errorMessage); - } - this.spriteAndMetadataInfo = metadata; - metadata.pointsInfo - .slice(0, this.points.length) - .forEach((m, i) => (this.points[i].metadata = m)); - return true; + }); } - - stopTSNE() { - this.tSNEShouldStop = true; + } + setSupervision(superviseColumn: string, superviseInput?: string) { + if (superviseColumn != null) { + this.superviseLabels = this.shuffledDataIndices + .slice(0, TSNE_SAMPLE_SIZE) + .map((index) => + this.points[index].metadata[superviseColumn] !== undefined + ? String(this.points[index].metadata[superviseColumn]) + : `Unknown #${index}` + ); } - - /** - * Finds the nearest neighbors of the query point using a - * user-specified distance metric. - */ - findNeighbors( - pointIndex: number, - distFunc: DistanceFunction, - numNN: number - ): knn.NearestEntry[] { - // Find the nearest neighbors of a particular point. - let neighbors = knn.findKNNofPoint( - this.points, - pointIndex, - numNN, - (d) => d.vector, - distFunc - ); - // TODO(@dsmilkov): Figure out why we slice. - let result = neighbors.slice(0, numNN); - return result; + if (superviseInput != null) { + this.superviseInput = superviseInput; } - - /** - * Search the dataset based on a metadata field. - */ - query(query: string, inRegexMode: boolean, fieldName: string): number[] { - let predicate = util.getSearchPredicate(query, inRegexMode, fieldName); - let matches: number[] = []; - this.points.forEach((point, id) => { - if (predicate(point)) { - matches.push(id); - } - }); - return matches; + if (this.tsne) { + this.tsne.setSupervision(this.superviseLabels, this.superviseInput); } } - - export type ProjectionType = 'tsne' | 'umap' | 'pca' | 'custom'; - - export class Projection { - constructor( - public projectionType: ProjectionType, - public projectionComponents: ProjectionComponents3D, - public dimensionality: number, - public dataSet: DataSet - ) {} + setSuperviseFactor(superviseFactor: number) { + if (superviseFactor != null) { + this.superviseFactor = superviseFactor; + if (this.tsne) { + this.tsne.setSuperviseFactor(superviseFactor); + } + } } - - export interface ColorOption { - name: string; - desc?: string; - map?: (value: string | number) => string; - /** List of items for the color map. Defined only for categorical map. */ - items?: {label: string; count: number}[]; - /** Threshold values and their colors. Defined for gradient color map. */ - thresholds?: {value: number; color: string}[]; - isSeparator?: boolean; - tooManyUniqueValues?: boolean; + /** + * Merges metadata to the dataset and returns whether it succeeded. + */ + mergeMetadata(metadata: SpriteAndMetadataInfo): boolean { + if (metadata.pointsInfo.length !== this.points.length) { + let errorMessage = + `Number of tensors (${this.points.length}) do not` + + ` match the number of lines in metadata` + + ` (${metadata.pointsInfo.length}).`; + if ( + metadata.stats.length === 1 && + this.points.length + 1 === metadata.pointsInfo.length + ) { + // If there is only one column of metadata and the number of points is + // exactly one less than the number of metadata lines, this is due to an + // unnecessary header line in the metadata and we can show a meaningful + // error. + logging.setErrorMessage( + errorMessage + + ' Single column metadata should not have a header ' + + 'row.', + 'merging metadata' + ); + return false; + } else if ( + metadata.stats.length > 1 && + this.points.length - 1 === metadata.pointsInfo.length + ) { + // If there are multiple columns of metadata and the number of points is + // exactly one greater than the number of lines in the metadata, this + // means there is a missing metadata header. + logging.setErrorMessage( + errorMessage + + ' Multi-column metadata should have a header ' + + 'row with column labels.', + 'merging metadata' + ); + return false; + } + logging.setWarningMessage(errorMessage); + } + this.spriteAndMetadataInfo = metadata; + metadata.pointsInfo + .slice(0, this.points.length) + .forEach((m, i) => (this.points[i].metadata = m)); + return true; + } + stopTSNE() { + this.tSNEShouldStop = true; } - /** - * An interface that holds all the data for serializing the current state of - * the world. + * Finds the nearest neighbors of the query point using a + * user-specified distance metric. */ - export class State { - /** A label identifying this state. */ - label: string = ''; - - /** Whether this State is selected in the bookmarks pane. */ - isSelected: boolean = false; - - /** The selected projection tab. */ - selectedProjection: ProjectionType; - - /** Dimensions of the DataSet. */ - dataSetDimensions: [number, number]; - - /** t-SNE parameters */ - tSNEIteration: number = 0; - tSNEPerplexity: number = 0; - tSNELearningRate: number = 0; - tSNEis3d: boolean = true; - - /** UMAP parameters */ - umapIs3d: boolean = true; - umapNeighbors: number = 15; - - /** PCA projection component dimensions */ - pcaComponentDimensions: number[] = []; - - /** Custom projection parameters */ - customSelectedSearchByMetadataOption: string; - customXLeftText: string; - customXLeftRegex: boolean; - customXRightText: string; - customXRightRegex: boolean; - customYUpText: string; - customYUpRegex: boolean; - customYDownText: string; - customYDownRegex: boolean; - - /** The computed projections of the tensors. */ - projections: Array<{[key: string]: number}> = []; - - /** Filtered dataset indices. */ - filteredPoints: number[]; - - /** The indices of selected points. */ - selectedPoints: number[] = []; - - /** Camera state (2d/3d, position, target, zoom, etc). */ - cameraDef: CameraDef; - - /** Color by option. */ - selectedColorOptionName: string; - forceCategoricalColoring: boolean; - - /** Label by option. */ - selectedLabelOption: string; + findNeighbors( + pointIndex: number, + distFunc: DistanceFunction, + numNN: number + ): knn.NearestEntry[] { + // Find the nearest neighbors of a particular point. + let neighbors = knn.findKNNofPoint( + this.points, + pointIndex, + numNN, + (d) => d.vector, + distFunc + ); + // TODO(@dsmilkov): Figure out why we slice. + let result = neighbors.slice(0, numNN); + return result; } - - export function getProjectionComponents( - projection: ProjectionType, - components: (number | string)[] - ): ProjectionComponents3D { - if (components.length > 3) { - throw new RangeError('components length must be <= 3'); - } - const projectionComponents: [string, string, string] = [null, null, null]; - const prefix = projection === 'custom' ? 'linear' : projection; - for (let i = 0; i < components.length; ++i) { - if (components[i] == null) { - continue; + /** + * Search the dataset based on a metadata field. + */ + query(query: string, inRegexMode: boolean, fieldName: string): number[] { + let predicate = util.getSearchPredicate(query, inRegexMode, fieldName); + let matches: number[] = []; + this.points.forEach((point, id) => { + if (predicate(point)) { + matches.push(id); } - projectionComponents[i] = prefix + '-' + components[i]; - } - return projectionComponents; + }); + return matches; } - - export function stateGetAccessorDimensions( - state: State - ): Array { - let dimensions: Array; - switch (state.selectedProjection) { - case 'pca': - dimensions = state.pcaComponentDimensions.slice(); - break; - case 'tsne': - dimensions = [0, 1]; - if (state.tSNEis3d) { - dimensions.push(2); - } - break; - case 'umap': - dimensions = [0, 1]; - if (state.umapIs3d) { - dimensions.push(2); - } - break; - case 'custom': - dimensions = ['x', 'y']; - break; - default: - throw new Error('Unexpected fallthrough'); +} +export type ProjectionType = 'tsne' | 'umap' | 'pca' | 'custom'; +export class Projection { + constructor( + public projectionType: ProjectionType, + public projectionComponents: ProjectionComponents3D, + public dimensionality: number, + public dataSet: DataSet + ) {} +} +export interface ColorOption { + name: string; + desc?: string; + map?: (value: string | number) => string; + /** List of items for the color map. Defined only for categorical map. */ + items?: { + label: string; + count: number; + }[]; + /** Threshold values and their colors. Defined for gradient color map. */ + thresholds?: { + value: number; + color: string; + }[]; + isSeparator?: boolean; + tooManyUniqueValues?: boolean; +} +/** + * An interface that holds all the data for serializing the current state of + * the world. + */ +export class State { + /** A label identifying this state. */ + label: string = ''; + /** Whether this State is selected in the bookmarks pane. */ + isSelected: boolean = false; + /** The selected projection tab. */ + selectedProjection: ProjectionType; + /** Dimensions of the DataSet. */ + dataSetDimensions: [number, number]; + /** t-SNE parameters */ + tSNEIteration: number = 0; + tSNEPerplexity: number = 0; + tSNELearningRate: number = 0; + tSNEis3d: boolean = true; + /** UMAP parameters */ + umapIs3d: boolean = true; + umapNeighbors: number = 15; + /** PCA projection component dimensions */ + pcaComponentDimensions: number[] = []; + /** Custom projection parameters */ + customSelectedSearchByMetadataOption: string; + customXLeftText: string; + customXLeftRegex: boolean; + customXRightText: string; + customXRightRegex: boolean; + customYUpText: string; + customYUpRegex: boolean; + customYDownText: string; + customYDownRegex: boolean; + /** The computed projections of the tensors. */ + projections: Array<{ + [key: string]: number; + }> = []; + /** Filtered dataset indices. */ + filteredPoints: number[]; + /** The indices of selected points. */ + selectedPoints: number[] = []; + /** Camera state (2d/3d, position, target, zoom, etc). */ + cameraDef: CameraDef; + /** Color by option. */ + selectedColorOptionName: string; + forceCategoricalColoring: boolean; + /** Label by option. */ + selectedLabelOption: string; +} +export function getProjectionComponents( + projection: ProjectionType, + components: (number | string)[] +): ProjectionComponents3D { + if (components.length > 3) { + throw new RangeError('components length must be <= 3'); + } + const projectionComponents: [string, string, string] = [null, null, null]; + const prefix = projection === 'custom' ? 'linear' : projection; + for (let i = 0; i < components.length; ++i) { + if (components[i] == null) { + continue; } - return dimensions; + projectionComponents[i] = prefix + '-' + components[i]; + } + return projectionComponents; +} +export function stateGetAccessorDimensions( + state: State +): Array { + let dimensions: Array; + switch (state.selectedProjection) { + case 'pca': + dimensions = state.pcaComponentDimensions.slice(); + break; + case 'tsne': + dimensions = [0, 1]; + if (state.tSNEis3d) { + dimensions.push(2); + } + break; + case 'umap': + dimensions = [0, 1]; + if (state.umapIs3d) { + dimensions.push(2); + } + break; + case 'custom': + dimensions = ['x', 'y']; + break; + default: + throw new Error('Unexpected fallthrough'); } -} // namespace vz_projector + return dimensions; +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/external.d.ts b/tensorboard/plugins/projector/polymer3/vz_projector/external.d.ts index 5cfc0b380a..f7d5d7989a 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/external.d.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/external.d.ts @@ -12,38 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -// TODO(@dsmilkov): Split into weblas.d.ts and numeric.d.ts and write -// typings for numeric. -interface Tensor { - new (size: [number, number], data: Float32Array); - transfer(): Float32Array; - delete(): void; -} - -interface Weblas { - sgemm( - M: number, - N: number, - K: number, - alpha: number, - A: Float32Array, - B: Float32Array, - beta: number, - C: Float32Array - ): Float32Array; - pipeline: { - Tensor: Tensor; - sgemm(alpha: number, A: Tensor, B: Tensor, beta: number, C: Tensor): Tensor; - }; - util: { - transpose(M: number, N: number, data: Float32Array): Tensor; - }; -} - -declare let numeric: any; -declare let weblas: Weblas; - interface AnalyticsEventType { hitType: string; page?: string; diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/heap.ts b/tensorboard/plugins/projector/polymer3/vz_projector/heap.ts index 04812b57b5..9347331e36 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/heap.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/heap.ts @@ -12,151 +12,130 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - /** Min key heap. */ - export type HeapItem = { - key: number; - value: T; - }; - /** - * Min-heap data structure. Provides O(1) for peek, returning the smallest key. - */ - // TODO(@jart): Rename to Heap and use Comparator. - export class MinHeap { - private arr: HeapItem[] = []; - - /** Push an element with the provided key. */ - push(key: number, value: T): void { - this.arr.push({key, value}); - this.bubbleUp(this.arr.length - 1); - } - - /** Pop the element with the smallest key. */ - pop(): HeapItem { - if (this.arr.length === 0) { - throw new Error('pop() called on empty binary heap'); - } - let item = this.arr[0]; - let last = this.arr.length - 1; - this.arr[0] = this.arr[last]; - this.arr.pop(); - if (last > 0) { - this.bubbleDown(0); - } - return item; - } - - /** Returns, but doesn't remove the element with the smallest key */ - peek(): HeapItem { - return this.arr[0]; +export type HeapItem = { + key: number; + value: T; +}; +/** + * Min-heap data structure. Provides O(1) for peek, returning the smallest key. + */ +// TODO(@jart): Rename to Heap and use Comparator. +export class MinHeap { + private arr: HeapItem[] = []; + /** Push an element with the provided key. */ + push(key: number, value: T): void { + this.arr.push({key, value}); + this.bubbleUp(this.arr.length - 1); + } + /** Pop the element with the smallest key. */ + pop(): HeapItem { + if (this.arr.length === 0) { + throw new Error('pop() called on empty binary heap'); } - - /** - * Pops the element with the smallest key and at the same time - * adds the newly provided element. This is faster than calling - * pop() and push() separately. - */ - popPush(key: number, value: T): HeapItem { - if (this.arr.length === 0) { - throw new Error('pop() called on empty binary heap'); - } - let item = this.arr[0]; - this.arr[0] = {key, value}; - if (this.arr.length > 0) { - this.bubbleDown(0); - } - return item; + let item = this.arr[0]; + let last = this.arr.length - 1; + this.arr[0] = this.arr[last]; + this.arr.pop(); + if (last > 0) { + this.bubbleDown(0); } - - /** Returns the number of elements in the heap. */ - size(): number { - return this.arr.length; + return item; + } + /** Returns, but doesn't remove the element with the smallest key */ + peek(): HeapItem { + return this.arr[0]; + } + /** + * Pops the element with the smallest key and at the same time + * adds the newly provided element. This is faster than calling + * pop() and push() separately. + */ + popPush(key: number, value: T): HeapItem { + if (this.arr.length === 0) { + throw new Error('pop() called on empty binary heap'); } - - /** Returns all the items in the heap. */ - items(): HeapItem[] { - return this.arr; + let item = this.arr[0]; + this.arr[0] = {key, value}; + if (this.arr.length > 0) { + this.bubbleDown(0); } - - private swap(a: number, b: number) { - let temp = this.arr[a]; - this.arr[a] = this.arr[b]; - this.arr[b] = temp; + return item; + } + /** Returns the number of elements in the heap. */ + size(): number { + return this.arr.length; + } + /** Returns all the items in the heap. */ + items(): HeapItem[] { + return this.arr; + } + private swap(a: number, b: number) { + let temp = this.arr[a]; + this.arr[a] = this.arr[b]; + this.arr[b] = temp; + } + private bubbleDown(pos: number) { + let left = (pos << 1) + 1; + let right = left + 1; + let largest = pos; + if (left < this.arr.length && this.arr[left].key < this.arr[largest].key) { + largest = left; } - - private bubbleDown(pos: number) { - let left = (pos << 1) + 1; - let right = left + 1; - let largest = pos; - if ( - left < this.arr.length && - this.arr[left].key < this.arr[largest].key - ) { - largest = left; - } - if ( - right < this.arr.length && - this.arr[right].key < this.arr[largest].key - ) { - largest = right; - } - if (largest !== pos) { - this.swap(largest, pos); - this.bubbleDown(largest); - } + if ( + right < this.arr.length && + this.arr[right].key < this.arr[largest].key + ) { + largest = right; } - - private bubbleUp(pos: number) { - if (pos <= 0) { - return; - } - let parent = (pos - 1) >> 1; - if (this.arr[pos].key < this.arr[parent].key) { - this.swap(pos, parent); - this.bubbleUp(parent); - } + if (largest !== pos) { + this.swap(largest, pos); + this.bubbleDown(largest); } } - - /** List that keeps the K elements with the smallest keys. */ - export class KMin { - private k: number; - private maxHeap = new MinHeap(); - - /** Constructs a new k-min data structure with the provided k. */ - constructor(k: number) { - this.k = k; + private bubbleUp(pos: number) { + if (pos <= 0) { + return; } - - /** Adds an element to the list. */ - add(key: number, value: T) { - if (this.maxHeap.size() < this.k) { - this.maxHeap.push(-key, value); - return; - } - let largest = this.maxHeap.peek(); - // If the new element is smaller, replace the largest with the new element. - if (key < -largest.key) { - this.maxHeap.popPush(-key, value); - } + let parent = (pos - 1) >> 1; + if (this.arr[pos].key < this.arr[parent].key) { + this.swap(pos, parent); + this.bubbleUp(parent); } - - /** Returns the k items with the smallest keys. */ - getMinKItems(): T[] { - let items = this.maxHeap.items(); - items.sort((a, b) => b.key - a.key); - return items.map((a) => a.value); - } - - /** Returns the size of the list. */ - getSize(): number { - return this.maxHeap.size(); + } +} +/** List that keeps the K elements with the smallest keys. */ +export class KMin { + private k: number; + private maxHeap = new MinHeap(); + /** Constructs a new k-min data structure with the provided k. */ + constructor(k: number) { + this.k = k; + } + /** Adds an element to the list. */ + add(key: number, value: T) { + if (this.maxHeap.size() < this.k) { + this.maxHeap.push(-key, value); + return; } - - /** Returns the largest key in the list. */ - getLargestKey(): number { - return this.maxHeap.size() === 0 ? null : -this.maxHeap.peek().key; + let largest = this.maxHeap.peek(); + // If the new element is smaller, replace the largest with the new element. + if (key < -largest.key) { + this.maxHeap.popPush(-key, value); } } -} // namespace vz_projector + /** Returns the k items with the smallest keys. */ + getMinKItems(): T[] { + let items = this.maxHeap.items(); + items.sort((a, b) => b.key - a.key); + return items.map((a) => a.value); + } + /** Returns the size of the list. */ + getSize(): number { + return this.maxHeap.size(); + } + /** Returns the largest key in the list. */ + getLargestKey(): number | null { + return this.maxHeap.size() === 0 ? null : -this.maxHeap.peek().key; + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/knn.ts b/tensorboard/plugins/projector/polymer3/vz_projector/knn.ts index 68a7893c92..9e6ffd318f 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/knn.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/knn.ts @@ -12,253 +12,249 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector.knn { - export type NearestEntry = { - index: number; - dist: number; - }; +import weblas from 'weblas/dist/weblas'; - /** - * Optimal size for the height of the matrix when doing computation on the GPU - * using WebGL. This was found experimentally. - * - * This also guarantees that for computing pair-wise distance for up to 10K - * vectors, no more than 40MB will be allocated in the GPU. Without the - * allocation limit, we can freeze the graphics of the whole OS. - */ - const OPTIMAL_GPU_BLOCK_SIZE = 256; - /** Id of message box used for knn gpu progress bar. */ - const KNN_GPU_MSG_ID = 'knn-gpu'; +import {KMin} from './heap'; +import * as vector from './vector'; +import * as logging from './logging'; +import * as util from './util'; - /** - * Returns the K nearest neighbors for each vector where the distance - * computation is done on the GPU (WebGL) using cosine distance. - * - * @param dataPoints List of data points, where each data point holds an - * n-dimensional vector. - * @param k Number of nearest neighbors to find. - * @param accessor A method that returns the vector, given the data point. - */ - export function findKNNGPUCosine( - dataPoints: T[], - k: number, - accessor: (dataPoint: T) => Float32Array - ): Promise { - let N = dataPoints.length; - let dim = accessor(dataPoints[0]).length; - - // The goal is to compute a large matrix multiplication A*A.T where A is of - // size NxD and A.T is its transpose. This results in a NxN matrix which - // could be too big to store on the GPU memory. To avoid memory overflow, we - // compute multiple A*partial_A.T where partial_A is of size BxD (B is much - // smaller than N). This results in storing only NxB size matrices on the GPU - // at a given time. - - // A*A.T will give us NxN matrix holding the cosine distance between every - // pair of points, which we sort using KMin data structure to obtain the - // K nearest neighbors for each point. - let typedArray = vector.toTypedArray(dataPoints, accessor); - let bigMatrix = new weblas.pipeline.Tensor([N, dim], typedArray); - let nearest: NearestEntry[][] = new Array(N); - let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE); - let M = Math.floor(N / numPieces); - let modulo = N % numPieces; - let offset = 0; - let progress = 0; - let progressDiff = 1 / (2 * numPieces); - let piece = 0; - - function step(resolve: (result: NearestEntry[][]) => void) { - let progressMsg = - 'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%'; - util - .runAsyncTask( - progressMsg, - () => { - let B = piece < modulo ? M + 1 : M; - let typedB = new Float32Array(B * dim); - for (let i = 0; i < B; ++i) { - let vector = accessor(dataPoints[offset + i]); - for (let d = 0; d < dim; ++d) { - typedB[i * dim + d] = vector[d]; - } +export type NearestEntry = { + index: number; + dist: number; +}; +/** + * Optimal size for the height of the matrix when doing computation on the GPU + * using WebGL. This was found experimentally. + * + * This also guarantees that for computing pair-wise distance for up to 10K + * vectors, no more than 40MB will be allocated in the GPU. Without the + * allocation limit, we can freeze the graphics of the whole OS. + */ +const OPTIMAL_GPU_BLOCK_SIZE = 256; +/** Id of message box used for knn gpu progress bar. */ +const KNN_GPU_MSG_ID = 'knn-gpu'; +/** + * Returns the K nearest neighbors for each vector where the distance + * computation is done on the GPU (WebGL) using cosine distance. + * + * @param dataPoints List of data points, where each data point holds an + * n-dimensional vector. + * @param k Number of nearest neighbors to find. + * @param accessor A method that returns the vector, given the data point. + */ +export function findKNNGPUCosine( + dataPoints: T[], + k: number, + accessor: (dataPoint: T) => Float32Array +): Promise { + let N = dataPoints.length; + let dim = accessor(dataPoints[0]).length; + // The goal is to compute a large matrix multiplication A*A.T where A is of + // size NxD and A.T is its transpose. This results in a NxN matrix which + // could be too big to store on the GPU memory. To avoid memory overflow, we + // compute multiple A*partial_A.T where partial_A is of size BxD (B is much + // smaller than N). This results in storing only NxB size matrices on the GPU + // at a given time. + // A*A.T will give us NxN matrix holding the cosine distance between every + // pair of points, which we sort using KMin data structure to obtain the + // K nearest neighbors for each point. + let typedArray = vector.toTypedArray(dataPoints, accessor); + let bigMatrix = new weblas.pipeline.Tensor([N, dim], typedArray); + let nearest: NearestEntry[][] = new Array(N); + let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE); + let M = Math.floor(N / numPieces); + let modulo = N % numPieces; + let offset = 0; + let progress = 0; + let progressDiff = 1 / (2 * numPieces); + let piece = 0; + function step(resolve: (result: NearestEntry[][]) => void) { + let progressMsg = + 'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%'; + util + .runAsyncTask( + progressMsg, + () => { + let B = piece < modulo ? M + 1 : M; + let typedB = new Float32Array(B * dim); + for (let i = 0; i < B; ++i) { + let vector = accessor(dataPoints[offset + i]); + for (let d = 0; d < dim; ++d) { + typedB[i * dim + d] = vector[d]; } - let partialMatrix = new weblas.pipeline.Tensor([B, dim], typedB); - // Result is N x B matrix. - let result = weblas.pipeline.sgemm( - 1, - bigMatrix, - partialMatrix, - null, - null - ); - let partial = result.transfer(); - partialMatrix.delete(); - result.delete(); - progress += progressDiff; - for (let i = 0; i < B; i++) { - let kMin = new KMin(k); - let iReal = offset + i; - for (let j = 0; j < N; j++) { - if (j === iReal) { - continue; - } - let cosDist = 1 - partial[j * B + i]; // [j, i]; - kMin.add(cosDist, {index: j, dist: cosDist}); + } + let partialMatrix = new weblas.pipeline.Tensor([B, dim], typedB); + // Result is N x B matrix. + let result = weblas.pipeline.sgemm( + 1, + bigMatrix, + partialMatrix, + null, + null + ); + let partial = result.transfer(); + partialMatrix.delete(); + result.delete(); + progress += progressDiff; + for (let i = 0; i < B; i++) { + let kMin = new KMin(k); + let iReal = offset + i; + for (let j = 0; j < N; j++) { + if (j === iReal) { + continue; } - nearest[iReal] = kMin.getMinKItems(); - } - progress += progressDiff; - offset += B; - piece++; - }, - KNN_GPU_MSG_ID - ) - .then( - () => { - if (piece < numPieces) { - step(resolve); - } else { - logging.setModalMessage(null, KNN_GPU_MSG_ID); - bigMatrix.delete(); - resolve(nearest); + let cosDist = 1 - partial[j * B + i]; // [j, i]; + kMin.add(cosDist, {index: j, dist: cosDist}); } - }, - (error) => { - // GPU failed. Reverting back to CPU. + nearest[iReal] = kMin.getMinKItems(); + } + progress += progressDiff; + offset += B; + piece++; + }, + KNN_GPU_MSG_ID + ) + .then( + () => { + if (piece < numPieces) { + step(resolve); + } else { logging.setModalMessage(null, KNN_GPU_MSG_ID); - let distFunc = (a, b, limit) => vector.cosDistNorm(a, b); - findKNN(dataPoints, k, accessor, distFunc).then((nearest) => { - resolve(nearest); - }); + bigMatrix.delete(); + resolve(nearest); } - ); - } - return new Promise((resolve) => step(resolve)); - } - - /** - * Returns the K nearest neighbors for each vector where the distance - * computation is done on the CPU using a user-specified distance method. - * - * @param dataPoints List of data points, where each data point holds an - * n-dimensional vector. - * @param k Number of nearest neighbors to find. - * @param accessor A method that returns the vector, given the data point. - * @param dist Method that takes two vectors and a limit, and computes the - * distance between two vectors, with the ability to stop early if the - * distance is above the limit. - */ - export function findKNN( - dataPoints: T[], - k: number, - accessor: (dataPoint: T) => Float32Array, - dist: (a: vector.Vector, b: vector.Vector, limit: number) => number - ): Promise { - return util.runAsyncTask( - 'Finding nearest neighbors...', - () => { - let N = dataPoints.length; - let nearest: NearestEntry[][] = new Array(N); - // Find the distances from node i. - let kMin: KMin[] = new Array(N); - for (let i = 0; i < N; i++) { - kMin[i] = new KMin(k); + }, + (error) => { + // GPU failed. Reverting back to CPU. + logging.setModalMessage(null, KNN_GPU_MSG_ID); + let distFunc = (a, b, limit) => vector.cosDistNorm(a, b); + findKNN(dataPoints, k, accessor, distFunc).then((nearest) => { + resolve(nearest); + }); } - for (let i = 0; i < N; i++) { - let a = accessor(dataPoints[i]); - let kMinA = kMin[i]; - for (let j = i + 1; j < N; j++) { - let kMinB = kMin[j]; - let limitI = - kMinA.getSize() === k - ? kMinA.getLargestKey() || Number.MAX_VALUE - : Number.MAX_VALUE; - let limitJ = - kMinB.getSize() === k - ? kMinB.getLargestKey() || Number.MAX_VALUE - : Number.MAX_VALUE; - let limit = Math.max(limitI, limitJ); - let dist2ItoJ = dist(a, accessor(dataPoints[j]), limit); - if (dist2ItoJ >= 0) { - kMinA.add(dist2ItoJ, {index: j, dist: dist2ItoJ}); - kMinB.add(dist2ItoJ, {index: i, dist: dist2ItoJ}); - } + ); + } + return new Promise((resolve) => step(resolve)); +} +/** + * Returns the K nearest neighbors for each vector where the distance + * computation is done on the CPU using a user-specified distance method. + * + * @param dataPoints List of data points, where each data point holds an + * n-dimensional vector. + * @param k Number of nearest neighbors to find. + * @param accessor A method that returns the vector, given the data point. + * @param dist Method that takes two vectors and a limit, and computes the + * distance between two vectors, with the ability to stop early if the + * distance is above the limit. + */ +export function findKNN( + dataPoints: T[], + k: number, + accessor: (dataPoint: T) => Float32Array, + dist: (a: vector.Vector, b: vector.Vector, limit: number) => number +): Promise { + return util.runAsyncTask( + 'Finding nearest neighbors...', + () => { + let N = dataPoints.length; + let nearest: NearestEntry[][] = new Array(N); + // Find the distances from node i. + let kMin: KMin[] = new Array(N); + for (let i = 0; i < N; i++) { + kMin[i] = new KMin(k); + } + for (let i = 0; i < N; i++) { + let a = accessor(dataPoints[i]); + let kMinA = kMin[i]; + for (let j = i + 1; j < N; j++) { + let kMinB = kMin[j]; + let limitI = + kMinA.getSize() === k + ? kMinA.getLargestKey() || Number.MAX_VALUE + : Number.MAX_VALUE; + let limitJ = + kMinB.getSize() === k + ? kMinB.getLargestKey() || Number.MAX_VALUE + : Number.MAX_VALUE; + let limit = Math.max(limitI, limitJ); + let dist2ItoJ = dist(a, accessor(dataPoints[j]), limit); + if (dist2ItoJ >= 0) { + kMinA.add(dist2ItoJ, {index: j, dist: dist2ItoJ}); + kMinB.add(dist2ItoJ, {index: i, dist: dist2ItoJ}); } } - for (let i = 0; i < N; i++) { - nearest[i] = kMin[i].getMinKItems(); - } - return nearest; } - ); - } - - /** Calculates the minimum distance between a search point and a rectangle. */ - function minDist( - point: [number, number], - x1: number, - y1: number, - x2: number, - y2: number - ) { - let x = point[0]; - let y = point[1]; - let dx1 = x - x1; - let dx2 = x - x2; - let dy1 = y - y1; - let dy2 = y - y2; - - if (dx1 * dx2 <= 0) { - // x is between x1 and x2 - if (dy1 * dy2 <= 0) { - // (x,y) is inside the rectangle - return 0; // return 0 as point is in rect + for (let i = 0; i < N; i++) { + nearest[i] = kMin[i].getMinKItems(); } - return Math.min(Math.abs(dy1), Math.abs(dy2)); + return nearest; } + ); +} +/** Calculates the minimum distance between a search point and a rectangle. */ +function minDist( + point: [number, number], + x1: number, + y1: number, + x2: number, + y2: number +) { + let x = point[0]; + let y = point[1]; + let dx1 = x - x1; + let dx2 = x - x2; + let dy1 = y - y1; + let dy2 = y - y2; + if (dx1 * dx2 <= 0) { + // x is between x1 and x2 if (dy1 * dy2 <= 0) { - // y is between y1 and y2 - // We know it is already inside the rectangle - return Math.min(Math.abs(dx1), Math.abs(dx2)); + // (x,y) is inside the rectangle + return 0; // return 0 as point is in rect } - let corner: [number, number]; - if (x > x2) { - // Upper-right vs lower-right. - corner = y > y2 ? [x2, y2] : [x2, y1]; - } else { - // Upper-left vs lower-left. - corner = y > y2 ? [x1, y2] : [x1, y1]; - } - return Math.sqrt(vector.dist22D([x, y], corner)); + return Math.min(Math.abs(dy1), Math.abs(dy2)); } - - /** - * Returns the nearest neighbors of a particular point. - * - * @param dataPoints List of data points. - * @param pointIndex The index of the point we need the nearest neighbors of. - * @param k Number of nearest neighbors to search for. - * @param accessor Method that maps a data point => vector (array of numbers). - * @param distance Method that takes two vectors and returns their distance. - */ - export function findKNNofPoint( - dataPoints: T[], - pointIndex: number, - k: number, - accessor: (dataPoint: T) => Float32Array, - distance: (a: vector.Vector, b: vector.Vector) => number - ) { - let kMin = new KMin(k); - let a = accessor(dataPoints[pointIndex]); - for (let i = 0; i < dataPoints.length; ++i) { - if (i === pointIndex) { - continue; - } - let b = accessor(dataPoints[i]); - let dist = distance(a, b); - kMin.add(dist, {index: i, dist: dist}); + if (dy1 * dy2 <= 0) { + // y is between y1 and y2 + // We know it is already inside the rectangle + return Math.min(Math.abs(dx1), Math.abs(dx2)); + } + let corner: [number, number]; + if (x > x2) { + // Upper-right vs lower-right. + corner = y > y2 ? [x2, y2] : [x2, y1]; + } else { + // Upper-left vs lower-left. + corner = y > y2 ? [x1, y2] : [x1, y1]; + } + return Math.sqrt(vector.dist22D([x, y], corner)); +} +/** + * Returns the nearest neighbors of a particular point. + * + * @param dataPoints List of data points. + * @param pointIndex The index of the point we need the nearest neighbors of. + * @param k Number of nearest neighbors to search for. + * @param accessor Method that maps a data point => vector (array of numbers). + * @param distance Method that takes two vectors and returns their distance. + */ +export function findKNNofPoint( + dataPoints: T[], + pointIndex: number, + k: number, + accessor: (dataPoint: T) => Float32Array, + distance: (a: vector.Vector, b: vector.Vector) => number +) { + let kMin = new KMin(k); + let a = accessor(dataPoints[pointIndex]); + for (let i = 0; i < dataPoints.length; ++i) { + if (i === pointIndex) { + continue; } - return kMin.getMinKItems(); + let b = accessor(dataPoints[i]); + let dist = distance(a, b); + kMin.add(dist, {index: i, dist: dist}); } -} // namespace vz_projector.knn + return kMin.getMinKItems(); +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/label.ts b/tensorboard/plugins/projector/polymer3/vz_projector/label.ts index 707f25c916..9598cec2e1 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/label.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/label.ts @@ -12,154 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - export interface BoundingBox { - loX: number; - loY: number; - hiX: number; - hiY: number; - } +export interface BoundingBox { + loX: number; + loY: number; + hiX: number; + hiY: number; +} +/** + * Accelerates label placement by dividing the view into a uniform grid. + * Labels only need to be tested for collision with other labels that overlap + * the same grid cells. This is a fork of {@code amoeba.CollisionGrid}. + */ +export class CollisionGrid { + private numHorizCells: number; + private numVertCells: number; + private grid: BoundingBox[][]; + private bound: BoundingBox; + private cellWidth: number; + private cellHeight: number; /** - * Accelerates label placement by dividing the view into a uniform grid. - * Labels only need to be tested for collision with other labels that overlap - * the same grid cells. This is a fork of {@code amoeba.CollisionGrid}. + * Constructs a new Collision grid. + * + * @param bound The bound of the grid. Labels out of bounds will be rejected. + * @param cellWidth Width of a cell in the grid. + * @param cellHeight Height of a cell in the grid. */ - export class CollisionGrid { - private numHorizCells: number; - private numVertCells: number; - private grid: BoundingBox[][]; - private bound: BoundingBox; - private cellWidth: number; - private cellHeight: number; - + constructor(bound: BoundingBox, cellWidth: number, cellHeight: number) { + /** The bound of the grid. Labels out of bounds will be rejected. */ + this.bound = bound; + /** Width of a cell in the grid. */ + this.cellWidth = cellWidth; + /** Height of a cell in the grid. */ + this.cellHeight = cellHeight; + /** Number of grid cells along the x axis. */ + this.numHorizCells = Math.ceil(this.boundWidth(bound) / cellWidth); + /** Number of grid cells along the y axis. */ + this.numVertCells = Math.ceil(this.boundHeight(bound) / cellHeight); /** - * Constructs a new Collision grid. - * - * @param bound The bound of the grid. Labels out of bounds will be rejected. - * @param cellWidth Width of a cell in the grid. - * @param cellHeight Height of a cell in the grid. + * The 2d grid (stored as a 1d array.) Each cell consists of an array of + * BoundingBoxes for objects that are in the cell. */ - constructor(bound: BoundingBox, cellWidth: number, cellHeight: number) { - /** The bound of the grid. Labels out of bounds will be rejected. */ - this.bound = bound; - - /** Width of a cell in the grid. */ - this.cellWidth = cellWidth; - - /** Height of a cell in the grid. */ - this.cellHeight = cellHeight; - - /** Number of grid cells along the x axis. */ - this.numHorizCells = Math.ceil(this.boundWidth(bound) / cellWidth); - - /** Number of grid cells along the y axis. */ - this.numVertCells = Math.ceil(this.boundHeight(bound) / cellHeight); - - /** - * The 2d grid (stored as a 1d array.) Each cell consists of an array of - * BoundingBoxes for objects that are in the cell. - */ - this.grid = new Array(this.numHorizCells * this.numVertCells); - } - - private boundWidth(bound: BoundingBox) { - return bound.hiX - bound.loX; - } - - private boundHeight(bound: BoundingBox) { - return bound.hiY - bound.loY; - } - - private boundsIntersect(a: BoundingBox, b: BoundingBox) { - return !( - a.loX > b.hiX || - a.loY > b.hiY || - a.hiX < b.loX || - a.hiY < b.loY - ); + this.grid = new Array(this.numHorizCells * this.numVertCells); + } + private boundWidth(bound: BoundingBox) { + return bound.hiX - bound.loX; + } + private boundHeight(bound: BoundingBox) { + return bound.hiY - bound.loY; + } + private boundsIntersect(a: BoundingBox, b: BoundingBox) { + return !(a.loX > b.hiX || a.loY > b.hiY || a.hiX < b.loX || a.hiY < b.loY); + } + /** + * Checks if a given bounding box has any conflicts in the grid and inserts it + * if none are found. + * + * @param bound The bound to insert. + * @param justTest If true, just test if it conflicts, without inserting. + * @return True if the bound was successfully inserted; false if it + * could not be inserted due to a conflict. + */ + insert(bound: BoundingBox, justTest = false): boolean { + // Reject if the label is out of bounds. + if ( + bound.hiX < this.bound.loX || + bound.loX > this.bound.hiX || + bound.hiY < this.bound.loY || + bound.loY > this.bound.hiY + ) { + return false; } - - /** - * Checks if a given bounding box has any conflicts in the grid and inserts it - * if none are found. - * - * @param bound The bound to insert. - * @param justTest If true, just test if it conflicts, without inserting. - * @return True if the bound was successfully inserted; false if it - * could not be inserted due to a conflict. - */ - insert(bound: BoundingBox, justTest = false): boolean { - // Reject if the label is out of bounds. - if ( - bound.hiX < this.bound.loX || - bound.loX > this.bound.hiX || - bound.hiY < this.bound.loY || - bound.loY > this.bound.hiY - ) { - return false; - } - - let minCellX = this.getCellX(bound.loX); - let maxCellX = this.getCellX(bound.hiX); - let minCellY = this.getCellY(bound.loY); - let maxCellY = this.getCellY(bound.hiY); - - // Check all overlapped cells to verify that we can insert. - let baseIdx = minCellY * this.numHorizCells + minCellX; - let idx = baseIdx; - for (let j = minCellY; j <= maxCellY; j++) { - for (let i = minCellX; i <= maxCellX; i++) { - let cell = this.grid[idx++]; - if (cell) { - for (let k = 0; k < cell.length; k++) { - if (this.boundsIntersect(bound, cell[k])) { - return false; - } + let minCellX = this.getCellX(bound.loX); + let maxCellX = this.getCellX(bound.hiX); + let minCellY = this.getCellY(bound.loY); + let maxCellY = this.getCellY(bound.hiY); + // Check all overlapped cells to verify that we can insert. + let baseIdx = minCellY * this.numHorizCells + minCellX; + let idx = baseIdx; + for (let j = minCellY; j <= maxCellY; j++) { + for (let i = minCellX; i <= maxCellX; i++) { + let cell = this.grid[idx++]; + if (cell) { + for (let k = 0; k < cell.length; k++) { + if (this.boundsIntersect(bound, cell[k])) { + return false; } } } - idx += this.numHorizCells - (maxCellX - minCellX + 1); - } - - if (justTest) { - return true; - } - - // Insert into the overlapped cells. - idx = baseIdx; - for (let j = minCellY; j <= maxCellY; j++) { - for (let i = minCellX; i <= maxCellX; i++) { - if (!this.grid[idx]) { - this.grid[idx] = [bound]; - } else { - this.grid[idx].push(bound); - } - idx++; - } - idx += this.numHorizCells - (maxCellX - minCellX + 1); } - return true; + idx += this.numHorizCells - (maxCellX - minCellX + 1); } - - /** - * Returns the x index of the grid cell where the given x coordinate falls. - * - * @param x the coordinate, in world space. - * @return the x index of the cell. - */ - private getCellX(x: number) { - return Math.floor((x - this.bound.loX) / this.cellWidth); + if (justTest) { + return true; } - - /** - * Returns the y index of the grid cell where the given y coordinate falls. - * - * @param y the coordinate, in world space. - * @return the y index of the cell. - */ - private getCellY(y: number) { - return Math.floor((y - this.bound.loY) / this.cellHeight); + // Insert into the overlapped cells. + idx = baseIdx; + for (let j = minCellY; j <= maxCellY; j++) { + for (let i = minCellX; i <= maxCellX; i++) { + if (!this.grid[idx]) { + this.grid[idx] = [bound]; + } else { + this.grid[idx].push(bound); + } + idx++; + } + idx += this.numHorizCells - (maxCellX - minCellX + 1); } + return true; + } + /** + * Returns the x index of the grid cell where the given x coordinate falls. + * + * @param x the coordinate, in world space. + * @return the x index of the cell. + */ + private getCellX(x: number) { + return Math.floor((x - this.bound.loX) / this.cellWidth); + } + /** + * Returns the y index of the grid cell where the given y coordinate falls. + * + * @param y the coordinate, in world space. + * @return the y index of the cell. + */ + private getCellY(y: number) { + return Math.floor((y - this.bound.loY) / this.cellHeight); } -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/logging.ts b/tensorboard/plugins/projector/polymer3/vz_projector/logging.ts index 3b2ca43e02..78f9709425 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/logging.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/logging.ts @@ -12,98 +12,88 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector.logging { - /** Duration in ms for showing warning messages to the user */ - const WARNING_DURATION_MS = 10000; - - let dom: HTMLElement = null; - let msgId = 0; - let numActiveMessages = 0; - - export function setDomContainer(domElement: HTMLElement) { - dom = domElement; +const WARNING_DURATION_MS = 10000; +let dom: HTMLElement = null; +let msgId = 0; +let numActiveMessages = 0; +export function setDomContainer(domElement: HTMLElement) { + dom = domElement; +} +/** + * Updates the user message with the provided id. + * + * @param msg The message shown to the user. If null, the message is removed. + * @param id The id of an existing message. If no id is provided, a unique id + * is assigned. + * @param title The title of the notification. + * @param isErrorMsg If true, the message is error and the dialog will have a + * close button. + * @return The id of the message. + */ +export function setModalMessage( + msg: string, + id: string = null, + title = null, + isErrorMsg = false +): string { + if (dom == null) { + console.warn("Can't show modal message before the dom is initialized"); + return; } - - /** - * Updates the user message with the provided id. - * - * @param msg The message shown to the user. If null, the message is removed. - * @param id The id of an existing message. If no id is provided, a unique id - * is assigned. - * @param title The title of the notification. - * @param isErrorMsg If true, the message is error and the dialog will have a - * close button. - * @return The id of the message. - */ - export function setModalMessage( - msg: string, - id: string = null, - title = null, - isErrorMsg = false - ): string { - if (dom == null) { - console.warn("Can't show modal message before the dom is initialized"); - return; - } - if (id == null) { - id = (msgId++).toString(); - } - let dialog = dom.shadowRoot.querySelector('#notification-dialog') as any; - dialog.querySelector('.close-button').style.display = isErrorMsg - ? null - : 'none'; - let spinner = dialog.querySelector('.progress'); - spinner.style.display = isErrorMsg ? 'none' : null; - spinner.active = isErrorMsg ? null : true; - dialog.querySelector('#notification-title').innerHTML = title; - let msgsContainer = dialog.querySelector('#notify-msgs') as HTMLElement; - if (isErrorMsg) { - msgsContainer.innerHTML = ''; - } else { - const errors = msgsContainer.querySelectorAll('.error'); - for (let i = 0; i < errors.length; i++) { - msgsContainer.removeChild(errors[i]); - } - } - let divId = `notify-msg-${id}`; - let msgDiv = dialog.querySelector('#' + divId) as HTMLDivElement; - if (msgDiv == null) { - msgDiv = document.createElement('div'); - msgDiv.className = 'notify-msg ' + (isErrorMsg ? 'error' : ''); - msgDiv.id = divId; - - msgsContainer.insertBefore(msgDiv, msgsContainer.firstChild); - - if (!isErrorMsg) { - numActiveMessages++; - } else { - numActiveMessages = 0; - } + if (id == null) { + id = (msgId++).toString(); + } + let dialog = dom.shadowRoot.querySelector('#notification-dialog') as any; + dialog.querySelector('.close-button').style.display = isErrorMsg + ? null + : 'none'; + let spinner = dialog.querySelector('.progress'); + spinner.style.display = isErrorMsg ? 'none' : null; + spinner.active = isErrorMsg ? null : true; + dialog.querySelector('#notification-title').innerHTML = title; + let msgsContainer = dialog.querySelector('#notify-msgs') as HTMLElement; + if (isErrorMsg) { + msgsContainer.innerHTML = ''; + } else { + const errors = msgsContainer.querySelectorAll('.error'); + for (let i = 0; i < errors.length; i++) { + msgsContainer.removeChild(errors[i]); } - if (msg == null) { - numActiveMessages--; - if (numActiveMessages === 0) { - dialog.close(); - } - msgDiv.remove(); + } + let divId = `notify-msg-${id}`; + let msgDiv = dialog.querySelector('#' + divId) as HTMLDivElement; + if (msgDiv == null) { + msgDiv = document.createElement('div'); + msgDiv.className = 'notify-msg ' + (isErrorMsg ? 'error' : ''); + msgDiv.id = divId; + msgsContainer.insertBefore(msgDiv, msgsContainer.firstChild); + if (!isErrorMsg) { + numActiveMessages++; } else { - msgDiv.innerText = msg; - dialog.open(); + numActiveMessages = 0; } - return id; } - - export function setErrorMessage(errMsg: string, task?: string) { - setModalMessage(errMsg, null, 'Error ' + (task != null ? task : ''), true); - } - - /** - * Shows a warning message to the user for a certain amount of time. - */ - export function setWarningMessage(msg: string): void { - let toast = dom.shadowRoot.querySelector('#toast') as any; - toast.text = msg; - toast.duration = WARNING_DURATION_MS; - toast.open(); + if (msg == null) { + numActiveMessages--; + if (numActiveMessages === 0) { + dialog.close(); + } + msgDiv.remove(); + } else { + msgDiv.innerText = msg; + dialog.open(); } -} // namespace vz_projector.logging + return id; +} +export function setErrorMessage(errMsg: string, task?: string) { + setModalMessage(errMsg, null, 'Error ' + (task != null ? task : ''), true); +} +/** + * Shows a warning message to the user for a certain amount of time. + */ +export function setWarningMessage(msg: string): void { + let toast = dom.shadowRoot.querySelector('#toast') as any; + toast.text = msg; + toast.duration = WARNING_DURATION_MS; + toast.open(); +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/projectorEventContext.ts b/tensorboard/plugins/projector/polymer3/vz_projector/projectorEventContext.ts index cd3c4b72df..09bcef0be4 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/projectorEventContext.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/projectorEventContext.ts @@ -12,35 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - export type HoverListener = (index: number) => void; - export type SelectionChangedListener = ( - selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[] - ) => void; - export type ProjectionChangedListener = (projection: Projection) => void; - export type DistanceMetricChangedListener = ( - distanceMetric: DistanceFunction - ) => void; - export interface ProjectorEventContext { - /** Register a callback to be invoked when the mouse hovers over a point. */ - registerHoverListener(listener: HoverListener); - /** Notify the hover system that a point is under the mouse. */ - notifyHoverOverPoint(pointIndex: number); - /** Registers a callback to be invoked when the selection changes. */ - registerSelectionChangedListener(listener: SelectionChangedListener); - /** - * Notify the selection system that a client has changed the selected point - * set. - */ - notifySelectionChanged(newSelectedPointIndices: number[]); - /** Registers a callback to be invoked when the projection changes. */ - registerProjectionChangedListener(listener: ProjectionChangedListener); - /** Notify listeners that a reprojection occurred. */ - notifyProjectionChanged(projection: Projection); - registerDistanceMetricChangedListener( - listener: DistanceMetricChangedListener - ); - notifyDistanceMetricChanged(distMetric: DistanceFunction); - } -} // namespace vz_projector +import {DistanceFunction, Projection} from './data'; +import * as knn from './knn'; + +export type HoverListener = (index: number) => void; + +export type SelectionChangedListener = ( + selectedPointIndices: number[], + neighborsOfFirstPoint: knn.NearestEntry[] +) => void; + +export type ProjectionChangedListener = (projection: Projection) => void; + +export type DistanceMetricChangedListener = ( + distanceMetric: DistanceFunction +) => void; + +export interface ProjectorEventContext { + /** Register a callback to be invoked when the mouse hovers over a point. */ + registerHoverListener(listener: HoverListener); + /** Notify the hover system that a point is under the mouse. */ + notifyHoverOverPoint(pointIndex: number); + /** Registers a callback to be invoked when the selection changes. */ + registerSelectionChangedListener(listener: SelectionChangedListener); + /** + * Notify the selection system that a client has changed the selected point + * set. + */ + notifySelectionChanged(newSelectedPointIndices: number[]); + /** Registers a callback to be invoked when the projection changes. */ + registerProjectionChangedListener(listener: ProjectionChangedListener); + /** Notify listeners that a reprojection occurred. */ + notifyProjectionChanged(projection: Projection); + registerDistanceMetricChangedListener( + listener: DistanceMetricChangedListener + ); + notifyDistanceMetricChanged(distMetric: DistanceFunction); +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/projectorScatterPlotAdapter.ts b/tensorboard/plugins/projector/polymer3/vz_projector/projectorScatterPlotAdapter.ts index ee22670294..110cacbef4 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/projectorScatterPlotAdapter.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/projectorScatterPlotAdapter.ts @@ -12,360 +12,398 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - const LABEL_FONT_SIZE = 10; - const LABEL_SCALE_DEFAULT = 1.0; - const LABEL_SCALE_LARGE = 2; - const LABEL_FILL_COLOR_SELECTED = 0x000000; - const LABEL_FILL_COLOR_HOVER = 0x000000; - const LABEL_FILL_COLOR_NEIGHBOR = 0x000000; - const LABEL_STROKE_COLOR_SELECTED = 0xffffff; - const LABEL_STROKE_COLOR_HOVER = 0xffffff; - const LABEL_STROKE_COLOR_NEIGHBOR = 0xffffff; - - const POINT_COLOR_UNSELECTED = 0xe3e3e3; - const POINT_COLOR_NO_SELECTION = 0x7575d9; - const POINT_COLOR_SELECTED = 0xfa6666; - const POINT_COLOR_HOVER = 0x760b4f; - - const POINT_SCALE_DEFAULT = 1.0; - const POINT_SCALE_SELECTED = 1.2; - const POINT_SCALE_NEIGHBOR = 1.2; - const POINT_SCALE_HOVER = 1.2; - - const LABELS_3D_COLOR_UNSELECTED = 0xffffff; - const LABELS_3D_COLOR_NO_SELECTION = 0xffffff; - - const SPRITE_IMAGE_COLOR_UNSELECTED = 0xffffff; - const SPRITE_IMAGE_COLOR_NO_SELECTION = 0xffffff; - - const POLYLINE_START_HUE = 60; - const POLYLINE_END_HUE = 360; - const POLYLINE_SATURATION = 1; - const POLYLINE_LIGHTNESS = 0.3; - - const POLYLINE_DEFAULT_OPACITY = 0.2; - const POLYLINE_DEFAULT_LINEWIDTH = 2; - const POLYLINE_SELECTED_OPACITY = 0.9; - const POLYLINE_SELECTED_LINEWIDTH = 3; - const POLYLINE_DESELECTED_OPACITY = 0.05; - - const SCATTER_PLOT_CUBE_LENGTH = 2; - - /** Color scale for nearest neighbors. */ - const NN_COLOR_SCALE = d3 - .scaleLinear() - .domain([1, 0.7, 0.4]) - .range(['hsl(285, 80%, 40%)', 'hsl(0, 80%, 65%)', 'hsl(40, 70%, 60%)']) - .clamp(true); - - /** - * Interprets projector events and assembes the arrays and commands necessary - * to use the ScatterPlot to render the current projected data set. - */ - export class ProjectorScatterPlotAdapter { - public scatterPlot: ScatterPlot; - private projection: Projection; - private hoverPointIndex: number; - private selectedPointIndices: number[]; - private neighborsOfFirstSelectedPoint: knn.NearestEntry[]; - private renderLabelsIn3D: boolean = false; - private labelPointAccessor: string; - private legendPointColorer: (ds: DataSet, index: number) => string; - private distanceMetric: DistanceFunction; - - private spriteVisualizer: ScatterPlotVisualizerSprites; - private labels3DVisualizer: ScatterPlotVisualizer3DLabels; - private canvasLabelsVisualizer: ScatterPlotVisualizerCanvasLabels; - private polylineVisualizer: ScatterPlotVisualizerPolylines; - - constructor( - private scatterPlotContainer: HTMLElement, - projectorEventContext: ProjectorEventContext - ) { - this.scatterPlot = new ScatterPlot( - scatterPlotContainer, - projectorEventContext - ); - projectorEventContext.registerProjectionChangedListener((projection) => { - this.projection = projection; - this.updateScatterPlotWithNewProjection(projection); - }); - projectorEventContext.registerSelectionChangedListener( - (selectedPointIndices, neighbors) => { - this.selectedPointIndices = selectedPointIndices; - this.neighborsOfFirstSelectedPoint = neighbors; - this.updateScatterPlotPositions(); - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); - } - ); - projectorEventContext.registerHoverListener((hoverPointIndex) => { - this.hoverPointIndex = hoverPointIndex; +import * as THREE from 'three'; +import * as d3 from 'd3'; + +import { + DataSet, + DistanceFunction, + Projection, + State, + ProjectionComponents3D, +} from './data'; +import {ProjectorEventContext} from './projectorEventContext'; +import {LabelRenderParams} from './renderContext'; +import {ScatterPlot} from './scatterPlot'; +import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites'; +import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels'; +import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels'; +import {ScatterPlotVisualizerPolylines} from './scatterPlotVisualizerPolylines'; +import * as knn from './knn'; +import * as vector from './vector'; + +const LABEL_FONT_SIZE = 10; +const LABEL_SCALE_DEFAULT = 1; +const LABEL_SCALE_LARGE = 2; +const LABEL_FILL_COLOR_SELECTED = 0; +const LABEL_FILL_COLOR_HOVER = 0; +const LABEL_FILL_COLOR_NEIGHBOR = 0; +const LABEL_STROKE_COLOR_SELECTED = 16777215; +const LABEL_STROKE_COLOR_HOVER = 16777215; +const LABEL_STROKE_COLOR_NEIGHBOR = 16777215; +const POINT_COLOR_UNSELECTED = 14935011; +const POINT_COLOR_NO_SELECTION = 7697881; +const POINT_COLOR_SELECTED = 16410214; +const POINT_COLOR_HOVER = 7736143; +const POINT_SCALE_DEFAULT = 1; +const POINT_SCALE_SELECTED = 1.2; +const POINT_SCALE_NEIGHBOR = 1.2; +const POINT_SCALE_HOVER = 1.2; +const LABELS_3D_COLOR_UNSELECTED = 16777215; +const LABELS_3D_COLOR_NO_SELECTION = 16777215; +const SPRITE_IMAGE_COLOR_UNSELECTED = 16777215; +const SPRITE_IMAGE_COLOR_NO_SELECTION = 16777215; +const POLYLINE_START_HUE = 60; +const POLYLINE_END_HUE = 360; +const POLYLINE_SATURATION = 1; +const POLYLINE_LIGHTNESS = 0.3; +const POLYLINE_DEFAULT_OPACITY = 0.2; +const POLYLINE_DEFAULT_LINEWIDTH = 2; +const POLYLINE_SELECTED_OPACITY = 0.9; +const POLYLINE_SELECTED_LINEWIDTH = 3; +const POLYLINE_DESELECTED_OPACITY = 0.05; +const SCATTER_PLOT_CUBE_LENGTH = 2; +/** Color scale for nearest neighbors. */ +const NN_COLOR_SCALE = d3 + .scaleLinear() + .domain([1, 0.7, 0.4]) + .range(['hsl(285, 80%, 40%)', 'hsl(0, 80%, 65%)', 'hsl(40, 70%, 60%)']) + .clamp(true); +/** + * Interprets projector events and assembes the arrays and commands necessary + * to use the ScatterPlot to render the current projected data set. + */ +export class ProjectorScatterPlotAdapter { + public scatterPlot: ScatterPlot; + private projection: Projection; + private hoverPointIndex: number; + private selectedPointIndices: number[]; + private neighborsOfFirstSelectedPoint: knn.NearestEntry[]; + private renderLabelsIn3D: boolean = false; + private labelPointAccessor: string; + private legendPointColorer: (ds: DataSet, index: number) => string; + private distanceMetric: DistanceFunction; + private spriteVisualizer: ScatterPlotVisualizerSprites; + private labels3DVisualizer: ScatterPlotVisualizer3DLabels; + private canvasLabelsVisualizer: ScatterPlotVisualizerCanvasLabels; + private polylineVisualizer: ScatterPlotVisualizerPolylines; + constructor( + private scatterPlotContainer: HTMLElement, + projectorEventContext: ProjectorEventContext + ) { + this.scatterPlot = new ScatterPlot( + scatterPlotContainer, + projectorEventContext + ); + projectorEventContext.registerProjectionChangedListener((projection) => { + this.projection = projection; + this.updateScatterPlotWithNewProjection(projection); + }); + projectorEventContext.registerSelectionChangedListener( + (selectedPointIndices, neighbors) => { + this.selectedPointIndices = selectedPointIndices; + this.neighborsOfFirstSelectedPoint = neighbors; + this.updateScatterPlotPositions(); this.updateScatterPlotAttributes(); this.scatterPlot.render(); - }); - projectorEventContext.registerDistanceMetricChangedListener( - (distanceMetric) => { - this.distanceMetric = distanceMetric; - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); - } - ); - this.createVisualizers(false); - } - - notifyProjectionPositionsUpdated() { - this.updateScatterPlotPositions(); - this.scatterPlot.render(); - } - - setDataSet(dataSet: DataSet) { - if (this.projection != null) { - // TODO(@charlesnicholson): setDataSet needs to go away, the projection is the - // atomic unit of update. - this.projection.dataSet = dataSet; - } - if (this.polylineVisualizer != null) { - this.polylineVisualizer.setDataSet(dataSet); - } - if (this.labels3DVisualizer != null) { - this.labels3DVisualizer.setLabelStrings( - this.generate3DLabelsArray(dataSet, this.labelPointAccessor) - ); } - if (this.spriteVisualizer == null) { - return; - } - this.spriteVisualizer.clearSpriteAtlas(); - if (dataSet == null || dataSet.spriteAndMetadataInfo == null) { - return; - } - const metadata = dataSet.spriteAndMetadataInfo; - if (metadata.spriteImage == null || metadata.spriteMetadata == null) { - return; - } - const n = dataSet.points.length; - const spriteIndices = new Float32Array(n); - for (let i = 0; i < n; ++i) { - spriteIndices[i] = dataSet.points[i].index; - } - this.spriteVisualizer.setSpriteAtlas( - metadata.spriteImage, - metadata.spriteMetadata.singleImageDim, - spriteIndices - ); - } - - set3DLabelMode(renderLabelsIn3D: boolean) { - this.renderLabelsIn3D = renderLabelsIn3D; - this.createVisualizers(renderLabelsIn3D); + ); + projectorEventContext.registerHoverListener((hoverPointIndex) => { + this.hoverPointIndex = hoverPointIndex; this.updateScatterPlotAttributes(); this.scatterPlot.render(); + }); + projectorEventContext.registerDistanceMetricChangedListener( + (distanceMetric) => { + this.distanceMetric = distanceMetric; + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + } + ); + this.createVisualizers(false); + } + notifyProjectionPositionsUpdated() { + this.updateScatterPlotPositions(); + this.scatterPlot.render(); + } + setDataSet(dataSet: DataSet) { + if (this.projection != null) { + // TODO(@charlesnicholson): setDataSet needs to go away, the projection is the + // atomic unit of update. + this.projection.dataSet = dataSet; } - - setLegendPointColorer( - legendPointColorer: (ds: DataSet, index: number) => string - ) { - this.legendPointColorer = legendPointColorer; + if (this.polylineVisualizer != null) { + this.polylineVisualizer.setDataSet(dataSet); } - - setLabelPointAccessor(labelPointAccessor: string) { - this.labelPointAccessor = labelPointAccessor; - if (this.labels3DVisualizer != null) { - const ds = this.projection == null ? null : this.projection.dataSet; - this.labels3DVisualizer.setLabelStrings( - this.generate3DLabelsArray(ds, labelPointAccessor) - ); - } + if (this.labels3DVisualizer != null) { + this.labels3DVisualizer.setLabelStrings( + this.generate3DLabelsArray(dataSet, this.labelPointAccessor) + ); } - - resize() { - this.scatterPlot.resize(); + if (this.spriteVisualizer == null) { + return; } - - populateBookmarkFromUI(state: State) { - state.cameraDef = this.scatterPlot.getCameraDef(); + this.spriteVisualizer.clearSpriteAtlas(); + if (dataSet == null || dataSet.spriteAndMetadataInfo == null) { + return; } - - restoreUIFromBookmark(state: State) { - this.scatterPlot.setCameraParametersForNextCameraCreation( - state.cameraDef, - false - ); + const metadata = dataSet.spriteAndMetadataInfo; + if (metadata.spriteImage == null || metadata.spriteMetadata == null) { + return; } - - updateScatterPlotPositions() { + const n = dataSet.points.length; + const spriteIndices = new Float32Array(n); + for (let i = 0; i < n; ++i) { + spriteIndices[i] = dataSet.points[i].index; + } + this.spriteVisualizer.setSpriteAtlas( + metadata.spriteImage, + metadata.spriteMetadata.singleImageDim, + spriteIndices + ); + } + set3DLabelMode(renderLabelsIn3D: boolean) { + this.renderLabelsIn3D = renderLabelsIn3D; + this.createVisualizers(renderLabelsIn3D); + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + } + setLegendPointColorer( + legendPointColorer: (ds: DataSet, index: number) => string + ) { + this.legendPointColorer = legendPointColorer; + } + setLabelPointAccessor(labelPointAccessor: string) { + this.labelPointAccessor = labelPointAccessor; + if (this.labels3DVisualizer != null) { const ds = this.projection == null ? null : this.projection.dataSet; - const projectionComponents = - this.projection == null ? null : this.projection.projectionComponents; - const newPositions = this.generatePointPositionArray( - ds, - projectionComponents + this.labels3DVisualizer.setLabelStrings( + this.generate3DLabelsArray(ds, labelPointAccessor) ); - this.scatterPlot.setPointPositions(newPositions); } - - updateScatterPlotAttributes() { - if (this.projection == null) { - return; - } - const dataSet = this.projection.dataSet; - const selectedSet = this.selectedPointIndices; - const hoverIndex = this.hoverPointIndex; - const neighbors = this.neighborsOfFirstSelectedPoint; - const pointColorer = this.legendPointColorer; - - const pointColors = this.generatePointColorArray( - dataSet, - pointColorer, - this.distanceMetric, - selectedSet, - neighbors, - hoverIndex, - this.renderLabelsIn3D, - this.getSpriteImageMode() + } + resize() { + this.scatterPlot.resize(); + } + populateBookmarkFromUI(state: State) { + state.cameraDef = this.scatterPlot.getCameraDef(); + } + restoreUIFromBookmark(state: State) { + this.scatterPlot.setCameraParametersForNextCameraCreation( + state.cameraDef, + false + ); + } + updateScatterPlotPositions() { + const ds = this.projection == null ? null : this.projection.dataSet; + const projectionComponents = + this.projection == null ? null : this.projection.projectionComponents; + const newPositions = this.generatePointPositionArray( + ds, + projectionComponents + ); + this.scatterPlot.setPointPositions(newPositions); + } + updateScatterPlotAttributes() { + if (this.projection == null) { + return; + } + const dataSet = this.projection.dataSet; + const selectedSet = this.selectedPointIndices; + const hoverIndex = this.hoverPointIndex; + const neighbors = this.neighborsOfFirstSelectedPoint; + const pointColorer = this.legendPointColorer; + const pointColors = this.generatePointColorArray( + dataSet, + pointColorer, + this.distanceMetric, + selectedSet, + neighbors, + hoverIndex, + this.renderLabelsIn3D, + this.getSpriteImageMode() + ); + const pointScaleFactors = this.generatePointScaleFactorArray( + dataSet, + selectedSet, + neighbors, + hoverIndex + ); + const labels = this.generateVisibleLabelRenderParams( + dataSet, + selectedSet, + neighbors, + hoverIndex + ); + const polylineColors = this.generateLineSegmentColorMap( + dataSet, + pointColorer + ); + const polylineOpacities = this.generateLineSegmentOpacityArray( + dataSet, + selectedSet + ); + const polylineWidths = this.generateLineSegmentWidthArray( + dataSet, + selectedSet + ); + this.scatterPlot.setPointColors(pointColors); + this.scatterPlot.setPointScaleFactors(pointScaleFactors); + this.scatterPlot.setLabels(labels); + this.scatterPlot.setPolylineColors(polylineColors); + this.scatterPlot.setPolylineOpacities(polylineOpacities); + this.scatterPlot.setPolylineWidths(polylineWidths); + } + render() { + this.scatterPlot.render(); + } + generatePointPositionArray( + ds: DataSet, + projectionComponents: ProjectionComponents3D + ): Float32Array { + if (ds == null) { + return null; + } + const xScaler = d3.scaleLinear(); + const yScaler = d3.scaleLinear(); + let zScaler = null; + { + // Determine max and min of each axis of our data. + const xExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[0]] ); - const pointScaleFactors = this.generatePointScaleFactorArray( - dataSet, - selectedSet, - neighbors, - hoverIndex + const yExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[1]] ); - const labels = this.generateVisibleLabelRenderParams( - dataSet, - selectedSet, - neighbors, - hoverIndex + const range = [ + -SCATTER_PLOT_CUBE_LENGTH / 2, + SCATTER_PLOT_CUBE_LENGTH / 2, + ]; + xScaler.domain(xExtent).range(range); + yScaler.domain(yExtent).range(range); + if (projectionComponents[2] != null) { + const zExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[2]] + ); + zScaler = d3.scaleLinear(); + zScaler.domain(zExtent).range(range); + } + } + const positions = new Float32Array(ds.points.length * 3); + let dst = 0; + ds.points.forEach((d, i) => { + positions[dst++] = xScaler( + ds.points[i].projections[projectionComponents[0]] ); - const polylineColors = this.generateLineSegmentColorMap( - dataSet, - pointColorer + positions[dst++] = yScaler( + ds.points[i].projections[projectionComponents[1]] ); - const polylineOpacities = this.generateLineSegmentOpacityArray( - dataSet, - selectedSet + positions[dst++] = 0; + }); + if (zScaler) { + dst = 2; + ds.points.forEach((d, i) => { + positions[dst] = zScaler( + ds.points[i].projections[projectionComponents[2]] + ); + dst += 3; + }); + } + return positions; + } + generateVisibleLabelRenderParams( + ds: DataSet, + selectedPointIndices: number[], + neighborsOfFirstPoint: knn.NearestEntry[], + hoverPointIndex: number + ): LabelRenderParams { + if (ds == null) { + return null; + } + const selectedPointCount = + selectedPointIndices == null ? 0 : selectedPointIndices.length; + const neighborCount = + neighborsOfFirstPoint == null ? 0 : neighborsOfFirstPoint.length; + const n = + selectedPointCount + neighborCount + (hoverPointIndex != null ? 1 : 0); + const visibleLabels = new Uint32Array(n); + const scale = new Float32Array(n); + const opacityFlags = new Int8Array(n); + const fillColors = new Uint8Array(n * 3); + const strokeColors = new Uint8Array(n * 3); + const labelStrings: string[] = []; + scale.fill(LABEL_SCALE_DEFAULT); + opacityFlags.fill(1); + let dst = 0; + if (hoverPointIndex != null) { + labelStrings.push( + this.getLabelText(ds, hoverPointIndex, this.labelPointAccessor) ); - const polylineWidths = this.generateLineSegmentWidthArray( - dataSet, - selectedSet + visibleLabels[dst] = hoverPointIndex; + scale[dst] = LABEL_SCALE_LARGE; + opacityFlags[dst] = 0; + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER); + packRgbIntoUint8Array( + fillColors, + dst, + fillRgb[0], + fillRgb[1], + fillRgb[2] ); - - this.scatterPlot.setPointColors(pointColors); - this.scatterPlot.setPointScaleFactors(pointScaleFactors); - this.scatterPlot.setLabels(labels); - this.scatterPlot.setPolylineColors(polylineColors); - this.scatterPlot.setPolylineOpacities(polylineOpacities); - this.scatterPlot.setPolylineWidths(polylineWidths); - } - - render() { - this.scatterPlot.render(); + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER); + packRgbIntoUint8Array( + strokeColors, + dst, + strokeRgb[0], + strokeRgb[1], + strokeRgb[1] + ); + ++dst; } - - generatePointPositionArray( - ds: DataSet, - projectionComponents: ProjectionComponents3D - ): Float32Array { - if (ds == null) { - return null; - } - - const xScaler = d3.scaleLinear(); - const yScaler = d3.scaleLinear(); - let zScaler = null; - { - // Determine max and min of each axis of our data. - const xExtent = d3.extent( - ds.points, - (p, i) => ds.points[i].projections[projectionComponents[0]] - ); - const yExtent = d3.extent( - ds.points, - (p, i) => ds.points[i].projections[projectionComponents[1]] + // Selected points + { + const n = selectedPointCount; + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED); + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED); + for (let i = 0; i < n; ++i) { + const labelIndex = selectedPointIndices[i]; + labelStrings.push( + this.getLabelText(ds, labelIndex, this.labelPointAccessor) ); - - const range = [ - -SCATTER_PLOT_CUBE_LENGTH / 2, - SCATTER_PLOT_CUBE_LENGTH / 2, - ]; - - xScaler.domain(xExtent).range(range); - yScaler.domain(yExtent).range(range); - - if (projectionComponents[2] != null) { - const zExtent = d3.extent( - ds.points, - (p, i) => ds.points[i].projections[projectionComponents[2]] - ); - zScaler = d3.scaleLinear(); - zScaler.domain(zExtent).range(range); - } - } - - const positions = new Float32Array(ds.points.length * 3); - let dst = 0; - - ds.points.forEach((d, i) => { - positions[dst++] = xScaler( - ds.points[i].projections[projectionComponents[0]] + visibleLabels[dst] = labelIndex; + scale[dst] = LABEL_SCALE_LARGE; + opacityFlags[dst] = n === 1 ? 0 : 1; + packRgbIntoUint8Array( + fillColors, + dst, + fillRgb[0], + fillRgb[1], + fillRgb[2] ); - positions[dst++] = yScaler( - ds.points[i].projections[projectionComponents[1]] + packRgbIntoUint8Array( + strokeColors, + dst, + strokeRgb[0], + strokeRgb[1], + strokeRgb[2] ); - positions[dst++] = 0.0; - }); - - if (zScaler) { - dst = 2; - ds.points.forEach((d, i) => { - positions[dst] = zScaler( - ds.points[i].projections[projectionComponents[2]] - ); - dst += 3; - }); + ++dst; } - - return positions; } - - generateVisibleLabelRenderParams( - ds: DataSet, - selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[], - hoverPointIndex: number - ): LabelRenderParams { - if (ds == null) { - return null; - } - - const selectedPointCount = - selectedPointIndices == null ? 0 : selectedPointIndices.length; - const neighborCount = - neighborsOfFirstPoint == null ? 0 : neighborsOfFirstPoint.length; - const n = - selectedPointCount + neighborCount + (hoverPointIndex != null ? 1 : 0); - - const visibleLabels = new Uint32Array(n); - const scale = new Float32Array(n); - const opacityFlags = new Int8Array(n); - const fillColors = new Uint8Array(n * 3); - const strokeColors = new Uint8Array(n * 3); - const labelStrings: string[] = []; - - scale.fill(LABEL_SCALE_DEFAULT); - opacityFlags.fill(1); - - let dst = 0; - - if (hoverPointIndex != null) { + // Neighbors + { + const n = neighborCount; + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR); + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR); + for (let i = 0; i < n; ++i) { + const labelIndex = neighborsOfFirstPoint[i].index; labelStrings.push( - this.getLabelText(ds, hoverPointIndex, this.labelPointAccessor) + this.getLabelText(ds, labelIndex, this.labelPointAccessor) ); - visibleLabels[dst] = hoverPointIndex; - scale[dst] = LABEL_SCALE_LARGE; - opacityFlags[dst] = 0; - const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER); + visibleLabels[dst] = labelIndex; packRgbIntoUint8Array( fillColors, dst, @@ -373,456 +411,359 @@ namespace vz_projector { fillRgb[1], fillRgb[2] ); - const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER); packRgbIntoUint8Array( strokeColors, dst, strokeRgb[0], strokeRgb[1], - strokeRgb[1] + strokeRgb[2] ); ++dst; } - - // Selected points - { - const n = selectedPointCount; - const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED); - const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED); - for (let i = 0; i < n; ++i) { - const labelIndex = selectedPointIndices[i]; - labelStrings.push( - this.getLabelText(ds, labelIndex, this.labelPointAccessor) - ); - visibleLabels[dst] = labelIndex; - scale[dst] = LABEL_SCALE_LARGE; - opacityFlags[dst] = n === 1 ? 0 : 1; - packRgbIntoUint8Array( - fillColors, - dst, - fillRgb[0], - fillRgb[1], - fillRgb[2] + } + return new LabelRenderParams( + new Float32Array(visibleLabels), + labelStrings, + scale, + opacityFlags, + LABEL_FONT_SIZE, + fillColors, + strokeColors + ); + } + generatePointScaleFactorArray( + ds: DataSet, + selectedPointIndices: number[], + neighborsOfFirstPoint: knn.NearestEntry[], + hoverPointIndex: number + ): Float32Array { + if (ds == null) { + return new Float32Array(0); + } + const scale = new Float32Array(ds.points.length); + scale.fill(POINT_SCALE_DEFAULT); + const selectedPointCount = + selectedPointIndices == null ? 0 : selectedPointIndices.length; + const neighborCount = + neighborsOfFirstPoint == null ? 0 : neighborsOfFirstPoint.length; + // Scale up all selected points. + { + const n = selectedPointCount; + for (let i = 0; i < n; ++i) { + const p = selectedPointIndices[i]; + scale[p] = POINT_SCALE_SELECTED; + } + } + // Scale up the neighbor points. + { + const n = neighborCount; + for (let i = 0; i < n; ++i) { + const p = neighborsOfFirstPoint[i].index; + scale[p] = POINT_SCALE_NEIGHBOR; + } + } + // Scale up the hover point. + if (hoverPointIndex != null) { + scale[hoverPointIndex] = POINT_SCALE_HOVER; + } + return scale; + } + generateLineSegmentColorMap( + ds: DataSet, + legendPointColorer: (ds: DataSet, index: number) => string + ): { + [polylineIndex: number]: Float32Array; + } { + let polylineColorArrayMap: { + [polylineIndex: number]: Float32Array; + } = {}; + if (ds == null) { + return polylineColorArrayMap; + } + for (let i = 0; i < ds.sequences.length; i++) { + let sequence = ds.sequences[i]; + let colors = new Float32Array(2 * (sequence.pointIndices.length - 1) * 3); + let colorIndex = 0; + if (legendPointColorer) { + for (let j = 0; j < sequence.pointIndices.length - 1; j++) { + const c1 = new THREE.Color( + legendPointColorer(ds, sequence.pointIndices[j]) ); - packRgbIntoUint8Array( - strokeColors, - dst, - strokeRgb[0], - strokeRgb[1], - strokeRgb[2] + const c2 = new THREE.Color( + legendPointColorer(ds, sequence.pointIndices[j + 1]) ); - ++dst; + colors[colorIndex++] = c1.r; + colors[colorIndex++] = c1.g; + colors[colorIndex++] = c1.b; + colors[colorIndex++] = c2.r; + colors[colorIndex++] = c2.g; + colors[colorIndex++] = c2.b; } - } - - // Neighbors - { - const n = neighborCount; - const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR); - const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR); - for (let i = 0; i < n; ++i) { - const labelIndex = neighborsOfFirstPoint[i].index; - labelStrings.push( - this.getLabelText(ds, labelIndex, this.labelPointAccessor) - ); - visibleLabels[dst] = labelIndex; - packRgbIntoUint8Array( - fillColors, - dst, - fillRgb[0], - fillRgb[1], - fillRgb[2] + } else { + for (let j = 0; j < sequence.pointIndices.length - 1; j++) { + const c1 = getDefaultPointInPolylineColor( + j, + sequence.pointIndices.length ); - packRgbIntoUint8Array( - strokeColors, - dst, - strokeRgb[0], - strokeRgb[1], - strokeRgb[2] + const c2 = getDefaultPointInPolylineColor( + j + 1, + sequence.pointIndices.length ); - ++dst; + colors[colorIndex++] = c1.r; + colors[colorIndex++] = c1.g; + colors[colorIndex++] = c1.b; + colors[colorIndex++] = c2.r; + colors[colorIndex++] = c2.g; + colors[colorIndex++] = c2.b; } } - - return new LabelRenderParams( - new Float32Array(visibleLabels), - labelStrings, - scale, - opacityFlags, - LABEL_FONT_SIZE, - fillColors, - strokeColors - ); + polylineColorArrayMap[i] = colors; } - - generatePointScaleFactorArray( - ds: DataSet, - selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[], - hoverPointIndex: number - ): Float32Array { - if (ds == null) { - return new Float32Array(0); - } - - const scale = new Float32Array(ds.points.length); - scale.fill(POINT_SCALE_DEFAULT); - - const selectedPointCount = - selectedPointIndices == null ? 0 : selectedPointIndices.length; - const neighborCount = - neighborsOfFirstPoint == null ? 0 : neighborsOfFirstPoint.length; - - // Scale up all selected points. - { - const n = selectedPointCount; - for (let i = 0; i < n; ++i) { - const p = selectedPointIndices[i]; - scale[p] = POINT_SCALE_SELECTED; - } - } - - // Scale up the neighbor points. - { - const n = neighborCount; - for (let i = 0; i < n; ++i) { - const p = neighborsOfFirstPoint[i].index; - scale[p] = POINT_SCALE_NEIGHBOR; - } - } - - // Scale up the hover point. - if (hoverPointIndex != null) { - scale[hoverPointIndex] = POINT_SCALE_HOVER; - } - - return scale; + return polylineColorArrayMap; + } + generateLineSegmentOpacityArray( + ds: DataSet, + selectedPoints: number[] + ): Float32Array { + if (ds == null) { + return new Float32Array(0); } - - generateLineSegmentColorMap( - ds: DataSet, - legendPointColorer: (ds: DataSet, index: number) => string - ): {[polylineIndex: number]: Float32Array} { - let polylineColorArrayMap: {[polylineIndex: number]: Float32Array} = {}; - if (ds == null) { - return polylineColorArrayMap; - } - - for (let i = 0; i < ds.sequences.length; i++) { - let sequence = ds.sequences[i]; - let colors = new Float32Array( - 2 * (sequence.pointIndices.length - 1) * 3 - ); - let colorIndex = 0; - - if (legendPointColorer) { - for (let j = 0; j < sequence.pointIndices.length - 1; j++) { - const c1 = new THREE.Color( - legendPointColorer(ds, sequence.pointIndices[j]) - ); - const c2 = new THREE.Color( - legendPointColorer(ds, sequence.pointIndices[j + 1]) - ); - colors[colorIndex++] = c1.r; - colors[colorIndex++] = c1.g; - colors[colorIndex++] = c1.b; - colors[colorIndex++] = c2.r; - colors[colorIndex++] = c2.g; - colors[colorIndex++] = c2.b; - } - } else { - for (let j = 0; j < sequence.pointIndices.length - 1; j++) { - const c1 = getDefaultPointInPolylineColor( - j, - sequence.pointIndices.length - ); - const c2 = getDefaultPointInPolylineColor( - j + 1, - sequence.pointIndices.length - ); - colors[colorIndex++] = c1.r; - colors[colorIndex++] = c1.g; - colors[colorIndex++] = c1.b; - colors[colorIndex++] = c2.r; - colors[colorIndex++] = c2.g; - colors[colorIndex++] = c2.b; - } - } - - polylineColorArrayMap[i] = colors; - } - - return polylineColorArrayMap; + const opacities = new Float32Array(ds.sequences.length); + const selectedPointCount = + selectedPoints == null ? 0 : selectedPoints.length; + if (selectedPointCount > 0) { + opacities.fill(POLYLINE_DESELECTED_OPACITY); + const i = ds.points[selectedPoints[0]].sequenceIndex; + opacities[i] = POLYLINE_SELECTED_OPACITY; + } else { + opacities.fill(POLYLINE_DEFAULT_OPACITY); } - - generateLineSegmentOpacityArray( - ds: DataSet, - selectedPoints: number[] - ): Float32Array { - if (ds == null) { - return new Float32Array(0); - } - const opacities = new Float32Array(ds.sequences.length); - const selectedPointCount = - selectedPoints == null ? 0 : selectedPoints.length; - if (selectedPointCount > 0) { - opacities.fill(POLYLINE_DESELECTED_OPACITY); - const i = ds.points[selectedPoints[0]].sequenceIndex; - opacities[i] = POLYLINE_SELECTED_OPACITY; - } else { - opacities.fill(POLYLINE_DEFAULT_OPACITY); - } - return opacities; + return opacities; + } + generateLineSegmentWidthArray( + ds: DataSet, + selectedPoints: number[] + ): Float32Array { + if (ds == null) { + return new Float32Array(0); } - - generateLineSegmentWidthArray( - ds: DataSet, - selectedPoints: number[] - ): Float32Array { - if (ds == null) { - return new Float32Array(0); - } - const widths = new Float32Array(ds.sequences.length); - widths.fill(POLYLINE_DEFAULT_LINEWIDTH); - const selectedPointCount = - selectedPoints == null ? 0 : selectedPoints.length; - if (selectedPointCount > 0) { - const i = ds.points[selectedPoints[0]].sequenceIndex; - widths[i] = POLYLINE_SELECTED_LINEWIDTH; - } - return widths; + const widths = new Float32Array(ds.sequences.length); + widths.fill(POLYLINE_DEFAULT_LINEWIDTH); + const selectedPointCount = + selectedPoints == null ? 0 : selectedPoints.length; + if (selectedPointCount > 0) { + const i = ds.points[selectedPoints[0]].sequenceIndex; + widths[i] = POLYLINE_SELECTED_LINEWIDTH; } - - generatePointColorArray( - ds: DataSet, - legendPointColorer: (ds: DataSet, index: number) => string, - distFunc: DistanceFunction, - selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[], - hoverPointIndex: number, - label3dMode: boolean, - spriteImageMode: boolean - ): Float32Array { - if (ds == null) { - return new Float32Array(0); - } - - const selectedPointCount = - selectedPointIndices == null ? 0 : selectedPointIndices.length; - const neighborCount = - neighborsOfFirstPoint == null ? 0 : neighborsOfFirstPoint.length; - const colors = new Float32Array(ds.points.length * 3); - - let unselectedColor = POINT_COLOR_UNSELECTED; - let noSelectionColor = POINT_COLOR_NO_SELECTION; - - if (label3dMode) { - unselectedColor = LABELS_3D_COLOR_UNSELECTED; - noSelectionColor = LABELS_3D_COLOR_NO_SELECTION; - } - - if (spriteImageMode) { - unselectedColor = SPRITE_IMAGE_COLOR_UNSELECTED; - noSelectionColor = SPRITE_IMAGE_COLOR_NO_SELECTION; - } - - // Give all points the unselected color. - { - const n = ds.points.length; - let dst = 0; - if (selectedPointCount > 0) { - const c = new THREE.Color(unselectedColor); + return widths; + } + generatePointColorArray( + ds: DataSet, + legendPointColorer: (ds: DataSet, index: number) => string, + distFunc: DistanceFunction, + selectedPointIndices: number[], + neighborsOfFirstPoint: knn.NearestEntry[], + hoverPointIndex: number, + label3dMode: boolean, + spriteImageMode: boolean + ): Float32Array { + if (ds == null) { + return new Float32Array(0); + } + const selectedPointCount = + selectedPointIndices == null ? 0 : selectedPointIndices.length; + const neighborCount = + neighborsOfFirstPoint == null ? 0 : neighborsOfFirstPoint.length; + const colors = new Float32Array(ds.points.length * 3); + let unselectedColor = POINT_COLOR_UNSELECTED; + let noSelectionColor = POINT_COLOR_NO_SELECTION; + if (label3dMode) { + unselectedColor = LABELS_3D_COLOR_UNSELECTED; + noSelectionColor = LABELS_3D_COLOR_NO_SELECTION; + } + if (spriteImageMode) { + unselectedColor = SPRITE_IMAGE_COLOR_UNSELECTED; + noSelectionColor = SPRITE_IMAGE_COLOR_NO_SELECTION; + } + // Give all points the unselected color. + { + const n = ds.points.length; + let dst = 0; + if (selectedPointCount > 0) { + const c = new THREE.Color(unselectedColor); + for (let i = 0; i < n; ++i) { + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; + } + } else { + if (legendPointColorer != null) { for (let i = 0; i < n; ++i) { + const c = new THREE.Color(legendPointColorer(ds, i)); colors[dst++] = c.r; colors[dst++] = c.g; colors[dst++] = c.b; } } else { - if (legendPointColorer != null) { - for (let i = 0; i < n; ++i) { - const c = new THREE.Color(legendPointColorer(ds, i)); - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - } else { - const c = new THREE.Color(noSelectionColor); - for (let i = 0; i < n; ++i) { - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } + const c = new THREE.Color(noSelectionColor); + for (let i = 0; i < n; ++i) { + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; } } } - - // Color the selected points. - { - const n = selectedPointCount; - const c = new THREE.Color(POINT_COLOR_SELECTED); - for (let i = 0; i < n; ++i) { - let dst = selectedPointIndices[i] * 3; - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - } - - // Color the neighbors. - { - const n = neighborCount; - let minDist = n > 0 ? neighborsOfFirstPoint[0].dist : 0; - for (let i = 0; i < n; ++i) { - const c = new THREE.Color( - dist2color(distFunc, neighborsOfFirstPoint[i].dist, minDist) - ); - let dst = neighborsOfFirstPoint[i].index * 3; - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - } - - // Color the hover point. - if (hoverPointIndex != null) { - const c = new THREE.Color(POINT_COLOR_HOVER); - let dst = hoverPointIndex * 3; + } + // Color the selected points. + { + const n = selectedPointCount; + const c = new THREE.Color(POINT_COLOR_SELECTED); + for (let i = 0; i < n; ++i) { + let dst = selectedPointIndices[i] * 3; colors[dst++] = c.r; colors[dst++] = c.g; colors[dst++] = c.b; } - - return colors; } - - generate3DLabelsArray(ds: DataSet, accessor: string) { - if (ds == null || accessor == null) { - return null; - } - let labels: string[] = []; - const n = ds.points.length; + // Color the neighbors. + { + const n = neighborCount; + let minDist = n > 0 ? neighborsOfFirstPoint[0].dist : 0; for (let i = 0; i < n; ++i) { - labels.push(this.getLabelText(ds, i, accessor)); - } - return labels; - } - - private getLabelText(ds: DataSet, i: number, accessor: string): string { - return ds.points[i].metadata[accessor] !== undefined - ? String(ds.points[i].metadata[accessor]) - : `Unknown #${i}`; - } - - private updateScatterPlotWithNewProjection(projection: Projection) { - if (projection == null) { - this.createVisualizers(this.renderLabelsIn3D); - this.scatterPlot.render(); - return; - } - this.setDataSet(projection.dataSet); - this.scatterPlot.setDimensions(projection.dimensionality); - if ( - projection.dataSet.projectionCanBeRendered(projection.projectionType) - ) { - this.updateScatterPlotAttributes(); - this.notifyProjectionPositionsUpdated(); - } - this.scatterPlot.setCameraParametersForNextCameraCreation(null, false); - } - - private createVisualizers(inLabels3DMode: boolean) { - const ds = this.projection == null ? null : this.projection.dataSet; - const scatterPlot = this.scatterPlot; - scatterPlot.removeAllVisualizers(); - this.labels3DVisualizer = null; - this.canvasLabelsVisualizer = null; - this.spriteVisualizer = null; - this.polylineVisualizer = null; - if (inLabels3DMode) { - this.labels3DVisualizer = new ScatterPlotVisualizer3DLabels(); - this.labels3DVisualizer.setLabelStrings( - this.generate3DLabelsArray(ds, this.labelPointAccessor) - ); - } else { - this.spriteVisualizer = new ScatterPlotVisualizerSprites(); - scatterPlot.addVisualizer(this.spriteVisualizer); - this.canvasLabelsVisualizer = new ScatterPlotVisualizerCanvasLabels( - this.scatterPlotContainer + const c = new THREE.Color( + dist2color(distFunc, neighborsOfFirstPoint[i].dist, minDist) ); + let dst = neighborsOfFirstPoint[i].index * 3; + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; } - this.polylineVisualizer = new ScatterPlotVisualizerPolylines(); - this.setDataSet(ds); - if (this.spriteVisualizer) { - scatterPlot.addVisualizer(this.spriteVisualizer); - } - if (this.labels3DVisualizer) { - scatterPlot.addVisualizer(this.labels3DVisualizer); - } - if (this.canvasLabelsVisualizer) { - scatterPlot.addVisualizer(this.canvasLabelsVisualizer); - } - scatterPlot.addVisualizer(this.polylineVisualizer); } - - private getSpriteImageMode(): boolean { - if (this.projection == null) { - return false; - } - const ds = this.projection.dataSet; - if (ds == null || ds.spriteAndMetadataInfo == null) { - return false; - } - return ds.spriteAndMetadataInfo.spriteImage != null; + // Color the hover point. + if (hoverPointIndex != null) { + const c = new THREE.Color(POINT_COLOR_HOVER); + let dst = hoverPointIndex * 3; + colors[dst++] = c.r; + colors[dst++] = c.g; + colors[dst++] = c.b; } + return colors; } - - function packRgbIntoUint8Array( - rgbArray: Uint8Array, - labelIndex: number, - r: number, - g: number, - b: number - ) { - rgbArray[labelIndex * 3] = r; - rgbArray[labelIndex * 3 + 1] = g; - rgbArray[labelIndex * 3 + 2] = b; + generate3DLabelsArray(ds: DataSet, accessor: string) { + if (ds == null || accessor == null) { + return null; + } + let labels: string[] = []; + const n = ds.points.length; + for (let i = 0; i < n; ++i) { + labels.push(this.getLabelText(ds, i, accessor)); + } + return labels; } - - function styleRgbFromHexColor(hex: number): [number, number, number] { - const c = new THREE.Color(hex); - return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0]; + private getLabelText(ds: DataSet, i: number, accessor: string): string { + return ds.points[i].metadata[accessor] !== undefined + ? String(ds.points[i].metadata[accessor]) + : `Unknown #${i}`; } - - function getDefaultPointInPolylineColor( - index: number, - totalPoints: number - ): THREE.Color { - let hue = - POLYLINE_START_HUE + - ((POLYLINE_END_HUE - POLYLINE_START_HUE) * index) / totalPoints; - - let rgb = d3.hsl(hue, POLYLINE_SATURATION, POLYLINE_LIGHTNESS).rgb(); - return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255); + private updateScatterPlotWithNewProjection(projection: Projection) { + if (projection == null) { + this.createVisualizers(this.renderLabelsIn3D); + this.scatterPlot.render(); + return; + } + this.setDataSet(projection.dataSet); + this.scatterPlot.setDimensions(projection.dimensionality); + if (projection.dataSet.projectionCanBeRendered(projection.projectionType)) { + this.updateScatterPlotAttributes(); + this.notifyProjectionPositionsUpdated(); + } + this.scatterPlot.setCameraParametersForNextCameraCreation(null, false); } - - /** - * Normalizes the distance so it can be visually encoded with color. - * The normalization depends on the distance metric (cosine vs euclidean). - */ - export function normalizeDist( - distFunc: DistanceFunction, - d: number, - minDist: number - ): number { - return distFunc === vector.dist ? minDist / d : 1 - d; + private createVisualizers(inLabels3DMode: boolean) { + const ds = this.projection == null ? null : this.projection.dataSet; + const scatterPlot = this.scatterPlot; + scatterPlot.removeAllVisualizers(); + this.labels3DVisualizer = null; + this.canvasLabelsVisualizer = null; + this.spriteVisualizer = null; + this.polylineVisualizer = null; + if (inLabels3DMode) { + this.labels3DVisualizer = new ScatterPlotVisualizer3DLabels(); + this.labels3DVisualizer.setLabelStrings( + this.generate3DLabelsArray(ds, this.labelPointAccessor) + ); + } else { + this.spriteVisualizer = new ScatterPlotVisualizerSprites(); + scatterPlot.addVisualizer(this.spriteVisualizer); + this.canvasLabelsVisualizer = new ScatterPlotVisualizerCanvasLabels( + this.scatterPlotContainer + ); + } + this.polylineVisualizer = new ScatterPlotVisualizerPolylines(); + this.setDataSet(ds); + if (this.spriteVisualizer) { + scatterPlot.addVisualizer(this.spriteVisualizer); + } + if (this.labels3DVisualizer) { + scatterPlot.addVisualizer(this.labels3DVisualizer); + } + if (this.canvasLabelsVisualizer) { + scatterPlot.addVisualizer(this.canvasLabelsVisualizer); + } + scatterPlot.addVisualizer(this.polylineVisualizer); } - - /** Normalizes and encodes the provided distance with color. */ - export function dist2color( - distFunc: DistanceFunction, - d: number, - minDist: number - ): string { - return NN_COLOR_SCALE(normalizeDist(distFunc, d, minDist)); + private getSpriteImageMode(): boolean { + if (this.projection == null) { + return false; + } + const ds = this.projection.dataSet; + if (ds == null || ds.spriteAndMetadataInfo == null) { + return false; + } + return ds.spriteAndMetadataInfo.spriteImage != null; } -} // namespace vz_projector +} +function packRgbIntoUint8Array( + rgbArray: Uint8Array, + labelIndex: number, + r: number, + g: number, + b: number +) { + rgbArray[labelIndex * 3] = r; + rgbArray[labelIndex * 3 + 1] = g; + rgbArray[labelIndex * 3 + 2] = b; +} +function styleRgbFromHexColor(hex: number): [number, number, number] { + const c = new THREE.Color(hex); + return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0]; +} +function getDefaultPointInPolylineColor( + index: number, + totalPoints: number +): THREE.Color { + let hue = + POLYLINE_START_HUE + + ((POLYLINE_END_HUE - POLYLINE_START_HUE) * index) / totalPoints; + let rgb = d3.hsl(hue, POLYLINE_SATURATION, POLYLINE_LIGHTNESS).rgb(); + return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255); +} +/** + * Normalizes the distance so it can be visually encoded with color. + * The normalization depends on the distance metric (cosine vs euclidean). + */ +export function normalizeDist( + distFunc: DistanceFunction, + d: number, + minDist: number +): number { + return distFunc === vector.dist ? minDist / d : 1 - d; +} +/** Normalizes and encodes the provided distance with color. */ +export function dist2color( + distFunc: DistanceFunction, + d: number, + minDist: number +): string { + return NN_COLOR_SCALE(normalizeDist(distFunc, d, minDist)); +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/renderContext.ts b/tensorboard/plugins/projector/polymer3/vz_projector/renderContext.ts index 40db686b09..e84904d862 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/renderContext.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/renderContext.ts @@ -12,53 +12,49 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - /** - * LabelRenderParams describes the set of points that should have labels - * rendered next to them. - */ - export class LabelRenderParams { - constructor( - public pointIndices: Float32Array, - public labelStrings: string[], - public scaleFactors: Float32Array, - public useSceneOpacityFlags: Int8Array, - public defaultFontSize: number, - public fillColors: Uint8Array, - public strokeColors: Uint8Array - ) {} - } +import * as THREE from 'three'; - /** Details about the camera projection being used to render the scene. */ - export enum CameraType { - Perspective, - Orthographic, - } - - /** - * RenderContext contains all of the state required to color and render the data - * set. ScatterPlot passes this to every attached visualizer as part of the - * render callback. - * TODO(@charlesnicholson): This should only contain the data that's changed between - * each frame. Data like colors / scale factors / labels should be reapplied - * only when they change. - */ - export class RenderContext { - constructor( - public camera: THREE.Camera, - public cameraType: CameraType, - public cameraTarget: THREE.Vector3, - public screenWidth: number, - public screenHeight: number, - public nearestCameraSpacePointZ: number, - public farthestCameraSpacePointZ: number, - public backgroundColor: number, - public pointColors: Float32Array, - public pointScaleFactors: Float32Array, - public labels: LabelRenderParams, - public polylineColors: {[polylineIndex: number]: Float32Array}, - public polylineOpacities: Float32Array, - public polylineWidths: Float32Array - ) {} - } -} // namespace vz_projector +export class LabelRenderParams { + constructor( + public pointIndices: Float32Array, + public labelStrings: string[], + public scaleFactors: Float32Array, + public useSceneOpacityFlags: Int8Array, + public defaultFontSize: number, + public fillColors: Uint8Array, + public strokeColors: Uint8Array + ) {} +} +/** Details about the camera projection being used to render the scene. */ +export enum CameraType { + Perspective, + Orthographic, +} +/** + * RenderContext contains all of the state required to color and render the data + * set. ScatterPlot passes this to every attached visualizer as part of the + * render callback. + * TODO(@charlesnicholson): This should only contain the data that's changed between + * each frame. Data like colors / scale factors / labels should be reapplied + * only when they change. + */ +export class RenderContext { + constructor( + public camera: THREE.Camera, + public cameraType: CameraType, + public cameraTarget: THREE.Vector3, + public screenWidth: number, + public screenHeight: number, + public nearestCameraSpacePointZ: number, + public farthestCameraSpacePointZ: number, + public backgroundColor: number, + public pointColors: Float32Array, + public pointScaleFactors: Float32Array, + public labels: LabelRenderParams, + public polylineColors: { + [polylineIndex: number]: Float32Array; + }, + public polylineOpacities: Float32Array, + public polylineWidths: Float32Array + ) {} +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlot.ts b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlot.ts index 6879c1232d..6d086446b4 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlot.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlot.ts @@ -12,798 +12,715 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - const BACKGROUND_COLOR = 0xffffff; - - /** - * The length of the cube (diameter of the circumscribing sphere) where all the - * points live. - */ - const CUBE_LENGTH = 2; - const MAX_ZOOM = 5 * CUBE_LENGTH; - const MIN_ZOOM = 0.025 * CUBE_LENGTH; - - // Constants relating to the camera parameters. - const PERSP_CAMERA_FOV_VERTICAL = 70; - const PERSP_CAMERA_NEAR_CLIP_PLANE = 0.01; - const PERSP_CAMERA_FAR_CLIP_PLANE = 100; - const ORTHO_CAMERA_FRUSTUM_HALF_EXTENT = 1.2; - - // Key presses. - const SHIFT_KEY = 16; - const CTRL_KEY = 17; - - const START_CAMERA_POS_3D = new THREE.Vector3(0.45, 0.9, 1.6); - const START_CAMERA_TARGET_3D = new THREE.Vector3(0, 0, 0); - const START_CAMERA_POS_2D = new THREE.Vector3(0, 0, 4); - const START_CAMERA_TARGET_2D = new THREE.Vector3(0, 0, 0); - - const ORBIT_MOUSE_ROTATION_SPEED = 1; - const ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS = 7; - - export type OnCameraMoveListener = ( - cameraPosition: THREE.Vector3, - cameraTarget: THREE.Vector3 - ) => void; - - /** Supported modes of interaction. */ - export enum MouseMode { - AREA_SELECT, - CAMERA_AND_CLICK_SELECT, +import * as THREE from 'three'; +import {OrbitControls} from 'three/examples/jsm/controls/OrbitControls'; + +import * as vector from './vector'; +import * as util from './util'; +import {ProjectorEventContext} from './projectorEventContext'; +import {CameraType, RenderContext, LabelRenderParams} from './renderContext'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import { + ScatterBoundingBox, + ScatterPlotRectangleSelector, +} from './scatterPlotRectangleSelector'; + +const BACKGROUND_COLOR = 16777215; +/** + * The length of the cube (diameter of the circumscribing sphere) where all the + * points live. + */ +const CUBE_LENGTH = 2; +const MAX_ZOOM = 5 * CUBE_LENGTH; +const MIN_ZOOM = 0.025 * CUBE_LENGTH; +// Constants relating to the camera parameters. +const PERSP_CAMERA_FOV_VERTICAL = 70; +const PERSP_CAMERA_NEAR_CLIP_PLANE = 0.01; +const PERSP_CAMERA_FAR_CLIP_PLANE = 100; +const ORTHO_CAMERA_FRUSTUM_HALF_EXTENT = 1.2; +// Key presses. +const SHIFT_KEY = 16; +const CTRL_KEY = 17; +const START_CAMERA_POS_3D = new THREE.Vector3(0.45, 0.9, 1.6); +const START_CAMERA_TARGET_3D = new THREE.Vector3(0, 0, 0); +const START_CAMERA_POS_2D = new THREE.Vector3(0, 0, 4); +const START_CAMERA_TARGET_2D = new THREE.Vector3(0, 0, 0); +const ORBIT_MOUSE_ROTATION_SPEED = 1; +const ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS = 7; +export type OnCameraMoveListener = ( + cameraPosition: THREE.Vector3, + cameraTarget: THREE.Vector3 +) => void; +/** Supported modes of interaction. */ +export enum MouseMode { + AREA_SELECT, + CAMERA_AND_CLICK_SELECT, +} +/** Defines a camera, suitable for serialization. */ +export class CameraDef { + orthographic: boolean = false; + position: vector.Point3D; + target: vector.Point3D; + zoom: number; +} +/** + * Maintains a three.js instantiation and context, + * animation state, and all other logic that's + * independent of how a 3D scatter plot is actually rendered. Also holds an + * array of visualizers and dispatches application events to them. + */ +export class ScatterPlot { + private visualizers: ScatterPlotVisualizer[] = []; + private onCameraMoveListeners: OnCameraMoveListener[] = []; + private height: number; + private width: number; + private mouseMode: MouseMode; + private backgroundColor: number = BACKGROUND_COLOR; + private dimensionality: number = 3; + private renderer: THREE.WebGLRenderer; + private scene: THREE.Scene; + private pickingTexture: THREE.WebGLRenderTarget; + private light: THREE.PointLight; + private cameraDef: CameraDef = null; + private camera: THREE.Camera; + private orbitAnimationOnNextCameraCreation: boolean = false; + private orbitCameraControls: any; + private orbitAnimationId: number; + private worldSpacePointPositions: Float32Array; + private pointColors: Float32Array; + private pointScaleFactors: Float32Array; + private labels: LabelRenderParams; + private polylineColors: { + [polylineIndex: number]: Float32Array; + }; + private polylineOpacities: Float32Array; + private polylineWidths: Float32Array; + private selecting = false; + private nearestPoint: number; + private mouseIsDown = false; + private isDragSequence = false; + private rectangleSelector: ScatterPlotRectangleSelector; + constructor( + private container: HTMLElement, + private projectorEventContext: ProjectorEventContext + ) { + this.getLayoutValues(); + this.scene = new THREE.Scene(); + this.renderer = new THREE.WebGLRenderer({ + alpha: true, + premultipliedAlpha: false, + antialias: false, + }); + this.renderer.setClearColor(BACKGROUND_COLOR, 1); + this.container.appendChild(this.renderer.domElement); + this.light = new THREE.PointLight(16772287, 1, 0); + this.scene.add(this.light); + this.setDimensions(3); + this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality)); + this.renderer.render(this.scene, this.camera); + this.rectangleSelector = new ScatterPlotRectangleSelector( + this.container, + (boundingBox: ScatterBoundingBox) => this.selectBoundingBox(boundingBox) + ); + this.addInteractionListeners(); } - - /** Defines a camera, suitable for serialization. */ - export class CameraDef { - orthographic: boolean = false; - position: vector.Point3D; - target: vector.Point3D; - zoom: number; + private addInteractionListeners() { + this.container.addEventListener('mousemove', this.onMouseMove.bind(this)); + this.container.addEventListener('mousedown', this.onMouseDown.bind(this)); + this.container.addEventListener('mouseup', this.onMouseUp.bind(this)); + this.container.addEventListener('click', this.onClick.bind(this)); + window.addEventListener('keydown', this.onKeyDown.bind(this), false); + window.addEventListener('keyup', this.onKeyUp.bind(this), false); } - - /** - * Maintains a three.js instantiation and context, - * animation state, and all other logic that's - * independent of how a 3D scatter plot is actually rendered. Also holds an - * array of visualizers and dispatches application events to them. - */ - export class ScatterPlot { - private visualizers: ScatterPlotVisualizer[] = []; - - private onCameraMoveListeners: OnCameraMoveListener[] = []; - - private height: number; - private width: number; - - private mouseMode: MouseMode; - private backgroundColor: number = BACKGROUND_COLOR; - - private dimensionality: number = 3; - private renderer: THREE.WebGLRenderer; - - private scene: THREE.Scene; - private pickingTexture: THREE.WebGLRenderTarget; - private light: THREE.PointLight; - - private cameraDef: CameraDef = null; - private camera: THREE.Camera; - private orbitAnimationOnNextCameraCreation: boolean = false; - private orbitCameraControls: any; - private orbitAnimationId: number; - - private worldSpacePointPositions: Float32Array; - private pointColors: Float32Array; - private pointScaleFactors: Float32Array; - private labels: LabelRenderParams; - private polylineColors: {[polylineIndex: number]: Float32Array}; - private polylineOpacities: Float32Array; - private polylineWidths: Float32Array; - - private selecting = false; - private nearestPoint: number; - private mouseIsDown = false; - private isDragSequence = false; - private rectangleSelector: ScatterPlotRectangleSelector; - - constructor( - private container: HTMLElement, - private projectorEventContext: ProjectorEventContext - ) { - this.getLayoutValues(); - - this.scene = new THREE.Scene(); - this.renderer = new THREE.WebGLRenderer({ - alpha: true, - premultipliedAlpha: false, - antialias: false, - }); - this.renderer.setClearColor(BACKGROUND_COLOR, 1); - this.container.appendChild(this.renderer.domElement); - this.light = new THREE.PointLight(0xffecbf, 1, 0); - this.scene.add(this.light); - - this.setDimensions(3); - this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality)); - this.renderer.render(this.scene, this.camera); - - this.rectangleSelector = new ScatterPlotRectangleSelector( - this.container, - (boundingBox: ScatterBoundingBox) => this.selectBoundingBox(boundingBox) + private addCameraControlsEventListeners(cameraControls: any) { + // Start is called when the user stars interacting with + // controls. + cameraControls.addEventListener('start', () => { + this.stopOrbitAnimation(); + this.onCameraMoveListeners.forEach((l) => + l(this.camera.position, cameraControls.target) ); - this.addInteractionListeners(); - } - - private addInteractionListeners() { - this.container.addEventListener('mousemove', this.onMouseMove.bind(this)); - this.container.addEventListener('mousedown', this.onMouseDown.bind(this)); - this.container.addEventListener('mouseup', this.onMouseUp.bind(this)); - this.container.addEventListener('click', this.onClick.bind(this)); - window.addEventListener('keydown', this.onKeyDown.bind(this), false); - window.addEventListener('keyup', this.onKeyUp.bind(this), false); - } - - private addCameraControlsEventListeners(cameraControls: any) { - // Start is called when the user stars interacting with - // controls. - cameraControls.addEventListener('start', () => { - this.stopOrbitAnimation(); - this.onCameraMoveListeners.forEach((l) => - l(this.camera.position, cameraControls.target) - ); - }); - - // Change is called everytime the user interacts with the controls. - cameraControls.addEventListener('change', () => { - this.render(); - }); - - // End is called when the user stops interacting with the - // controls (e.g. on mouse up, after dragging). - cameraControls.addEventListener('end', () => {}); - } - - private makeOrbitControls( - camera: THREE.Camera, - cameraDef: CameraDef, - cameraIs3D: boolean - ) { - if (this.orbitCameraControls != null) { - this.orbitCameraControls.dispose(); - } - const occ = new (THREE as any).OrbitControls( - camera, - this.renderer.domElement + }); + // Change is called everytime the user interacts with the controls. + cameraControls.addEventListener('change', () => { + this.render(); + }); + // End is called when the user stops interacting with the + // controls (e.g. on mouse up, after dragging). + cameraControls.addEventListener('end', () => {}); + } + private makeOrbitControls( + camera: THREE.Camera, + cameraDef: CameraDef, + cameraIs3D: boolean + ) { + if (this.orbitCameraControls != null) { + this.orbitCameraControls.dispose(); + } + const occ = new OrbitControls(camera, this.renderer.domElement) as any; + occ.target0 = new THREE.Vector3( + cameraDef.target[0], + cameraDef.target[1], + cameraDef.target[2] + ); + occ.position0 = new THREE.Vector3().copy(camera.position); + occ.zoom0 = cameraDef.zoom; + occ.enableRotate = cameraIs3D; + occ.autoRotate = false; + occ.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED; + if (cameraIs3D) { + occ.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + occ.mouseButtons.PAN = THREE.MOUSE.RIGHT; + } else { + occ.mouseButtons.ORBIT = null; + occ.mouseButtons.PAN = THREE.MOUSE.LEFT; + } + occ.reset(); + this.camera = camera; + this.orbitCameraControls = occ; + this.addCameraControlsEventListeners(this.orbitCameraControls); + } + private makeCamera3D(cameraDef: CameraDef, w: number, h: number) { + let camera: THREE.PerspectiveCamera; + { + const aspectRatio = w / h; + camera = new THREE.PerspectiveCamera( + PERSP_CAMERA_FOV_VERTICAL, + aspectRatio, + PERSP_CAMERA_NEAR_CLIP_PLANE, + PERSP_CAMERA_FAR_CLIP_PLANE ); - occ.target0 = new THREE.Vector3( - cameraDef.target[0], - cameraDef.target[1], - cameraDef.target[2] + camera.position.set( + cameraDef.position[0], + cameraDef.position[1], + cameraDef.position[2] ); - occ.position0 = new THREE.Vector3().copy(camera.position); - occ.zoom0 = cameraDef.zoom; - occ.enableRotate = cameraIs3D; - occ.autoRotate = false; - occ.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED; - if (cameraIs3D) { - occ.mouseButtons.ORBIT = THREE.MOUSE.LEFT; - occ.mouseButtons.PAN = THREE.MOUSE.RIGHT; - } else { - occ.mouseButtons.ORBIT = null; - occ.mouseButtons.PAN = THREE.MOUSE.LEFT; - } - occ.reset(); - - this.camera = camera; - this.orbitCameraControls = occ; - this.addCameraControlsEventListeners(this.orbitCameraControls); - } - - private makeCamera3D(cameraDef: CameraDef, w: number, h: number) { - let camera: THREE.PerspectiveCamera; - { - const aspectRatio = w / h; - camera = new THREE.PerspectiveCamera( - PERSP_CAMERA_FOV_VERTICAL, - aspectRatio, - PERSP_CAMERA_NEAR_CLIP_PLANE, - PERSP_CAMERA_FAR_CLIP_PLANE - ); - camera.position.set( - cameraDef.position[0], - cameraDef.position[1], - cameraDef.position[2] - ); - const at = new THREE.Vector3( - cameraDef.target[0], - cameraDef.target[1], - cameraDef.target[2] - ); - camera.lookAt(at); - camera.zoom = cameraDef.zoom; - camera.updateProjectionMatrix(); - } - this.camera = camera; - this.makeOrbitControls(camera, cameraDef, true); - } - - private makeCamera2D(cameraDef: CameraDef, w: number, h: number) { - let camera: THREE.OrthographicCamera; - const target = new THREE.Vector3( + const at = new THREE.Vector3( cameraDef.target[0], cameraDef.target[1], cameraDef.target[2] ); - { - const aspectRatio = w / h; - let left = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; - let right = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; - let bottom = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; - let top = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; - // Scale up the larger of (w, h) to match the aspect ratio. - if (aspectRatio > 1) { - left *= aspectRatio; - right *= aspectRatio; - } else { - top /= aspectRatio; - bottom /= aspectRatio; - } - camera = new THREE.OrthographicCamera( - left, - right, - top, - bottom, - -1000, - 1000 - ); - camera.position.set( - cameraDef.position[0], - cameraDef.position[1], - cameraDef.position[2] - ); - camera.up = new THREE.Vector3(0, 1, 0); - camera.lookAt(target); - camera.zoom = cameraDef.zoom; - camera.updateProjectionMatrix(); - } - this.camera = camera; - this.makeOrbitControls(camera, cameraDef, false); + camera.lookAt(at); + camera.zoom = cameraDef.zoom; + camera.updateProjectionMatrix(); } - - private makeDefaultCameraDef(dimensionality: number): CameraDef { - const def = new CameraDef(); - def.orthographic = dimensionality === 2; - def.zoom = 1.0; - if (def.orthographic) { - def.position = [ - START_CAMERA_POS_2D.x, - START_CAMERA_POS_2D.y, - START_CAMERA_POS_2D.z, - ]; - def.target = [ - START_CAMERA_TARGET_2D.x, - START_CAMERA_TARGET_2D.y, - START_CAMERA_TARGET_2D.z, - ]; - } else { - def.position = [ - START_CAMERA_POS_3D.x, - START_CAMERA_POS_3D.y, - START_CAMERA_POS_3D.z, - ]; - def.target = [ - START_CAMERA_TARGET_3D.x, - START_CAMERA_TARGET_3D.y, - START_CAMERA_TARGET_3D.z, - ]; - } - return def; - } - - /** Recreate the scatter plot camera from a definition structure. */ - recreateCamera(cameraDef: CameraDef) { - if (cameraDef.orthographic) { - this.makeCamera2D(cameraDef, this.width, this.height); + this.camera = camera; + this.makeOrbitControls(camera, cameraDef, true); + } + private makeCamera2D(cameraDef: CameraDef, w: number, h: number) { + let camera: THREE.OrthographicCamera; + const target = new THREE.Vector3( + cameraDef.target[0], + cameraDef.target[1], + cameraDef.target[2] + ); + { + const aspectRatio = w / h; + let left = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; + let right = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; + let bottom = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; + let top = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; + // Scale up the larger of (w, h) to match the aspect ratio. + if (aspectRatio > 1) { + left *= aspectRatio; + right *= aspectRatio; } else { - this.makeCamera3D(cameraDef, this.width, this.height); - } - this.orbitCameraControls.minDistance = MIN_ZOOM; - this.orbitCameraControls.maxDistance = MAX_ZOOM; - this.orbitCameraControls.update(); - if (this.orbitAnimationOnNextCameraCreation) { - this.startOrbitAnimation(); - } - } - - private onClick(e?: MouseEvent, notify = true) { - if (e && this.selecting) { - return; - } - // Only call event handlers if the click originated from the scatter plot. - if (!this.isDragSequence && notify) { - const selection = this.nearestPoint != null ? [this.nearestPoint] : []; - this.projectorEventContext.notifySelectionChanged(selection); - } - this.isDragSequence = false; - this.render(); - } - - private onMouseDown(e: MouseEvent) { - this.isDragSequence = false; - this.mouseIsDown = true; - if (this.selecting) { - this.orbitCameraControls.enabled = false; - this.rectangleSelector.onMouseDown(e.offsetX, e.offsetY); - this.setNearestPointToMouse(e); - } else if ( - !e.ctrlKey && - this.sceneIs3D() && - this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.RIGHT - ) { - // The user happened to press the ctrl key when the tab was active, - // unpressed the ctrl when the tab was inactive, and now he/she - // is back to the projector tab. - this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; - this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; - } else if ( - e.ctrlKey && - this.sceneIs3D() && - this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.LEFT - ) { - // Similarly to the situation above. - this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; - this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; - } - } - - /** When we stop dragging/zooming, return to normal behavior. */ - private onMouseUp(e: any) { - if (this.selecting) { - this.orbitCameraControls.enabled = true; - this.rectangleSelector.onMouseUp(); - this.render(); - } - this.mouseIsDown = false; - } - - /** - * When the mouse moves, find the nearest point (if any) and send it to the - * hoverlisteners (usually called from embedding.ts) - */ - private onMouseMove(e: MouseEvent) { - this.isDragSequence = this.mouseIsDown; - // Depending if we're selecting or just navigating, handle accordingly. - if (this.selecting && this.mouseIsDown) { - this.rectangleSelector.onMouseMove(e.offsetX, e.offsetY); - this.render(); - } else if (!this.mouseIsDown) { - this.setNearestPointToMouse(e); - this.projectorEventContext.notifyHoverOverPoint(this.nearestPoint); - } - } - - /** For using ctrl + left click as right click, and for circle select */ - private onKeyDown(e: any) { - // If ctrl is pressed, use left click to orbit - if (e.keyCode === CTRL_KEY && this.sceneIs3D()) { - this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; - this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; - } - - // If shift is pressed, start selecting - if (e.keyCode === SHIFT_KEY) { - this.selecting = true; - this.container.style.cursor = 'crosshair'; - } - } - - /** For using ctrl + left click as right click, and for circle select */ - private onKeyUp(e: any) { - if (e.keyCode === CTRL_KEY && this.sceneIs3D()) { - this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; - this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; - } - - // If shift is released, stop selecting - if (e.keyCode === SHIFT_KEY) { - this.selecting = this.getMouseMode() === MouseMode.AREA_SELECT; - if (!this.selecting) { - this.container.style.cursor = 'default'; - } - this.render(); - } - } - - /** - * Returns a list of indices of points in a bounding box from the picking - * texture. - * @param boundingBox The bounding box to select from. - */ - private getPointIndicesFromPickingTexture( - boundingBox: ScatterBoundingBox - ): number[] { - if (this.worldSpacePointPositions == null) { - return null; - } - const pointCount = this.worldSpacePointPositions.length / 3; - const dpr = window.devicePixelRatio || 1; - const x = Math.floor(boundingBox.x * dpr); - const y = Math.floor(boundingBox.y * dpr); - const width = Math.floor(boundingBox.width * dpr); - const height = Math.floor(boundingBox.height * dpr); - - // Create buffer for reading all of the pixels from the texture. - let pixelBuffer = new Uint8Array(width * height * 4); - - // Read the pixels from the bounding box. - this.renderer.readRenderTargetPixels( - this.pickingTexture, - x, - this.pickingTexture.height - y, - width, - height, - pixelBuffer + top /= aspectRatio; + bottom /= aspectRatio; + } + camera = new THREE.OrthographicCamera( + left, + right, + top, + bottom, + -1000, + 1000 ); - - // Keep a flat list of each point and whether they are selected or not. This - // approach is more efficient than using an object keyed by the index. - let pointIndicesSelection = new Uint8Array( - this.worldSpacePointPositions.length + camera.position.set( + cameraDef.position[0], + cameraDef.position[1], + cameraDef.position[2] ); - for (let i = 0; i < width * height; i++) { - const id = - (pixelBuffer[i * 4] << 16) | - (pixelBuffer[i * 4 + 1] << 8) | - pixelBuffer[i * 4 + 2]; - if (id !== 0xffffff && id < pointCount) { - pointIndicesSelection[id] = 1; - } - } - let pointIndices: number[] = []; - for (let i = 0; i < pointIndicesSelection.length; i++) { - if (pointIndicesSelection[i] === 1) { - pointIndices.push(i); - } - } - - return pointIndices; + camera.up = new THREE.Vector3(0, 1, 0); + camera.lookAt(target); + camera.zoom = cameraDef.zoom; + camera.updateProjectionMatrix(); } - - private selectBoundingBox(boundingBox: ScatterBoundingBox) { - let pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); - this.projectorEventContext.notifySelectionChanged(pointIndices); - } - - private setNearestPointToMouse(e: MouseEvent) { - if (this.pickingTexture == null) { - this.nearestPoint = null; - return; - } - const boundingBox: ScatterBoundingBox = { - x: e.offsetX, - y: e.offsetY, - width: 1, - height: 1, - }; - const pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); - this.nearestPoint = pointIndices != null ? pointIndices[0] : null; - } - - private getLayoutValues(): vector.Point2D { - this.width = this.container.offsetWidth; - this.height = Math.max(1, this.container.offsetHeight); - return [this.width, this.height]; - } - - private sceneIs3D(): boolean { - return this.dimensionality === 3; - } - - private remove3dAxisFromScene(): THREE.Object3D { - const axes = this.scene.getObjectByName('axes'); - if (axes != null) { - this.scene.remove(axes); - } - return axes; + this.camera = camera; + this.makeOrbitControls(camera, cameraDef, false); + } + private makeDefaultCameraDef(dimensionality: number): CameraDef { + const def = new CameraDef(); + def.orthographic = dimensionality === 2; + def.zoom = 1; + if (def.orthographic) { + def.position = [ + START_CAMERA_POS_2D.x, + START_CAMERA_POS_2D.y, + START_CAMERA_POS_2D.z, + ]; + def.target = [ + START_CAMERA_TARGET_2D.x, + START_CAMERA_TARGET_2D.y, + START_CAMERA_TARGET_2D.z, + ]; + } else { + def.position = [ + START_CAMERA_POS_3D.x, + START_CAMERA_POS_3D.y, + START_CAMERA_POS_3D.z, + ]; + def.target = [ + START_CAMERA_TARGET_3D.x, + START_CAMERA_TARGET_3D.y, + START_CAMERA_TARGET_3D.z, + ]; + } + return def; + } + /** Recreate the scatter plot camera from a definition structure. */ + recreateCamera(cameraDef: CameraDef) { + if (cameraDef.orthographic) { + this.makeCamera2D(cameraDef, this.width, this.height); + } else { + this.makeCamera3D(cameraDef, this.width, this.height); + } + this.orbitCameraControls.minDistance = MIN_ZOOM; + this.orbitCameraControls.maxDistance = MAX_ZOOM; + this.orbitCameraControls.update(); + if (this.orbitAnimationOnNextCameraCreation) { + this.startOrbitAnimation(); } - - private add3dAxis() { - const axes = new (THREE as any).AxesHelper(); - axes.name = 'axes'; - this.scene.add(axes); + } + private onClick(e?: MouseEvent, notify = true) { + if (e && this.selecting) { + return; } - - /** Set 2d vs 3d mode. */ - setDimensions(dimensionality: number) { - if (dimensionality !== 2 && dimensionality !== 3) { - throw new RangeError('dimensionality must be 2 or 3'); - } - this.dimensionality = dimensionality; - - const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality); - this.recreateCamera(def); - - this.remove3dAxisFromScene(); - if (dimensionality === 3) { - this.add3dAxis(); - } + // Only call event handlers if the click originated from the scatter plot. + if (!this.isDragSequence && notify) { + const selection = this.nearestPoint != null ? [this.nearestPoint] : []; + this.projectorEventContext.notifySelectionChanged(selection); } - - /** Gets the current camera information, suitable for serialization. */ - getCameraDef(): CameraDef { - const def = new CameraDef(); - const pos = this.camera.position; - const tgt = this.orbitCameraControls.target; - def.orthographic = !this.sceneIs3D(); - def.position = [pos.x, pos.y, pos.z]; - def.target = [tgt.x, tgt.y, tgt.z]; - def.zoom = (this.camera as any).zoom; - return def; - } - - /** Sets parameters for the next camera recreation. */ - setCameraParametersForNextCameraCreation( - def: CameraDef, - orbitAnimation: boolean + this.isDragSequence = false; + this.render(); + } + private onMouseDown(e: MouseEvent) { + this.isDragSequence = false; + this.mouseIsDown = true; + if (this.selecting) { + this.orbitCameraControls.enabled = false; + this.rectangleSelector.onMouseDown(e.offsetX, e.offsetY); + this.setNearestPointToMouse(e); + } else if ( + !e.ctrlKey && + this.sceneIs3D() && + this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.RIGHT ) { - this.cameraDef = def; - this.orbitAnimationOnNextCameraCreation = orbitAnimation; - } - - /** Gets the current camera position. */ - getCameraPosition(): vector.Point3D { - const currPos = this.camera.position; - return [currPos.x, currPos.y, currPos.z]; + // The user happened to press the ctrl key when the tab was active, + // unpressed the ctrl when the tab was inactive, and now he/she + // is back to the projector tab. + this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; + } else if ( + e.ctrlKey && + this.sceneIs3D() && + this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.LEFT + ) { + // Similarly to the situation above. + this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; + this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; } - - /** Gets the current camera target. */ - getCameraTarget(): vector.Point3D { - let currTarget = this.orbitCameraControls.target; - return [currTarget.x, currTarget.y, currTarget.z]; + } + /** When we stop dragging/zooming, return to normal behavior. */ + private onMouseUp(e: any) { + if (this.selecting) { + this.orbitCameraControls.enabled = true; + this.rectangleSelector.onMouseUp(); + this.render(); } - - /** Sets up the camera from given position and target coordinates. */ - setCameraPositionAndTarget( - position: vector.Point3D, - target: vector.Point3D - ) { - this.stopOrbitAnimation(); - this.camera.position.set(position[0], position[1], position[2]); - this.orbitCameraControls.target.set(target[0], target[1], target[2]); - this.orbitCameraControls.update(); + this.mouseIsDown = false; + } + /** + * When the mouse moves, find the nearest point (if any) and send it to the + * hoverlisteners (usually called from embedding.ts) + */ + private onMouseMove(e: MouseEvent) { + this.isDragSequence = this.mouseIsDown; + // Depending if we're selecting or just navigating, handle accordingly. + if (this.selecting && this.mouseIsDown) { + this.rectangleSelector.onMouseMove(e.offsetX, e.offsetY); this.render(); + } else if (!this.mouseIsDown) { + this.setNearestPointToMouse(e); + this.projectorEventContext.notifyHoverOverPoint(this.nearestPoint); } - - /** Starts orbiting the camera around its current lookat target. */ - startOrbitAnimation() { - if (!this.sceneIs3D()) { - return; - } - if (this.orbitAnimationId != null) { - this.stopOrbitAnimation(); + } + /** For using ctrl + left click as right click, and for circle select */ + private onKeyDown(e: any) { + // If ctrl is pressed, use left click to orbit + if (e.keyCode === CTRL_KEY && this.sceneIs3D()) { + this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; + this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; + } + // If shift is pressed, start selecting + if (e.keyCode === SHIFT_KEY) { + this.selecting = true; + this.container.style.cursor = 'crosshair'; + } + } + /** For using ctrl + left click as right click, and for circle select */ + private onKeyUp(e: any) { + if (e.keyCode === CTRL_KEY && this.sceneIs3D()) { + this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; + } + // If shift is released, stop selecting + if (e.keyCode === SHIFT_KEY) { + this.selecting = this.getMouseMode() === MouseMode.AREA_SELECT; + if (!this.selecting) { + this.container.style.cursor = 'default'; } - this.orbitCameraControls.autoRotate = true; - this.orbitCameraControls.rotateSpeed = ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS; - this.updateOrbitAnimation(); + this.render(); } - - private updateOrbitAnimation() { - this.orbitCameraControls.update(); - this.orbitAnimationId = requestAnimationFrame(() => - this.updateOrbitAnimation() - ); + } + /** + * Returns a list of indices of points in a bounding box from the picking + * texture. + * @param boundingBox The bounding box to select from. + */ + private getPointIndicesFromPickingTexture( + boundingBox: ScatterBoundingBox + ): number[] { + if (this.worldSpacePointPositions == null) { + return null; + } + const pointCount = this.worldSpacePointPositions.length / 3; + const dpr = window.devicePixelRatio || 1; + const x = Math.floor(boundingBox.x * dpr); + const y = Math.floor(boundingBox.y * dpr); + const width = Math.floor(boundingBox.width * dpr); + const height = Math.floor(boundingBox.height * dpr); + // Create buffer for reading all of the pixels from the texture. + let pixelBuffer = new Uint8Array(width * height * 4); + // Read the pixels from the bounding box. + this.renderer.readRenderTargetPixels( + this.pickingTexture, + x, + this.pickingTexture.height - y, + width, + height, + pixelBuffer + ); + // Keep a flat list of each point and whether they are selected or not. This + // approach is more efficient than using an object keyed by the index. + let pointIndicesSelection = new Uint8Array( + this.worldSpacePointPositions.length + ); + for (let i = 0; i < width * height; i++) { + const id = + (pixelBuffer[i * 4] << 16) | + (pixelBuffer[i * 4 + 1] << 8) | + pixelBuffer[i * 4 + 2]; + if (id !== 16777215 && id < pointCount) { + pointIndicesSelection[id] = 1; + } + } + let pointIndices: number[] = []; + for (let i = 0; i < pointIndicesSelection.length; i++) { + if (pointIndicesSelection[i] === 1) { + pointIndices.push(i); + } + } + return pointIndices; + } + private selectBoundingBox(boundingBox: ScatterBoundingBox) { + let pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); + this.projectorEventContext.notifySelectionChanged(pointIndices); + } + private setNearestPointToMouse(e: MouseEvent) { + if (this.pickingTexture == null) { + this.nearestPoint = null; + return; + } + const boundingBox: ScatterBoundingBox = { + x: e.offsetX, + y: e.offsetY, + width: 1, + height: 1, + }; + const pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); + this.nearestPoint = pointIndices != null ? pointIndices[0] : null; + } + private getLayoutValues(): vector.Point2D { + this.width = this.container.offsetWidth; + this.height = Math.max(1, this.container.offsetHeight); + return [this.width, this.height]; + } + private sceneIs3D(): boolean { + return this.dimensionality === 3; + } + private remove3dAxisFromScene(): THREE.Object3D { + const axes = this.scene.getObjectByName('axes'); + if (axes != null) { + this.scene.remove(axes); } - - /** Stops the orbiting animation on the camera. */ - stopOrbitAnimation() { - this.orbitCameraControls.autoRotate = false; - this.orbitCameraControls.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED; - if (this.orbitAnimationId != null) { - cancelAnimationFrame(this.orbitAnimationId); - this.orbitAnimationId = null; - } + return axes; + } + private add3dAxis() { + const axes = new (THREE as any).AxesHelper(); + axes.name = 'axes'; + this.scene.add(axes); + } + /** Set 2d vs 3d mode. */ + setDimensions(dimensionality: number) { + if (dimensionality !== 2 && dimensionality !== 3) { + throw new RangeError('dimensionality must be 2 or 3'); + } + this.dimensionality = dimensionality; + const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality); + this.recreateCamera(def); + this.remove3dAxisFromScene(); + if (dimensionality === 3) { + this.add3dAxis(); } - - /** Adds a visualizer to the set, will start dispatching events to it */ - addVisualizer(visualizer: ScatterPlotVisualizer) { - if (this.scene) { - visualizer.setScene(this.scene); - } - visualizer.onResize(this.width, this.height); - visualizer.onPointPositionsChanged(this.worldSpacePointPositions); - this.visualizers.push(visualizer); + } + /** Gets the current camera information, suitable for serialization. */ + getCameraDef(): CameraDef { + const def = new CameraDef(); + const pos = this.camera.position; + const tgt = this.orbitCameraControls.target; + def.orthographic = !this.sceneIs3D(); + def.position = [pos.x, pos.y, pos.z]; + def.target = [tgt.x, tgt.y, tgt.z]; + def.zoom = (this.camera as any).zoom; + return def; + } + /** Sets parameters for the next camera recreation. */ + setCameraParametersForNextCameraCreation( + def: CameraDef, + orbitAnimation: boolean + ) { + this.cameraDef = def; + this.orbitAnimationOnNextCameraCreation = orbitAnimation; + } + /** Gets the current camera position. */ + getCameraPosition(): vector.Point3D { + const currPos = this.camera.position; + return [currPos.x, currPos.y, currPos.z]; + } + /** Gets the current camera target. */ + getCameraTarget(): vector.Point3D { + let currTarget = this.orbitCameraControls.target; + return [currTarget.x, currTarget.y, currTarget.z]; + } + /** Sets up the camera from given position and target coordinates. */ + setCameraPositionAndTarget(position: vector.Point3D, target: vector.Point3D) { + this.stopOrbitAnimation(); + this.camera.position.set(position[0], position[1], position[2]); + this.orbitCameraControls.target.set(target[0], target[1], target[2]); + this.orbitCameraControls.update(); + this.render(); + } + /** Starts orbiting the camera around its current lookat target. */ + startOrbitAnimation() { + if (!this.sceneIs3D()) { + return; } - - /** Removes all visualizers attached to this scatter plot. */ - removeAllVisualizers() { - this.visualizers.forEach((v) => v.dispose()); - this.visualizers = []; + if (this.orbitAnimationId != null) { + this.stopOrbitAnimation(); } - - /** Update scatter plot with a new array of packed xyz point positions. */ - setPointPositions(worldSpacePointPositions: Float32Array) { - this.worldSpacePointPositions = worldSpacePointPositions; - this.visualizers.forEach((v) => - v.onPointPositionsChanged(worldSpacePointPositions) - ); + this.orbitCameraControls.autoRotate = true; + this.orbitCameraControls.rotateSpeed = ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS; + this.updateOrbitAnimation(); + } + private updateOrbitAnimation() { + this.orbitCameraControls.update(); + this.orbitAnimationId = requestAnimationFrame(() => + this.updateOrbitAnimation() + ); + } + /** Stops the orbiting animation on the camera. */ + stopOrbitAnimation() { + this.orbitCameraControls.autoRotate = false; + this.orbitCameraControls.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED; + if (this.orbitAnimationId != null) { + cancelAnimationFrame(this.orbitAnimationId); + this.orbitAnimationId = null; } - - render() { - { - const lightPos = this.camera.position.clone(); - lightPos.x += 1; - lightPos.y += 1; - this.light.position.set(lightPos.x, lightPos.y, lightPos.z); - } - - const cameraType = - this.camera instanceof THREE.PerspectiveCamera - ? CameraType.Perspective - : CameraType.Orthographic; - - let cameraSpacePointExtents: [number, number] = [0, 0]; - if (this.worldSpacePointPositions != null) { - cameraSpacePointExtents = util.getNearFarPoints( - this.worldSpacePointPositions, - this.camera.position, - this.orbitCameraControls.target - ); - } - - const rc = new RenderContext( - this.camera, - cameraType, - this.orbitCameraControls.target, - this.width, - this.height, - cameraSpacePointExtents[0], - cameraSpacePointExtents[1], - this.backgroundColor, - this.pointColors, - this.pointScaleFactors, - this.labels, - this.polylineColors, - this.polylineOpacities, - this.polylineWidths + } + /** Adds a visualizer to the set, will start dispatching events to it */ + addVisualizer(visualizer: ScatterPlotVisualizer) { + if (this.scene) { + visualizer.setScene(this.scene); + } + visualizer.onResize(this.width, this.height); + visualizer.onPointPositionsChanged(this.worldSpacePointPositions); + this.visualizers.push(visualizer); + } + /** Removes all visualizers attached to this scatter plot. */ + removeAllVisualizers() { + this.visualizers.forEach((v) => v.dispose()); + this.visualizers = []; + } + /** Update scatter plot with a new array of packed xyz point positions. */ + setPointPositions(worldSpacePointPositions: Float32Array) { + this.worldSpacePointPositions = worldSpacePointPositions; + this.visualizers.forEach((v) => + v.onPointPositionsChanged(worldSpacePointPositions) + ); + } + render() { + { + const lightPos = this.camera.position.clone(); + lightPos.x += 1; + lightPos.y += 1; + this.light.position.set(lightPos.x, lightPos.y, lightPos.z); + } + const cameraType = + this.camera instanceof THREE.PerspectiveCamera + ? CameraType.Perspective + : CameraType.Orthographic; + let cameraSpacePointExtents: [number, number] = [0, 0]; + if (this.worldSpacePointPositions != null) { + cameraSpacePointExtents = util.getNearFarPoints( + this.worldSpacePointPositions, + this.camera.position, + this.orbitCameraControls.target ); - - // Render first pass to picking target. This render fills pickingTexture - // with colors that are actually point ids, so that sampling the texture at - // the mouse's current x,y coordinates will reveal the data point that the - // mouse is over. - this.visualizers.forEach((v) => v.onPickingRender(rc)); - - { - const axes = this.remove3dAxisFromScene(); - // Render to the pickingTexture when existing. - if (this.pickingTexture) { - this.renderer.setRenderTarget(this.pickingTexture); - } else { - this.renderer.setRenderTarget(null); - } - this.renderer.render(this.scene, this.camera); - - // Set the renderTarget back to the default. + } + const rc = new RenderContext( + this.camera, + cameraType, + this.orbitCameraControls.target, + this.width, + this.height, + cameraSpacePointExtents[0], + cameraSpacePointExtents[1], + this.backgroundColor, + this.pointColors, + this.pointScaleFactors, + this.labels, + this.polylineColors, + this.polylineOpacities, + this.polylineWidths + ); + // Render first pass to picking target. This render fills pickingTexture + // with colors that are actually point ids, so that sampling the texture at + // the mouse's current x,y coordinates will reveal the data point that the + // mouse is over. + this.visualizers.forEach((v) => v.onPickingRender(rc)); + { + const axes = this.remove3dAxisFromScene(); + // Render to the pickingTexture when existing. + if (this.pickingTexture) { + this.renderer.setRenderTarget(this.pickingTexture); + } else { this.renderer.setRenderTarget(null); - if (axes != null) { - this.scene.add(axes); - } } - - // Render second pass to color buffer, to be displayed on the canvas. - this.visualizers.forEach((v) => v.onRender(rc)); this.renderer.render(this.scene, this.camera); - } - - setMouseMode(mouseMode: MouseMode) { - this.mouseMode = mouseMode; - if (mouseMode === MouseMode.AREA_SELECT) { - this.selecting = true; - this.container.style.cursor = 'crosshair'; - } else { - this.selecting = false; - this.container.style.cursor = 'default'; + // Set the renderTarget back to the default. + this.renderer.setRenderTarget(null); + if (axes != null) { + this.scene.add(axes); } } - - /** Set the colors for every data point. (RGB triplets) */ - setPointColors(colors: Float32Array) { - this.pointColors = colors; - } - - /** Set the scale factors for every data point. (scalars) */ - setPointScaleFactors(scaleFactors: Float32Array) { - this.pointScaleFactors = scaleFactors; - } - - /** Set the labels to rendered */ - setLabels(labels: LabelRenderParams) { - this.labels = labels; - } - - /** Set the colors for every data polyline. (RGB triplets) */ - setPolylineColors(colors: {[polylineIndex: number]: Float32Array}) { - this.polylineColors = colors; - } - - setPolylineOpacities(opacities: Float32Array) { - this.polylineOpacities = opacities; + // Render second pass to color buffer, to be displayed on the canvas. + this.visualizers.forEach((v) => v.onRender(rc)); + this.renderer.render(this.scene, this.camera); + } + setMouseMode(mouseMode: MouseMode) { + this.mouseMode = mouseMode; + if (mouseMode === MouseMode.AREA_SELECT) { + this.selecting = true; + this.container.style.cursor = 'crosshair'; + } else { + this.selecting = false; + this.container.style.cursor = 'default'; } - - setPolylineWidths(widths: Float32Array) { - this.polylineWidths = widths; + } + /** Set the colors for every data point. (RGB triplets) */ + setPointColors(colors: Float32Array) { + this.pointColors = colors; + } + /** Set the scale factors for every data point. (scalars) */ + setPointScaleFactors(scaleFactors: Float32Array) { + this.pointScaleFactors = scaleFactors; + } + /** Set the labels to rendered */ + setLabels(labels: LabelRenderParams) { + this.labels = labels; + } + /** Set the colors for every data polyline. (RGB triplets) */ + setPolylineColors(colors: {[polylineIndex: number]: Float32Array}) { + this.polylineColors = colors; + } + setPolylineOpacities(opacities: Float32Array) { + this.polylineOpacities = opacities; + } + setPolylineWidths(widths: Float32Array) { + this.polylineWidths = widths; + } + getMouseMode(): MouseMode { + return this.mouseMode; + } + resetZoom() { + this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality)); + this.render(); + } + setDayNightMode(isNight: boolean) { + const canvases = this.container.querySelectorAll('canvas'); + const filterValue = isNight ? 'invert(100%)' : null; + for (let i = 0; i < canvases.length; i++) { + canvases[i].style.filter = filterValue; } - - getMouseMode(): MouseMode { - return this.mouseMode; + } + resize(render = true) { + const [oldW, oldH] = [this.width, this.height]; + const [newW, newH] = this.getLayoutValues(); + if (this.dimensionality === 3) { + const camera = this.camera as THREE.PerspectiveCamera; + camera.aspect = newW / newH; + camera.updateProjectionMatrix(); + } else { + const camera = this.camera as THREE.OrthographicCamera; + // Scale the ortho frustum by however much the window changed. + const scaleW = newW / oldW; + const scaleH = newH / oldH; + const newCamHalfWidth = ((camera.right - camera.left) * scaleW) / 2; + const newCamHalfHeight = ((camera.top - camera.bottom) * scaleH) / 2; + camera.top = newCamHalfHeight; + camera.bottom = -newCamHalfHeight; + camera.left = -newCamHalfWidth; + camera.right = newCamHalfWidth; + camera.updateProjectionMatrix(); + } + // Accouting for retina displays. + const dpr = window.devicePixelRatio || 1; + this.renderer.setPixelRatio(dpr); + this.renderer.setSize(newW, newH); + // the picking texture needs to be exactly the same as the render texture. + { + const renderCanvasSize = new THREE.Vector2(); + // TODO(stephanwlee): Remove casting to any after three.js typing is + // proper. + (this.renderer as any).getSize(renderCanvasSize); + const pixelRatio = this.renderer.getPixelRatio(); + this.pickingTexture = new THREE.WebGLRenderTarget( + renderCanvasSize.width * pixelRatio, + renderCanvasSize.height * pixelRatio + ); + this.pickingTexture.texture.minFilter = THREE.LinearFilter; } - - resetZoom() { - this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality)); + this.visualizers.forEach((v) => v.onResize(newW, newH)); + if (render) { this.render(); } - - setDayNightMode(isNight: boolean) { - const canvases = this.container.querySelectorAll('canvas'); - const filterValue = isNight ? 'invert(100%)' : null; - for (let i = 0; i < canvases.length; i++) { - canvases[i].style.filter = filterValue; - } - } - - resize(render = true) { - const [oldW, oldH] = [this.width, this.height]; - const [newW, newH] = this.getLayoutValues(); - - if (this.dimensionality === 3) { - const camera = this.camera as THREE.PerspectiveCamera; - camera.aspect = newW / newH; - camera.updateProjectionMatrix(); - } else { - const camera = this.camera as THREE.OrthographicCamera; - // Scale the ortho frustum by however much the window changed. - const scaleW = newW / oldW; - const scaleH = newH / oldH; - const newCamHalfWidth = ((camera.right - camera.left) * scaleW) / 2; - const newCamHalfHeight = ((camera.top - camera.bottom) * scaleH) / 2; - camera.top = newCamHalfHeight; - camera.bottom = -newCamHalfHeight; - camera.left = -newCamHalfWidth; - camera.right = newCamHalfWidth; - camera.updateProjectionMatrix(); - } - - // Accouting for retina displays. - const dpr = window.devicePixelRatio || 1; - this.renderer.setPixelRatio(dpr); - this.renderer.setSize(newW, newH); - - // the picking texture needs to be exactly the same as the render texture. - { - const renderCanvasSize = new THREE.Vector2(); - // TODO(stephanwlee): Remove casting to any after three.js typing is - // proper. - (this.renderer as any).getSize(renderCanvasSize); - const pixelRatio = this.renderer.getPixelRatio(); - this.pickingTexture = new THREE.WebGLRenderTarget( - renderCanvasSize.width * pixelRatio, - renderCanvasSize.height * pixelRatio - ); - this.pickingTexture.texture.minFilter = THREE.LinearFilter; - } - - this.visualizers.forEach((v) => v.onResize(newW, newH)); - - if (render) { - this.render(); - } - } - - onCameraMove(listener: OnCameraMoveListener) { - this.onCameraMoveListeners.push(listener); - } - - clickOnPoint(pointIndex: number) { - this.nearestPoint = pointIndex; - this.onClick(null, false); - } } -} // namespace vz_projector + onCameraMove(listener: OnCameraMoveListener) { + this.onCameraMoveListeners.push(listener); + } + clickOnPoint(pointIndex: number) { + this.nearestPoint = pointIndex; + this.onClick(null, false); + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotRectangleSelector.ts b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotRectangleSelector.ts index 7dd563333e..1aadffd2ad 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotRectangleSelector.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotRectangleSelector.ts @@ -12,102 +12,89 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - const FILL = '#dddddd'; - const FILL_OPACITY = 0.2; - const STROKE = '#aaaaaa'; - const STROKE_WIDTH = 2; - const STROKE_DASHARRAY = '10 5'; - - export interface ScatterBoundingBox { - // The bounding box (x, y) position refers to the bottom left corner of the - // rect. - x: number; - y: number; - width: number; - height: number; - } +const FILL = '#dddddd'; +const FILL_OPACITY = 0.2; +const STROKE = '#aaaaaa'; +const STROKE_WIDTH = 2; +const STROKE_DASHARRAY = '10 5'; +export interface ScatterBoundingBox { + // The bounding box (x, y) position refers to the bottom left corner of the + // rect. + x: number; + y: number; + width: number; + height: number; +} +/** + * A class that manages and renders a data selection rectangle. + */ +export class ScatterPlotRectangleSelector { + private svgElement: SVGElement; + private rectElement: SVGRectElement; + private isMouseDown: boolean; + private startCoordinates: [number, number]; + private lastBoundingBox: ScatterBoundingBox; + private selectionCallback: (boundingBox: ScatterBoundingBox) => void; /** - * A class that manages and renders a data selection rectangle. + * @param container The container HTML element that the selection SVG rect + * will be a child of. + * @param selectionCallback The callback that accepts a bounding box to be + * called when selection changes. Currently, we only call the callback on + * mouseUp. */ - export class ScatterPlotRectangleSelector { - private svgElement: SVGElement; - private rectElement: SVGRectElement; - - private isMouseDown: boolean; - private startCoordinates: [number, number]; - private lastBoundingBox: ScatterBoundingBox; - - private selectionCallback: (boundingBox: ScatterBoundingBox) => void; - - /** - * @param container The container HTML element that the selection SVG rect - * will be a child of. - * @param selectionCallback The callback that accepts a bounding box to be - * called when selection changes. Currently, we only call the callback on - * mouseUp. - */ - constructor( - container: HTMLElement, - selectionCallback: (boundingBox: ScatterBoundingBox) => void - ) { - this.svgElement = container.querySelector('#selector') as SVGElement; - this.rectElement = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'rect' - ); - this.rectElement.style.stroke = STROKE; - this.rectElement.style.strokeDasharray = STROKE_DASHARRAY; - this.rectElement.style.strokeWidth = '' + STROKE_WIDTH; - this.rectElement.style.fill = FILL; - this.rectElement.style.fillOpacity = '' + FILL_OPACITY; - this.svgElement.appendChild(this.rectElement); - - this.selectionCallback = selectionCallback; - this.isMouseDown = false; - } - - onMouseDown(offsetX: number, offsetY: number) { - this.isMouseDown = true; - this.svgElement.style.display = 'block'; - - this.startCoordinates = [offsetX, offsetY]; - this.lastBoundingBox = { - x: this.startCoordinates[0], - y: this.startCoordinates[1], - width: 1, - height: 1, - }; - } - - onMouseMove(offsetX: number, offsetY: number) { - if (!this.isMouseDown) { - return; - } - - this.lastBoundingBox.x = Math.min(offsetX, this.startCoordinates[0]); - this.lastBoundingBox.y = Math.max(offsetY, this.startCoordinates[1]); - this.lastBoundingBox.width = - Math.max(offsetX, this.startCoordinates[0]) - this.lastBoundingBox.x; - this.lastBoundingBox.height = - this.lastBoundingBox.y - Math.min(offsetY, this.startCoordinates[1]); - - this.rectElement.setAttribute('x', '' + this.lastBoundingBox.x); - this.rectElement.setAttribute( - 'y', - '' + (this.lastBoundingBox.y - this.lastBoundingBox.height) - ); - this.rectElement.setAttribute('width', '' + this.lastBoundingBox.width); - this.rectElement.setAttribute('height', '' + this.lastBoundingBox.height); - } - - onMouseUp() { - this.isMouseDown = false; - this.svgElement.style.display = 'none'; - this.rectElement.setAttribute('width', '0'); - this.rectElement.setAttribute('height', '0'); - this.selectionCallback(this.lastBoundingBox); + constructor( + container: HTMLElement, + selectionCallback: (boundingBox: ScatterBoundingBox) => void + ) { + this.svgElement = container.querySelector('#selector') as SVGElement; + this.rectElement = document.createElementNS( + 'http://www.w3.org/2000/svg', + 'rect' + ); + this.rectElement.style.stroke = STROKE; + this.rectElement.style.strokeDasharray = STROKE_DASHARRAY; + this.rectElement.style.strokeWidth = '' + STROKE_WIDTH; + this.rectElement.style.fill = FILL; + this.rectElement.style.fillOpacity = '' + FILL_OPACITY; + this.svgElement.appendChild(this.rectElement); + this.selectionCallback = selectionCallback; + this.isMouseDown = false; + } + onMouseDown(offsetX: number, offsetY: number) { + this.isMouseDown = true; + this.svgElement.style.display = 'block'; + this.startCoordinates = [offsetX, offsetY]; + this.lastBoundingBox = { + x: this.startCoordinates[0], + y: this.startCoordinates[1], + width: 1, + height: 1, + }; + } + onMouseMove(offsetX: number, offsetY: number) { + if (!this.isMouseDown) { + return; } + this.lastBoundingBox.x = Math.min(offsetX, this.startCoordinates[0]); + this.lastBoundingBox.y = Math.max(offsetY, this.startCoordinates[1]); + this.lastBoundingBox.width = + Math.max(offsetX, this.startCoordinates[0]) - this.lastBoundingBox.x; + this.lastBoundingBox.height = + this.lastBoundingBox.y - Math.min(offsetY, this.startCoordinates[1]); + this.rectElement.setAttribute('x', '' + this.lastBoundingBox.x); + this.rectElement.setAttribute( + 'y', + '' + (this.lastBoundingBox.y - this.lastBoundingBox.height) + ); + this.rectElement.setAttribute('width', '' + this.lastBoundingBox.width); + this.rectElement.setAttribute('height', '' + this.lastBoundingBox.height); + } + onMouseUp() { + this.isMouseDown = false; + this.svgElement.style.display = 'none'; + this.rectElement.setAttribute('width', '0'); + this.rectElement.setAttribute('height', '0'); + this.selectionCallback(this.lastBoundingBox); } -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizer.ts b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizer.ts index dc1ea6e8d7..30925ce6c3 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizer.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizer.ts @@ -12,39 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { +import * as THREE from 'three'; + +import {RenderContext} from './renderContext'; + +export interface ScatterPlotVisualizer { + /** Called to initialize the visualizer with the primary scene. */ + setScene(scene: THREE.Scene); + /** + * Called when the main scatter plot tears down the visualizer. Remove all + * objects from the scene, and dispose any heavy resources. + */ + dispose(); + /** + * Called when the positions of the scatter plot points have changed. + */ + onPointPositionsChanged(newWorldSpacePointPositions: Float32Array); + /** + * Called immediately before the main scatter plot performs a picking + * (selection) render. Set up render state for any geometry to use picking IDs + * instead of visual colors. + */ + onPickingRender(renderContext: RenderContext); + /** + * Called immediately before the main scatter plot performs a color (visual) + * render. Set up render state, lights, etc here. + */ + onRender(renderContext: RenderContext); /** - * ScatterPlotVisualizer is an interface used by ScatterPlotContainer - * to manage and aggregate any number of concurrent visualization behaviors. - * To add a new visualization to the 3D scatter plot, create a new class that - * implements this interface and attach it to the ScatterPlotContainer. + * Called when the canvas size changes. */ - export interface ScatterPlotVisualizer { - /** Called to initialize the visualizer with the primary scene. */ - setScene(scene: THREE.Scene); - /** - * Called when the main scatter plot tears down the visualizer. Remove all - * objects from the scene, and dispose any heavy resources. - */ - dispose(); - /** - * Called when the positions of the scatter plot points have changed. - */ - onPointPositionsChanged(newWorldSpacePointPositions: Float32Array); - /** - * Called immediately before the main scatter plot performs a picking - * (selection) render. Set up render state for any geometry to use picking IDs - * instead of visual colors. - */ - onPickingRender(renderContext: RenderContext); - /** - * Called immediately before the main scatter plot performs a color (visual) - * render. Set up render state, lights, etc here. - */ - onRender(renderContext: RenderContext); - /** - * Called when the canvas size changes. - */ - onResize(newWidth: number, newHeight: number); - } -} // namespace vz_projector + onResize(newWidth: number, newHeight: number); +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizer3DLabels.ts b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizer3DLabels.ts index 3406ca9aa3..5f577071a5 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizer3DLabels.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizer3DLabels.ts @@ -12,33 +12,36 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - const FONT_SIZE = 80; - const ONE_OVER_FONT_SIZE = 1 / FONT_SIZE; - const LABEL_SCALE = 2.2; // at 1:1 texel/pixel ratio - const LABEL_COLOR = 'black'; - const LABEL_BACKGROUND = 'white'; - const MAX_CANVAS_DIMENSION = 8192; - const NUM_GLYPHS = 256; - const RGB_ELEMENTS_PER_ENTRY = 3; - const XYZ_ELEMENTS_PER_ENTRY = 3; - const UV_ELEMENTS_PER_ENTRY = 2; - const VERTICES_PER_GLYPH = 2 * 3; // 2 triangles, 3 verts per triangle - - /** - * Each label is made up of triangles (two per letter.) Each vertex, then, is - * the corner of one of these triangles (and thus the corner of a letter - * rectangle.) - * Each has the following attributes: - * posObj: The (x, y) position of the vertex within the label, where the - * bottom center of the word is positioned at (0, 0); - * position: The position of the label in worldspace. - * vUv: The (u, v) coordinates that index into the glyphs sheet (range 0, 1.) - * color: The color of the label (matches the corresponding point's color.) - * wordShown: Boolean. Whether or not the label is visible. - */ - - const VERTEX_SHADER = ` +import * as THREE from 'three'; + +import {RenderContext} from './renderContext'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; + +const FONT_SIZE = 80; +const ONE_OVER_FONT_SIZE = 1 / FONT_SIZE; +const LABEL_SCALE = 2.2; // at 1:1 texel/pixel ratio +const LABEL_COLOR = 'black'; +const LABEL_BACKGROUND = 'white'; +const MAX_CANVAS_DIMENSION = 8192; +const NUM_GLYPHS = 256; +const RGB_ELEMENTS_PER_ENTRY = 3; +const XYZ_ELEMENTS_PER_ENTRY = 3; +const UV_ELEMENTS_PER_ENTRY = 2; +const VERTICES_PER_GLYPH = 2 * 3; // 2 triangles, 3 verts per triangle +/** + * Each label is made up of triangles (two per letter.) Each vertex, then, is + * the corner of one of these triangles (and thus the corner of a letter + * rectangle.) + * Each has the following attributes: + * posObj: The (x, y) position of the vertex within the label, where the + * bottom center of the word is positioned at (0, 0); + * position: The position of the label in worldspace. + * vUv: The (u, v) coordinates that index into the glyphs sheet (range 0, 1.) + * color: The color of the label (matches the corresponding point's color.) + * wordShown: Boolean. Whether or not the label is visible. + */ +const VERTEX_SHADER = ` attribute vec2 posObj; attribute vec3 color; varying vec2 vUv; @@ -67,8 +70,7 @@ namespace vz_projector { vec4 mvPosition = modelViewMatrix * (vec4(position, 0) + posRotated); gl_Position = projectionMatrix * mvPosition; }`; - - const FRAGMENT_SHADER = ` +const FRAGMENT_SHADER = ` uniform sampler2D texture; uniform bool picking; varying vec2 vUv; @@ -82,319 +84,268 @@ namespace vz_projector { gl_FragColor = vec4(vColor, 1.0) * fromTexture; } }`; - - type GlyphTexture = { - texture: THREE.Texture; - lengths: Float32Array; - offsets: Float32Array; - }; - - /** - * Renders the text labels as 3d geometry in the world. - */ - export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { - private scene: THREE.Scene; - private labelStrings: string[]; - private geometry: THREE.BufferGeometry; - private worldSpacePointPositions: Float32Array; - private pickingColors: Float32Array; - private renderColors: Float32Array; - private material: THREE.ShaderMaterial; - private uniforms: Object; - private labelsMesh: THREE.Mesh; - private positions: THREE.BufferAttribute; - private totalVertexCount: number; - private labelVertexMap: number[][]; - private glyphTexture: GlyphTexture; - - private createGlyphTexture(): GlyphTexture { - let canvas = document.createElement('canvas'); - canvas.width = MAX_CANVAS_DIMENSION; - canvas.height = FONT_SIZE; - let ctx = canvas.getContext('2d'); - ctx.font = 'bold ' + FONT_SIZE * 0.75 + 'px roboto'; - ctx.textBaseline = 'top'; - ctx.fillStyle = LABEL_BACKGROUND; - ctx.rect(0, 0, canvas.width, canvas.height); - ctx.fill(); - ctx.fillStyle = LABEL_COLOR; - let spaceOffset = ctx.measureText(' ').width; - // For each letter, store length, position at the encoded index. - let glyphLengths = new Float32Array(NUM_GLYPHS); - let glyphOffset = new Float32Array(NUM_GLYPHS); - let leftCoord = 0; - for (let i = 0; i < NUM_GLYPHS; i++) { - let text = ' ' + String.fromCharCode(i); - let textLength = ctx.measureText(text).width; - glyphLengths[i] = textLength - spaceOffset; - glyphOffset[i] = leftCoord; - ctx.fillText(text, leftCoord - spaceOffset, 0); - leftCoord += textLength; - } - const tex = util.createTexture(canvas); - return {texture: tex, lengths: glyphLengths, offsets: glyphOffset}; +type GlyphTexture = { + texture: THREE.Texture; + lengths: Float32Array; + offsets: Float32Array; +}; +/** + * Renders the text labels as 3d geometry in the world. + */ +export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { + private scene: THREE.Scene; + private labelStrings: string[]; + private geometry: THREE.BufferGeometry; + private worldSpacePointPositions: Float32Array; + private pickingColors: Float32Array; + private renderColors: Float32Array; + private material: THREE.ShaderMaterial; + private uniforms: any; + private labelsMesh: THREE.Mesh; + private positions: THREE.BufferAttribute; + private totalVertexCount: number; + private labelVertexMap: number[][]; + private glyphTexture: GlyphTexture; + private createGlyphTexture(): GlyphTexture { + let canvas = document.createElement('canvas'); + canvas.width = MAX_CANVAS_DIMENSION; + canvas.height = FONT_SIZE; + let ctx = canvas.getContext('2d'); + ctx.font = 'bold ' + FONT_SIZE * 0.75 + 'px roboto'; + ctx.textBaseline = 'top'; + ctx.fillStyle = LABEL_BACKGROUND; + ctx.rect(0, 0, canvas.width, canvas.height); + ctx.fill(); + ctx.fillStyle = LABEL_COLOR; + let spaceOffset = ctx.measureText(' ').width; + // For each letter, store length, position at the encoded index. + let glyphLengths = new Float32Array(NUM_GLYPHS); + let glyphOffset = new Float32Array(NUM_GLYPHS); + let leftCoord = 0; + for (let i = 0; i < NUM_GLYPHS; i++) { + let text = ' ' + String.fromCharCode(i); + let textLength = ctx.measureText(text).width; + glyphLengths[i] = textLength - spaceOffset; + glyphOffset[i] = leftCoord; + ctx.fillText(text, leftCoord - spaceOffset, 0); + leftCoord += textLength; } - - private processLabelVerts(pointCount: number) { - let numTotalLetters = 0; - this.labelVertexMap = []; - for (let i = 0; i < pointCount; i++) { - const label = this.labelStrings[i]; - let vertsArray: number[] = []; - for (let j = 0; j < label.length; j++) { - for (let k = 0; k < VERTICES_PER_GLYPH; k++) { - vertsArray.push(numTotalLetters * VERTICES_PER_GLYPH + k); - } - numTotalLetters++; + const tex = util.createTexture(canvas); + return {texture: tex, lengths: glyphLengths, offsets: glyphOffset}; + } + private processLabelVerts(pointCount: number) { + let numTotalLetters = 0; + this.labelVertexMap = []; + for (let i = 0; i < pointCount; i++) { + const label = this.labelStrings[i]; + let vertsArray: number[] = []; + for (let j = 0; j < label.length; j++) { + for (let k = 0; k < VERTICES_PER_GLYPH; k++) { + vertsArray.push(numTotalLetters * VERTICES_PER_GLYPH + k); } - this.labelVertexMap.push(vertsArray); + numTotalLetters++; } - this.totalVertexCount = numTotalLetters * VERTICES_PER_GLYPH; + this.labelVertexMap.push(vertsArray); } - - private createColorBuffers(pointCount: number) { - this.pickingColors = new Float32Array( - this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY - ); - this.renderColors = new Float32Array( - this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY - ); - for (let i = 0; i < pointCount; i++) { - let color = new THREE.Color(i); - this.labelVertexMap[i].forEach((j) => { - this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j] = color.r; - this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = color.g; - this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = color.b; - this.renderColors[RGB_ELEMENTS_PER_ENTRY * j] = 1.0; - this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = 1.0; - this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = 1.0; - }); - } + this.totalVertexCount = numTotalLetters * VERTICES_PER_GLYPH; + } + private createColorBuffers(pointCount: number) { + this.pickingColors = new Float32Array( + this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY + ); + this.renderColors = new Float32Array( + this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY + ); + for (let i = 0; i < pointCount; i++) { + let color = new THREE.Color(i); + this.labelVertexMap[i].forEach((j) => { + this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j] = color.r; + this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = color.g; + this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = color.b; + this.renderColors[RGB_ELEMENTS_PER_ENTRY * j] = 1; + this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = 1; + this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = 1; + }); } - - private createLabels() { - if (this.labelStrings == null || this.worldSpacePointPositions == null) { - return; + } + private createLabels() { + if (this.labelStrings == null || this.worldSpacePointPositions == null) { + return; + } + const pointCount = + this.worldSpacePointPositions.length / XYZ_ELEMENTS_PER_ENTRY; + if (pointCount !== this.labelStrings.length) { + return; + } + this.glyphTexture = this.createGlyphTexture(); + this.uniforms = { + texture: {type: 't'}, + picking: {type: 'bool'}, + }; + this.material = new THREE.ShaderMaterial({ + uniforms: this.uniforms, + transparent: true, + vertexShader: VERTEX_SHADER, + fragmentShader: FRAGMENT_SHADER, + }); + this.processLabelVerts(pointCount); + this.createColorBuffers(pointCount); + let positionArray = new Float32Array( + this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY + ); + this.positions = new THREE.BufferAttribute( + positionArray, + XYZ_ELEMENTS_PER_ENTRY + ); + let posArray = new Float32Array( + this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY + ); + let uvArray = new Float32Array( + this.totalVertexCount * UV_ELEMENTS_PER_ENTRY + ); + let colorsArray = new Float32Array( + this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY + ); + let positionObject = new THREE.BufferAttribute(posArray, 2); + let uv = new THREE.BufferAttribute(uvArray, UV_ELEMENTS_PER_ENTRY); + let colors = new THREE.BufferAttribute(colorsArray, RGB_ELEMENTS_PER_ENTRY); + this.geometry = new THREE.BufferGeometry(); + this.geometry.addAttribute('posObj', positionObject); + this.geometry.addAttribute('position', this.positions); + this.geometry.addAttribute('uv', uv); + this.geometry.addAttribute('color', colors); + let lettersSoFar = 0; + for (let i = 0; i < pointCount; i++) { + const label = this.labelStrings[i]; + let leftOffset = 0; + // Determine length of word in pixels. + for (let j = 0; j < label.length; j++) { + let letterCode = label.charCodeAt(j); + leftOffset += this.glyphTexture.lengths[letterCode]; } - const pointCount = - this.worldSpacePointPositions.length / XYZ_ELEMENTS_PER_ENTRY; - if (pointCount !== this.labelStrings.length) { - return; + leftOffset /= -2; // centers text horizontally around the origin + for (let j = 0; j < label.length; j++) { + let letterCode = label.charCodeAt(j); + let letterWidth = this.glyphTexture.lengths[letterCode]; + let scale = FONT_SIZE; + let right = (leftOffset + letterWidth) / scale; + let left = leftOffset / scale; + let top = FONT_SIZE / scale; + // First triangle + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, left, 0); + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, right, 0); + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, left, top); + // Second triangle + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, left, top); + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, right, 0); + positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, right, top); + // Set UVs based on letter. + let uLeft = this.glyphTexture.offsets[letterCode]; + let uRight = this.glyphTexture.offsets[letterCode] + letterWidth; + // Scale so that uvs lie between 0 and 1 on the texture. + uLeft /= MAX_CANVAS_DIMENSION; + uRight /= MAX_CANVAS_DIMENSION; + let vTop = 1; + let vBottom = 0; + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, uLeft, vTop); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, uRight, vTop); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, uLeft, vBottom); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, uLeft, vBottom); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, uRight, vTop); + uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, uRight, vBottom); + lettersSoFar++; + leftOffset += letterWidth; } - this.glyphTexture = this.createGlyphTexture(); - - this.uniforms = { - texture: {type: 't'}, - picking: {type: 'bool'}, - }; - - this.material = new THREE.ShaderMaterial({ - uniforms: this.uniforms, - transparent: true, - vertexShader: VERTEX_SHADER, - fragmentShader: FRAGMENT_SHADER, + } + for (let i = 0; i < pointCount; i++) { + const p = util.vector3FromPackedArray(this.worldSpacePointPositions, i); + this.labelVertexMap[i].forEach((j) => { + this.positions.setXYZ(j, p.x, p.y, p.z); }); - - this.processLabelVerts(pointCount); - this.createColorBuffers(pointCount); - - let positionArray = new Float32Array( - this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY - ); - this.positions = new THREE.BufferAttribute( - positionArray, - XYZ_ELEMENTS_PER_ENTRY - ); - - let posArray = new Float32Array( - this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY - ); - let uvArray = new Float32Array( - this.totalVertexCount * UV_ELEMENTS_PER_ENTRY - ); - let colorsArray = new Float32Array( - this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY - ); - let positionObject = new THREE.BufferAttribute(posArray, 2); - let uv = new THREE.BufferAttribute(uvArray, UV_ELEMENTS_PER_ENTRY); - let colors = new THREE.BufferAttribute( - colorsArray, - RGB_ELEMENTS_PER_ENTRY + } + this.labelsMesh = new THREE.Mesh(this.geometry, this.material); + this.labelsMesh.frustumCulled = false; + this.scene.add(this.labelsMesh); + } + private colorLabels(pointColors: Float32Array) { + if ( + this.labelStrings == null || + this.geometry == null || + pointColors == null + ) { + return; + } + const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + (colors as any).setArray(this.renderColors); + const n = pointColors.length / XYZ_ELEMENTS_PER_ENTRY; + let src = 0; + for (let i = 0; i < n; ++i) { + const c = new THREE.Color( + pointColors[src], + pointColors[src + 1], + pointColors[src + 2] ); - - this.geometry = new THREE.BufferGeometry(); - this.geometry.addAttribute('posObj', positionObject); - this.geometry.addAttribute('position', this.positions); - this.geometry.addAttribute('uv', uv); - this.geometry.addAttribute('color', colors); - - let lettersSoFar = 0; - for (let i = 0; i < pointCount; i++) { - const label = this.labelStrings[i]; - let leftOffset = 0; - // Determine length of word in pixels. - for (let j = 0; j < label.length; j++) { - let letterCode = label.charCodeAt(j); - leftOffset += this.glyphTexture.lengths[letterCode]; - } - leftOffset /= -2; // centers text horizontally around the origin - for (let j = 0; j < label.length; j++) { - let letterCode = label.charCodeAt(j); - let letterWidth = this.glyphTexture.lengths[letterCode]; - let scale = FONT_SIZE; - let right = (leftOffset + letterWidth) / scale; - let left = leftOffset / scale; - let top = FONT_SIZE / scale; - - // First triangle - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, left, 0); - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, right, 0); - positionObject.setXY( - lettersSoFar * VERTICES_PER_GLYPH + 2, - left, - top - ); - - // Second triangle - positionObject.setXY( - lettersSoFar * VERTICES_PER_GLYPH + 3, - left, - top - ); - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, right, 0); - positionObject.setXY( - lettersSoFar * VERTICES_PER_GLYPH + 5, - right, - top - ); - - // Set UVs based on letter. - let uLeft = this.glyphTexture.offsets[letterCode]; - let uRight = this.glyphTexture.offsets[letterCode] + letterWidth; - // Scale so that uvs lie between 0 and 1 on the texture. - uLeft /= MAX_CANVAS_DIMENSION; - uRight /= MAX_CANVAS_DIMENSION; - let vTop = 1; - let vBottom = 0; - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, uLeft, vTop); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, uRight, vTop); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, uLeft, vBottom); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, uLeft, vBottom); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, uRight, vTop); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, uRight, vBottom); - - lettersSoFar++; - leftOffset += letterWidth; - } - } - - for (let i = 0; i < pointCount; i++) { - const p = util.vector3FromPackedArray(this.worldSpacePointPositions, i); - this.labelVertexMap[i].forEach((j) => { - this.positions.setXYZ(j, p.x, p.y, p.z); - }); + const m = this.labelVertexMap[i].length; + for (let j = 0; j < m; ++j) { + colors.setXYZ(this.labelVertexMap[i][j], c.r, c.g, c.b); } - - this.labelsMesh = new THREE.Mesh(this.geometry, this.material); - this.labelsMesh.frustumCulled = false; - this.scene.add(this.labelsMesh); + src += RGB_ELEMENTS_PER_ENTRY; } - - private colorLabels(pointColors: Float32Array) { - if ( - this.labelStrings == null || - this.geometry == null || - pointColors == null - ) { - return; - } - - const colors = this.geometry.getAttribute( - 'color' - ) as THREE.BufferAttribute; - (colors as any).setArray(this.renderColors); - - const n = pointColors.length / XYZ_ELEMENTS_PER_ENTRY; - let src = 0; - for (let i = 0; i < n; ++i) { - const c = new THREE.Color( - pointColors[src], - pointColors[src + 1], - pointColors[src + 2] - ); - const m = this.labelVertexMap[i].length; - for (let j = 0; j < m; ++j) { - colors.setXYZ(this.labelVertexMap[i][j], c.r, c.g, c.b); - } - src += RGB_ELEMENTS_PER_ENTRY; + colors.needsUpdate = true; + } + setScene(scene: THREE.Scene) { + this.scene = scene; + } + dispose() { + if (this.labelsMesh) { + if (this.scene) { + this.scene.remove(this.labelsMesh); } - colors.needsUpdate = true; + this.labelsMesh = null; } - - setScene(scene: THREE.Scene) { - this.scene = scene; + if (this.geometry) { + this.geometry.dispose(); + this.geometry = null; } - - dispose() { - if (this.labelsMesh) { - if (this.scene) { - this.scene.remove(this.labelsMesh); - } - this.labelsMesh = null; - } - if (this.geometry) { - this.geometry.dispose(); - this.geometry = null; - } - if (this.glyphTexture != null && this.glyphTexture.texture != null) { - this.glyphTexture.texture.dispose(); - this.glyphTexture.texture = null; - } + if (this.glyphTexture != null && this.glyphTexture.texture != null) { + this.glyphTexture.texture.dispose(); + this.glyphTexture.texture = null; } - - onPickingRender(rc: RenderContext) { - if (this.geometry == null) { - this.createLabels(); - } - if (this.geometry == null) { - return; - } - this.material.uniforms.texture.value = this.glyphTexture.texture; - this.material.uniforms.picking.value = true; - const colors = this.geometry.getAttribute( - 'color' - ) as THREE.BufferAttribute; - (colors as any).setArray(this.pickingColors); - colors.needsUpdate = true; + } + onPickingRender(rc: RenderContext) { + if (this.geometry == null) { + this.createLabels(); } - - onRender(rc: RenderContext) { - if (this.geometry == null) { - this.createLabels(); - } - if (this.geometry == null) { - return; - } - this.colorLabels(rc.pointColors); - this.material.uniforms.texture.value = this.glyphTexture.texture; - this.material.uniforms.picking.value = false; - const colors = this.geometry.getAttribute( - 'color' - ) as THREE.BufferAttribute; - (colors as any).setArray(this.renderColors); - colors.needsUpdate = true; + if (this.geometry == null) { + return; } - - onPointPositionsChanged(newPositions: Float32Array) { - this.worldSpacePointPositions = newPositions; - this.dispose(); + this.material.uniforms.texture.value = this.glyphTexture.texture; + this.material.uniforms.picking.value = true; + const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + (colors as any).setArray(this.pickingColors); + colors.needsUpdate = true; + } + onRender(rc: RenderContext) { + if (this.geometry == null) { + this.createLabels(); } - - setLabelStrings(labelStrings: string[]) { - this.labelStrings = labelStrings; - this.dispose(); + if (this.geometry == null) { + return; } - - onResize(newWidth: number, newHeight: number) {} + this.colorLabels(rc.pointColors); + this.material.uniforms.texture.value = this.glyphTexture.texture; + this.material.uniforms.picking.value = false; + const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + (colors as any).setArray(this.renderColors); + colors.needsUpdate = true; + } + onPointPositionsChanged(newPositions: Float32Array) { + this.worldSpacePointPositions = newPositions; + this.dispose(); + } + setLabelStrings(labelStrings: string[]) { + this.labelStrings = labelStrings; + this.dispose(); } -} // namespace vz_projector + onResize(newWidth: number, newHeight: number) {} +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerCanvasLabels.ts b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerCanvasLabels.ts index a537d44488..96778c0886 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerCanvasLabels.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerCanvasLabels.ts @@ -12,186 +12,165 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - const MAX_LABELS_ON_SCREEN = 10000; - const LABEL_STROKE_WIDTH = 3; - const LABEL_FILL_WIDTH = 6; - - /** - * Creates and maintains a 2d canvas on top of the GL canvas. All labels, when - * active, are rendered to the 2d canvas as part of the visible render pass. - */ - export class ScatterPlotVisualizerCanvasLabels - implements ScatterPlotVisualizer { - private worldSpacePointPositions: Float32Array; - private gc: CanvasRenderingContext2D; - private canvas: HTMLCanvasElement; - private labelsActive: boolean = true; - - constructor(container: HTMLElement) { - this.canvas = document.createElement('canvas'); - container.appendChild(this.canvas); - - this.gc = this.canvas.getContext('2d'); - this.canvas.style.position = 'absolute'; - this.canvas.style.left = '0'; - this.canvas.style.top = '0'; - this.canvas.style.pointerEvents = 'none'; +import * as d3 from 'd3'; +import * as THREE from 'three'; + +import {CameraType, RenderContext} from './renderContext'; +import {BoundingBox, CollisionGrid} from './label'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; + +const MAX_LABELS_ON_SCREEN = 10000; +const LABEL_STROKE_WIDTH = 3; +const LABEL_FILL_WIDTH = 6; +/** + * Creates and maintains a 2d canvas on top of the GL canvas. All labels, when + * active, are rendered to the 2d canvas as part of the visible render pass. + */ +export class ScatterPlotVisualizerCanvasLabels + implements ScatterPlotVisualizer { + private worldSpacePointPositions: Float32Array; + private gc: CanvasRenderingContext2D; + private canvas: HTMLCanvasElement; + private labelsActive: boolean = true; + constructor(container: HTMLElement) { + this.canvas = document.createElement('canvas'); + container.appendChild(this.canvas); + this.gc = this.canvas.getContext('2d'); + this.canvas.style.position = 'absolute'; + this.canvas.style.left = '0'; + this.canvas.style.top = '0'; + this.canvas.style.pointerEvents = 'none'; + } + private removeAllLabels() { + const pixelWidth = this.canvas.width * window.devicePixelRatio; + const pixelHeight = this.canvas.height * window.devicePixelRatio; + this.gc.clearRect(0, 0, pixelWidth, pixelHeight); + } + /** Render all of the non-overlapping visible labels to the canvas. */ + private makeLabels(rc: RenderContext) { + if (rc.labels == null || rc.labels.pointIndices.length === 0) { + return; } - - private removeAllLabels() { - const pixelWidth = this.canvas.width * window.devicePixelRatio; - const pixelHeight = this.canvas.height * window.devicePixelRatio; - this.gc.clearRect(0, 0, pixelWidth, pixelHeight); + if (this.worldSpacePointPositions == null) { + return; } - - /** Render all of the non-overlapping visible labels to the canvas. */ - private makeLabels(rc: RenderContext) { - if (rc.labels == null || rc.labels.pointIndices.length === 0) { - return; - } - if (this.worldSpacePointPositions == null) { - return; - } - - const lrc = rc.labels; - const sceneIs3D: boolean = rc.cameraType === CameraType.Perspective; - const labelHeight = parseInt(this.gc.font, 10); - const dpr = window.devicePixelRatio; - - let grid: CollisionGrid; + const lrc = rc.labels; + const sceneIs3D: boolean = rc.cameraType === CameraType.Perspective; + const labelHeight = parseInt(this.gc.font, 10); + const dpr = window.devicePixelRatio; + let grid: CollisionGrid; + { + const pixw = this.canvas.width * dpr; + const pixh = this.canvas.height * dpr; + const bb: BoundingBox = {loX: 0, hiX: pixw, loY: 0, hiY: pixh}; + grid = new CollisionGrid(bb, pixw / 25, pixh / 50); + } + let opacityMap = d3 + .scalePow() + .exponent(Math.E) + .domain([rc.farthestCameraSpacePointZ, rc.nearestCameraSpacePointZ]) + .range([0.1, 1]); + const camPos = rc.camera.position; + const camToTarget = camPos.clone().sub(rc.cameraTarget); + let camToPoint = new THREE.Vector3(); + this.gc.textBaseline = 'middle'; + this.gc.miterLimit = 2; + // Have extra space between neighboring labels. Don't pack too tightly. + const labelMargin = 2; + // Shift the label to the right of the point circle. + const xShift = 4; + const n = Math.min(MAX_LABELS_ON_SCREEN, lrc.pointIndices.length); + for (let i = 0; i < n; ++i) { + let point: THREE.Vector3; { - const pixw = this.canvas.width * dpr; - const pixh = this.canvas.height * dpr; - const bb: BoundingBox = {loX: 0, hiX: pixw, loY: 0, hiY: pixh}; - grid = new CollisionGrid(bb, pixw / 25, pixh / 50); + const pi = lrc.pointIndices[i]; + point = util.vector3FromPackedArray(this.worldSpacePointPositions, pi); } - - let opacityMap = d3 - .scalePow() - .exponent(Math.E) - .domain([rc.farthestCameraSpacePointZ, rc.nearestCameraSpacePointZ]) - .range([0.1, 1]); - - const camPos = rc.camera.position; - const camToTarget = camPos.clone().sub(rc.cameraTarget); - let camToPoint = new THREE.Vector3(); - - this.gc.textBaseline = 'middle'; - this.gc.miterLimit = 2; - - // Have extra space between neighboring labels. Don't pack too tightly. - const labelMargin = 2; - // Shift the label to the right of the point circle. - const xShift = 4; - - const n = Math.min(MAX_LABELS_ON_SCREEN, lrc.pointIndices.length); - for (let i = 0; i < n; ++i) { - let point: THREE.Vector3; - { - const pi = lrc.pointIndices[i]; - point = util.vector3FromPackedArray( - this.worldSpacePointPositions, - pi - ); - } - - // discard points that are behind the camera - camToPoint.copy(camPos).sub(point); - if (camToTarget.dot(camToPoint) < 0) { - continue; - } - - let [x, y] = util.vector3DToScreenCoords( - rc.camera, - rc.screenWidth, - rc.screenHeight, - point - ); - x += xShift; - - // Computing the width of the font is expensive, - // so we assume width of 1 at first. Then, if the label doesn't - // conflict with other labels, we measure the actual width. - const textBoundingBox: BoundingBox = { - loX: x - labelMargin, - hiX: x + 1 + labelMargin, - loY: y - labelHeight / 2 - labelMargin, - hiY: y + labelHeight / 2 + labelMargin, - }; - - if (grid.insert(textBoundingBox, true)) { - const text = lrc.labelStrings[i]; - const fontSize = lrc.defaultFontSize * lrc.scaleFactors[i] * dpr; - this.gc.font = fontSize + 'px roboto'; - - // Now, check with properly computed width. - textBoundingBox.hiX += this.gc.measureText(text).width - 1; - if (grid.insert(textBoundingBox)) { - let opacity = 1; - if (sceneIs3D && lrc.useSceneOpacityFlags[i] === 1) { - opacity = opacityMap(camToPoint.length()); - } - this.gc.fillStyle = this.styleStringFromPackedRgba( - lrc.fillColors, - i, - opacity - ); - this.gc.strokeStyle = this.styleStringFromPackedRgba( - lrc.strokeColors, - i, - opacity - ); - this.gc.lineWidth = LABEL_STROKE_WIDTH; - this.gc.strokeText(text, x, y); - this.gc.lineWidth = LABEL_FILL_WIDTH; - this.gc.fillText(text, x, y); + // discard points that are behind the camera + camToPoint.copy(camPos).sub(point); + if (camToTarget.dot(camToPoint) < 0) { + continue; + } + let [x, y] = util.vector3DToScreenCoords( + rc.camera, + rc.screenWidth, + rc.screenHeight, + point + ); + x += xShift; + // Computing the width of the font is expensive, + // so we assume width of 1 at first. Then, if the label doesn't + // conflict with other labels, we measure the actual width. + const textBoundingBox: BoundingBox = { + loX: x - labelMargin, + hiX: x + 1 + labelMargin, + loY: y - labelHeight / 2 - labelMargin, + hiY: y + labelHeight / 2 + labelMargin, + }; + if (grid.insert(textBoundingBox, true)) { + const text = lrc.labelStrings[i]; + const fontSize = lrc.defaultFontSize * lrc.scaleFactors[i] * dpr; + this.gc.font = fontSize + 'px roboto'; + // Now, check with properly computed width. + textBoundingBox.hiX += this.gc.measureText(text).width - 1; + if (grid.insert(textBoundingBox)) { + let opacity = 1; + if (sceneIs3D && lrc.useSceneOpacityFlags[i] === 1) { + opacity = opacityMap(camToPoint.length()); } + this.gc.fillStyle = this.styleStringFromPackedRgba( + lrc.fillColors, + i, + opacity + ); + this.gc.strokeStyle = this.styleStringFromPackedRgba( + lrc.strokeColors, + i, + opacity + ); + this.gc.lineWidth = LABEL_STROKE_WIDTH; + this.gc.strokeText(text, x, y); + this.gc.lineWidth = LABEL_FILL_WIDTH; + this.gc.fillText(text, x, y); } } } - - private styleStringFromPackedRgba( - packedRgbaArray: Uint8Array, - colorIndex: number, - opacity: number - ): string { - const offset = colorIndex * 3; - const r = packedRgbaArray[offset]; - const g = packedRgbaArray[offset + 1]; - const b = packedRgbaArray[offset + 2]; - return 'rgba(' + r + ',' + g + ',' + b + ',' + opacity + ')'; - } - - onResize(newWidth: number, newHeight: number) { - let dpr = window.devicePixelRatio; - this.canvas.width = newWidth * dpr; - this.canvas.height = newHeight * dpr; - this.canvas.style.width = newWidth + 'px'; - this.canvas.style.height = newHeight + 'px'; - } - - dispose() { - this.removeAllLabels(); - this.canvas = null; - this.gc = null; - } - - onPointPositionsChanged(newPositions: Float32Array) { - this.worldSpacePointPositions = newPositions; - this.removeAllLabels(); - } - - onRender(rc: RenderContext) { - if (!this.labelsActive) { - return; - } - - this.removeAllLabels(); - this.makeLabels(rc); + } + private styleStringFromPackedRgba( + packedRgbaArray: Uint8Array, + colorIndex: number, + opacity: number + ): string { + const offset = colorIndex * 3; + const r = packedRgbaArray[offset]; + const g = packedRgbaArray[offset + 1]; + const b = packedRgbaArray[offset + 2]; + return 'rgba(' + r + ',' + g + ',' + b + ',' + opacity + ')'; + } + onResize(newWidth: number, newHeight: number) { + let dpr = window.devicePixelRatio; + this.canvas.width = newWidth * dpr; + this.canvas.height = newHeight * dpr; + this.canvas.style.width = newWidth + 'px'; + this.canvas.style.height = newHeight + 'px'; + } + dispose() { + this.removeAllLabels(); + this.canvas = null; + this.gc = null; + } + onPointPositionsChanged(newPositions: Float32Array) { + this.worldSpacePointPositions = newPositions; + this.removeAllLabels(); + } + onRender(rc: RenderContext) { + if (!this.labelsActive) { + return; } - - setScene(scene: THREE.Scene) {} - onPickingRender(renderContext: RenderContext) {} + this.removeAllLabels(); + this.makeLabels(rc); } -} // namespace vz_projector + setScene(scene: THREE.Scene) {} + onPickingRender(renderContext: RenderContext) {} +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerPolylines.ts b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerPolylines.ts index feb3c7b904..106154a205 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerPolylines.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerPolylines.ts @@ -12,142 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - const RGB_NUM_ELEMENTS = 3; - const XYZ_NUM_ELEMENTS = 3; - - /** - * Renders polylines that connect multiple points in the dataset. - */ - export class ScatterPlotVisualizerPolylines implements ScatterPlotVisualizer { - private dataSet: DataSet; - private scene: THREE.Scene; - private polylines: THREE.Line[]; - private polylinePositionBuffer: { - [polylineIndex: number]: THREE.BufferAttribute; - } = {}; - private polylineColorBuffer: { - [polylineIndex: number]: THREE.BufferAttribute; - } = {}; - - private updateSequenceIndicesInDataSet(ds: DataSet) { - for (let i = 0; i < ds.sequences.length; i++) { - const sequence = ds.sequences[i]; - for (let j = 0; j < sequence.pointIndices.length - 1; j++) { - ds.points[sequence.pointIndices[j]].sequenceIndex = i; - ds.points[sequence.pointIndices[j + 1]].sequenceIndex = i; - } +import * as THREE from 'three'; + +import {DataSet} from './data'; +import {RenderContext} from './renderContext'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; + +const RGB_NUM_ELEMENTS = 3; +const XYZ_NUM_ELEMENTS = 3; +/** + * Renders polylines that connect multiple points in the dataset. + */ +export class ScatterPlotVisualizerPolylines implements ScatterPlotVisualizer { + private dataSet: DataSet; + private scene: THREE.Scene; + private polylines: THREE.Line[]; + private polylinePositionBuffer: { + [polylineIndex: number]: THREE.BufferAttribute; + } = {}; + private polylineColorBuffer: { + [polylineIndex: number]: THREE.BufferAttribute; + } = {}; + private updateSequenceIndicesInDataSet(ds: DataSet) { + for (let i = 0; i < ds.sequences.length; i++) { + const sequence = ds.sequences[i]; + for (let j = 0; j < sequence.pointIndices.length - 1; j++) { + ds.points[sequence.pointIndices[j]].sequenceIndex = i; + ds.points[sequence.pointIndices[j + 1]].sequenceIndex = i; } } - - private createPolylines(scene: THREE.Scene) { - if (!this.dataSet || !this.dataSet.sequences) { - return; - } - - this.updateSequenceIndicesInDataSet(this.dataSet); - this.polylines = []; - - for (let i = 0; i < this.dataSet.sequences.length; i++) { - const geometry = new THREE.BufferGeometry(); - geometry.addAttribute('position', this.polylinePositionBuffer[i]); - geometry.addAttribute('color', this.polylineColorBuffer[i]); - - const material = new THREE.LineBasicMaterial({ - linewidth: 1, // unused default, overwritten by width array. - opacity: 1.0, // unused default, overwritten by opacity array. - transparent: true, - vertexColors: THREE.VertexColors, - }); - - const polyline = new THREE.LineSegments(geometry, material); - polyline.frustumCulled = false; - this.polylines.push(polyline); - scene.add(polyline); - } + } + private createPolylines(scene: THREE.Scene) { + if (!this.dataSet || !this.dataSet.sequences) { + return; } - - dispose() { - if (this.polylines == null) { - return; - } - for (let i = 0; i < this.polylines.length; i++) { - this.scene.remove(this.polylines[i]); - this.polylines[i].geometry.dispose(); - } - this.polylines = null; - this.polylinePositionBuffer = {}; - this.polylineColorBuffer = {}; + this.updateSequenceIndicesInDataSet(this.dataSet); + this.polylines = []; + for (let i = 0; i < this.dataSet.sequences.length; i++) { + const geometry = new THREE.BufferGeometry(); + geometry.addAttribute('position', this.polylinePositionBuffer[i]); + geometry.addAttribute('color', this.polylineColorBuffer[i]); + const material = new THREE.LineBasicMaterial({ + linewidth: 1, + opacity: 1, + transparent: true, + vertexColors: THREE.VertexColors as any, + }); + const polyline = new THREE.LineSegments(geometry, material); + polyline.frustumCulled = false; + this.polylines.push(polyline); + scene.add(polyline); } - - setScene(scene: THREE.Scene) { - this.scene = scene; + } + dispose() { + if (this.polylines == null) { + return; } - - setDataSet(dataSet: DataSet) { - this.dataSet = dataSet; + for (let i = 0; i < this.polylines.length; i++) { + this.scene.remove(this.polylines[i]); + this.polylines[i].geometry.dispose(); } - - onPointPositionsChanged(newPositions: Float32Array) { - if (newPositions == null || this.polylines != null) { - this.dispose(); - } - if (newPositions == null || this.dataSet == null) { - return; - } - // Set up the position buffer arrays for each polyline. - for (let i = 0; i < this.dataSet.sequences.length; i++) { - let sequence = this.dataSet.sequences[i]; - const vertexCount = 2 * (sequence.pointIndices.length - 1); - - let polylines = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS); - this.polylinePositionBuffer[i] = new THREE.BufferAttribute( - polylines, - XYZ_NUM_ELEMENTS - ); - - let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS); - this.polylineColorBuffer[i] = new THREE.BufferAttribute( - colors, - RGB_NUM_ELEMENTS - ); - } - for (let i = 0; i < this.dataSet.sequences.length; i++) { - const sequence = this.dataSet.sequences[i]; - let src = 0; - for (let j = 0; j < sequence.pointIndices.length - 1; j++) { - const p1Index = sequence.pointIndices[j]; - const p2Index = sequence.pointIndices[j + 1]; - const p1 = util.vector3FromPackedArray(newPositions, p1Index); - const p2 = util.vector3FromPackedArray(newPositions, p2Index); - this.polylinePositionBuffer[i].setXYZ(src, p1.x, p1.y, p1.z); - this.polylinePositionBuffer[i].setXYZ(src + 1, p2.x, p2.y, p2.z); - src += 2; - } - this.polylinePositionBuffer[i].needsUpdate = true; - } - - if (this.polylines == null) { - this.createPolylines(this.scene); - } + this.polylines = null; + this.polylinePositionBuffer = {}; + this.polylineColorBuffer = {}; + } + setScene(scene: THREE.Scene) { + this.scene = scene; + } + setDataSet(dataSet: DataSet) { + this.dataSet = dataSet; + } + onPointPositionsChanged(newPositions: Float32Array) { + if (newPositions == null || this.polylines != null) { + this.dispose(); } - - onRender(renderContext: RenderContext) { - if (this.polylines == null) { - return; - } - for (let i = 0; i < this.polylines.length; i++) { - this.polylines[i].material.opacity = renderContext.polylineOpacities[i]; - (this.polylines[i].material as THREE.LineBasicMaterial).linewidth = - renderContext.polylineWidths[i]; - (this.polylineColorBuffer[i] as any).setArray( - renderContext.polylineColors[i] - ); - this.polylineColorBuffer[i].needsUpdate = true; + if (newPositions == null || this.dataSet == null) { + return; + } + // Set up the position buffer arrays for each polyline. + for (let i = 0; i < this.dataSet.sequences.length; i++) { + let sequence = this.dataSet.sequences[i]; + const vertexCount = 2 * (sequence.pointIndices.length - 1); + let polylines = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS); + this.polylinePositionBuffer[i] = new THREE.BufferAttribute( + polylines, + XYZ_NUM_ELEMENTS + ); + let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS); + this.polylineColorBuffer[i] = new THREE.BufferAttribute( + colors, + RGB_NUM_ELEMENTS + ); + } + for (let i = 0; i < this.dataSet.sequences.length; i++) { + const sequence = this.dataSet.sequences[i]; + let src = 0; + for (let j = 0; j < sequence.pointIndices.length - 1; j++) { + const p1Index = sequence.pointIndices[j]; + const p2Index = sequence.pointIndices[j + 1]; + const p1 = util.vector3FromPackedArray(newPositions, p1Index); + const p2 = util.vector3FromPackedArray(newPositions, p2Index); + this.polylinePositionBuffer[i].setXYZ(src, p1.x, p1.y, p1.z); + this.polylinePositionBuffer[i].setXYZ(src + 1, p2.x, p2.y, p2.z); + src += 2; } + this.polylinePositionBuffer[i].needsUpdate = true; + } + if (this.polylines == null) { + this.createPolylines(this.scene); + } + } + onRender(renderContext: RenderContext) { + if (this.polylines == null) { + return; + } + for (let i = 0; i < this.polylines.length; i++) { + const mat = this.polylines[i].material as THREE.LineBasicMaterial; + mat.opacity = renderContext.polylineOpacities[i]; + mat.linewidth = renderContext.polylineWidths[i]; + (this.polylineColorBuffer[i] as any).setArray( + renderContext.polylineColors[i] + ); + this.polylineColorBuffer[i].needsUpdate = true; } - - onPickingRender(renderContext: RenderContext) {} - onResize(newWidth: number, newHeight: number) {} } -} // namespace vz_projector + onPickingRender(renderContext: RenderContext) {} + onResize(newWidth: number, newHeight: number) {} +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerSprites.ts b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerSprites.ts index dd75f57a4b..cdf372eb9f 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerSprites.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/scatterPlotVisualizerSprites.ts @@ -12,17 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - const NUM_POINTS_FOG_THRESHOLD = 5000; - const MIN_POINT_SIZE = 5.0; - const IMAGE_SIZE = 30; - - // Constants relating to the indices of buffer arrays. - const RGB_NUM_ELEMENTS = 3; - const INDEX_NUM_ELEMENTS = 1; - const XYZ_NUM_ELEMENTS = 3; - - const VERTEX_SHADER = ` +import * as THREE from 'three'; + +import {CameraType, RenderContext} from './renderContext'; +import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; +import * as util from './util'; + +const NUM_POINTS_FOG_THRESHOLD = 5000; +const MIN_POINT_SIZE = 5; +const IMAGE_SIZE = 30; +// Constants relating to the indices of buffer arrays. +const RGB_NUM_ELEMENTS = 3; +const INDEX_NUM_ELEMENTS = 1; +const XYZ_NUM_ELEMENTS = 3; +const VERTEX_SHADER = ` // Index of the specific vertex (passed in as bufferAttribute), and the // variable that will be used to pass it to the fragment shader. attribute float spriteIndex; @@ -75,8 +78,7 @@ namespace vz_projector { gl_PointSize = max(outputPointSize * scaleFactor, ${MIN_POINT_SIZE.toFixed(1)}); }`; - - const FRAGMENT_SHADER_POINT_TEST_CHUNK = ` +const FRAGMENT_SHADER_POINT_TEST_CHUNK = ` bool point_in_unit_circle(vec2 spriteCoord) { vec2 centerToP = spriteCoord - vec2(0.5, 0.5); return dot(centerToP, centerToP) < (0.5 * 0.5); @@ -96,8 +98,7 @@ namespace vz_projector { return true; } `; - - const FRAGMENT_SHADER = ` +const FRAGMENT_SHADER = ` varying vec2 xyIndex; varying vec3 vColor; @@ -125,8 +126,7 @@ namespace vz_projector { } ${THREE.ShaderChunk['fog_fragment']} }`; - - const FRAGMENT_SHADER_PICKING = ` +const FRAGMENT_SHADER_PICKING = ` varying vec2 xyIndex; varying vec3 vColor; uniform bool isImage; @@ -145,338 +145,298 @@ namespace vz_projector { gl_FragColor = vec4(vColor, 1); } }`; - - /** - * Uses GL point sprites to render the dataset. - */ - export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { - private scene: THREE.Scene; - private fog: THREE.Fog; - private texture: THREE.Texture = null; - private standinTextureForPoints: THREE.Texture; - private spritesPerRow: number; - private spritesPerColumn: number; - private spriteDimensions: [number, number]; - private spriteIndexBufferAttribute: THREE.BufferAttribute; - private renderMaterial: THREE.ShaderMaterial; - private pickingMaterial: THREE.ShaderMaterial; - - private points: THREE.Points; - private worldSpacePointPositions: Float32Array; - private pickingColors: Float32Array; - private renderColors: Float32Array; - - constructor() { - this.standinTextureForPoints = util.createTexture( - document.createElement('canvas') +/** + * Uses GL point sprites to render the dataset. + */ +export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { + private scene: THREE.Scene; + private fog: THREE.Fog; + private texture: THREE.Texture = null; + private standinTextureForPoints: THREE.Texture; + private spritesPerRow: number; + private spritesPerColumn: number; + private spriteDimensions: [number, number]; + private spriteIndexBufferAttribute: THREE.BufferAttribute; + private renderMaterial: THREE.ShaderMaterial; + private pickingMaterial: THREE.ShaderMaterial; + private points: THREE.Points; + private worldSpacePointPositions: Float32Array; + private pickingColors: Float32Array; + private renderColors: Float32Array; + constructor() { + this.standinTextureForPoints = util.createTexture( + document.createElement('canvas') + ); + this.renderMaterial = this.createRenderMaterial(false); + this.pickingMaterial = this.createPickingMaterial(false); + } + private createTextureFromSpriteAtlas( + spriteAtlas: HTMLImageElement, + spriteDimensions: [number, number], + spriteIndices: Float32Array + ) { + this.texture = util.createTexture(spriteAtlas); + this.spritesPerRow = spriteAtlas.width / spriteDimensions[0]; + this.spritesPerColumn = spriteAtlas.height / spriteDimensions[1]; + this.spriteDimensions = spriteDimensions; + this.spriteIndexBufferAttribute = new THREE.BufferAttribute( + spriteIndices, + INDEX_NUM_ELEMENTS + ); + if (this.points != null) { + (this.points.geometry as THREE.BufferGeometry).addAttribute( + 'spriteIndex', + this.spriteIndexBufferAttribute ); - this.renderMaterial = this.createRenderMaterial(false); - this.pickingMaterial = this.createPickingMaterial(false); } - - private createTextureFromSpriteAtlas( - spriteAtlas: HTMLImageElement, - spriteDimensions: [number, number], - spriteIndices: Float32Array - ) { - this.texture = util.createTexture(spriteAtlas); - this.spritesPerRow = spriteAtlas.width / spriteDimensions[0]; - this.spritesPerColumn = spriteAtlas.height / spriteDimensions[1]; - this.spriteDimensions = spriteDimensions; - this.spriteIndexBufferAttribute = new THREE.BufferAttribute( - spriteIndices, - INDEX_NUM_ELEMENTS + } + private createUniforms(): any { + return { + texture: {type: 't'}, + spritesPerRow: {type: 'f'}, + spritesPerColumn: {type: 'f'}, + fogColor: {type: 'c'}, + fogNear: {type: 'f'}, + fogFar: {type: 'f'}, + isImage: {type: 'bool'}, + sizeAttenuation: {type: 'bool'}, + pointSize: {type: 'f'}, + }; + } + private createRenderMaterial(haveImage: boolean): THREE.ShaderMaterial { + const uniforms = this.createUniforms(); + return new THREE.ShaderMaterial({ + uniforms: uniforms, + vertexShader: VERTEX_SHADER, + fragmentShader: FRAGMENT_SHADER, + transparent: !haveImage, + depthTest: haveImage, + depthWrite: haveImage, + fog: true, + blending: THREE.MultiplyBlending, + }); + } + private createPickingMaterial(haveImage: boolean): THREE.ShaderMaterial { + const uniforms = this.createUniforms(); + return new THREE.ShaderMaterial({ + uniforms: uniforms, + vertexShader: VERTEX_SHADER, + fragmentShader: FRAGMENT_SHADER_PICKING, + transparent: true, + depthTest: true, + depthWrite: true, + fog: false, + blending: THREE.NormalBlending, + }); + } + /** + * Create points, set their locations and actually instantiate the + * geometry. + */ + private createPointSprites(scene: THREE.Scene, positions: Float32Array) { + const pointCount = + positions != null ? positions.length / XYZ_NUM_ELEMENTS : 0; + const geometry = this.createGeometry(pointCount); + this.fog = new THREE.Fog(16777215); // unused value, gets overwritten. + this.points = new THREE.Points(geometry, this.renderMaterial); + this.points.frustumCulled = false; + if (this.spriteIndexBufferAttribute != null) { + (this.points.geometry as THREE.BufferGeometry).addAttribute( + 'spriteIndex', + this.spriteIndexBufferAttribute ); - - if (this.points != null) { - (this.points.geometry as THREE.BufferGeometry).addAttribute( - 'spriteIndex', - this.spriteIndexBufferAttribute - ); - } - } - - private createUniforms(): any { - return { - texture: {type: 't'}, - spritesPerRow: {type: 'f'}, - spritesPerColumn: {type: 'f'}, - fogColor: {type: 'c'}, - fogNear: {type: 'f'}, - fogFar: {type: 'f'}, - isImage: {type: 'bool'}, - sizeAttenuation: {type: 'bool'}, - pointSize: {type: 'f'}, - }; - } - - private createRenderMaterial(haveImage: boolean): THREE.ShaderMaterial { - const uniforms = this.createUniforms(); - return new THREE.ShaderMaterial({ - uniforms: uniforms, - vertexShader: VERTEX_SHADER, - fragmentShader: FRAGMENT_SHADER, - transparent: !haveImage, - depthTest: haveImage, - depthWrite: haveImage, - fog: true, - blending: THREE.MultiplyBlending, - }); - } - - private createPickingMaterial(haveImage: boolean): THREE.ShaderMaterial { - const uniforms = this.createUniforms(); - return new THREE.ShaderMaterial({ - uniforms: uniforms, - vertexShader: VERTEX_SHADER, - fragmentShader: FRAGMENT_SHADER_PICKING, - transparent: true, - depthTest: true, - depthWrite: true, - fog: false, - blending: THREE.NormalBlending, - }); } - - /** - * Create points, set their locations and actually instantiate the - * geometry. - */ - private createPointSprites(scene: THREE.Scene, positions: Float32Array) { - const pointCount = - positions != null ? positions.length / XYZ_NUM_ELEMENTS : 0; - const geometry = this.createGeometry(pointCount); - - this.fog = new THREE.Fog(0xffffff); // unused value, gets overwritten. - - this.points = new THREE.Points(geometry, this.renderMaterial); - this.points.frustumCulled = false; - if (this.spriteIndexBufferAttribute != null) { - (this.points.geometry as THREE.BufferGeometry).addAttribute( - 'spriteIndex', - this.spriteIndexBufferAttribute - ); - } - scene.add(this.points); + scene.add(this.points); + } + private calculatePointSize(sceneIs3D: boolean): number { + if (this.texture != null) { + return sceneIs3D ? IMAGE_SIZE : this.spriteDimensions[0]; } - - private calculatePointSize(sceneIs3D: boolean): number { - if (this.texture != null) { - return sceneIs3D ? IMAGE_SIZE : this.spriteDimensions[0]; + const n = + this.worldSpacePointPositions != null + ? this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS + : 1; + const SCALE = 200; + const LOG_BASE = 8; + const DIVISOR = 1.5; + // Scale point size inverse-logarithmically to the number of points. + const pointSize = SCALE / Math.log(n) / Math.log(LOG_BASE); + return sceneIs3D ? pointSize : pointSize / DIVISOR; + } + /** + * Set up buffer attributes to be used for the points/images. + */ + private createGeometry(pointCount: number): THREE.BufferGeometry { + const n = pointCount; + // Fill pickingColors with each point's unique id as its color. + this.pickingColors = new Float32Array(n * RGB_NUM_ELEMENTS); + { + let dst = 0; + for (let i = 0; i < n; i++) { + const c = new THREE.Color(i); + this.pickingColors[dst++] = c.r; + this.pickingColors[dst++] = c.g; + this.pickingColors[dst++] = c.b; } - const n = - this.worldSpacePointPositions != null - ? this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS - : 1; - const SCALE = 200; - const LOG_BASE = 8; - const DIVISOR = 1.5; - // Scale point size inverse-logarithmically to the number of points. - const pointSize = SCALE / Math.log(n) / Math.log(LOG_BASE); - return sceneIs3D ? pointSize : pointSize / DIVISOR; } - - /** - * Set up buffer attributes to be used for the points/images. - */ - private createGeometry(pointCount: number): THREE.BufferGeometry { - const n = pointCount; - - // Fill pickingColors with each point's unique id as its color. - this.pickingColors = new Float32Array(n * RGB_NUM_ELEMENTS); - { - let dst = 0; - for (let i = 0; i < n; i++) { - const c = new THREE.Color(i); - this.pickingColors[dst++] = c.r; - this.pickingColors[dst++] = c.g; - this.pickingColors[dst++] = c.b; - } - } - - const geometry = new THREE.BufferGeometry(); - geometry.addAttribute( - 'position', - new THREE.BufferAttribute(undefined, XYZ_NUM_ELEMENTS) - ); - geometry.addAttribute( - 'color', - new THREE.BufferAttribute(undefined, RGB_NUM_ELEMENTS) - ); - geometry.addAttribute( - 'scaleFactor', - new THREE.BufferAttribute(undefined, INDEX_NUM_ELEMENTS) - ); - return geometry; + const geometry = new THREE.BufferGeometry(); + geometry.addAttribute( + 'position', + new THREE.BufferAttribute(undefined, XYZ_NUM_ELEMENTS) + ); + geometry.addAttribute( + 'color', + new THREE.BufferAttribute(undefined, RGB_NUM_ELEMENTS) + ); + geometry.addAttribute( + 'scaleFactor', + new THREE.BufferAttribute(undefined, INDEX_NUM_ELEMENTS) + ); + return geometry; + } + private setFogDistances( + sceneIs3D: boolean, + nearestPointZ: number, + farthestPointZ: number + ) { + if (sceneIs3D) { + const n = this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS; + this.fog.near = nearestPointZ; + // If there are fewer points we want less fog. We do this + // by making the "far" value (that is, the distance from the camera to the + // far edge of the fog) proportional to the number of points. + let multiplier = + 2 - Math.min(n, NUM_POINTS_FOG_THRESHOLD) / NUM_POINTS_FOG_THRESHOLD; + this.fog.far = farthestPointZ * multiplier; + } else { + this.fog.near = Infinity; + this.fog.far = Infinity; } - - private setFogDistances( - sceneIs3D: boolean, - nearestPointZ: number, - farthestPointZ: number - ) { - if (sceneIs3D) { - const n = this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS; - this.fog.near = nearestPointZ; - // If there are fewer points we want less fog. We do this - // by making the "far" value (that is, the distance from the camera to the - // far edge of the fog) proportional to the number of points. - let multiplier = - 2 - Math.min(n, NUM_POINTS_FOG_THRESHOLD) / NUM_POINTS_FOG_THRESHOLD; - this.fog.far = farthestPointZ * multiplier; - } else { - this.fog.near = Infinity; - this.fog.far = Infinity; - } + } + dispose() { + this.disposeGeometry(); + this.disposeTextureAtlas(); + } + private disposeGeometry() { + if (this.points != null) { + this.scene.remove(this.points); + this.points.geometry.dispose(); + this.points = null; + this.worldSpacePointPositions = null; } - - dispose() { - this.disposeGeometry(); - this.disposeTextureAtlas(); + } + private disposeTextureAtlas() { + if (this.texture != null) { + this.texture.dispose(); } - - private disposeGeometry() { - if (this.points != null) { - this.scene.remove(this.points); - this.points.geometry.dispose(); - this.points = null; - this.worldSpacePointPositions = null; - } + this.texture = null; + this.renderMaterial = null; + this.pickingMaterial = null; + } + setScene(scene: THREE.Scene) { + this.scene = scene; + } + setSpriteAtlas( + spriteImage: HTMLImageElement, + spriteDimensions: [number, number], + spriteIndices: Float32Array + ) { + this.disposeTextureAtlas(); + this.createTextureFromSpriteAtlas( + spriteImage, + spriteDimensions, + spriteIndices + ); + this.renderMaterial = this.createRenderMaterial(true); + this.pickingMaterial = this.createPickingMaterial(true); + } + clearSpriteAtlas() { + this.disposeTextureAtlas(); + this.renderMaterial = this.createRenderMaterial(false); + this.pickingMaterial = this.createPickingMaterial(false); + } + onPointPositionsChanged(newPositions: Float32Array) { + if (newPositions == null || newPositions.length === 0) { + this.dispose(); + return; } - - private disposeTextureAtlas() { - if (this.texture != null) { - this.texture.dispose(); + if (this.points != null) { + if (this.worldSpacePointPositions.length !== newPositions.length) { + this.disposeGeometry(); } - this.texture = null; - this.renderMaterial = null; - this.pickingMaterial = null; } - - setScene(scene: THREE.Scene) { - this.scene = scene; + this.worldSpacePointPositions = newPositions; + if (this.points == null) { + this.createPointSprites(this.scene, newPositions); } - - setSpriteAtlas( - spriteImage: HTMLImageElement, - spriteDimensions: [number, number], - spriteIndices: Float32Array - ) { - this.disposeTextureAtlas(); - this.createTextureFromSpriteAtlas( - spriteImage, - spriteDimensions, - spriteIndices - ); - this.renderMaterial = this.createRenderMaterial(true); - this.pickingMaterial = this.createPickingMaterial(true); - } - - clearSpriteAtlas() { - this.disposeTextureAtlas(); - this.renderMaterial = this.createRenderMaterial(false); - this.pickingMaterial = this.createPickingMaterial(false); - } - - onPointPositionsChanged(newPositions: Float32Array) { - if (newPositions == null || newPositions.length === 0) { - this.dispose(); - return; - } - if (this.points != null) { - if (this.worldSpacePointPositions.length !== newPositions.length) { - this.disposeGeometry(); - } - } - - this.worldSpacePointPositions = newPositions; - - if (this.points == null) { - this.createPointSprites(this.scene, newPositions); - } - - const positions = (this.points - .geometry as THREE.BufferGeometry).getAttribute( - 'position' - ) as THREE.BufferAttribute; - - (positions as any).setArray(newPositions); - positions.needsUpdate = true; - } - - onPickingRender(rc: RenderContext) { - if (this.points == null) { - return; - } - - const sceneIs3D: boolean = rc.cameraType === CameraType.Perspective; - - this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; - this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerColumn; - this.pickingMaterial.uniforms.sizeAttenuation.value = sceneIs3D; - this.pickingMaterial.uniforms.pointSize.value = this.calculatePointSize( - sceneIs3D - ); - this.points.material = this.pickingMaterial; - - let colors = (this.points.geometry as THREE.BufferGeometry).getAttribute( - 'color' - ) as THREE.BufferAttribute; - (colors as any).setArray(this.pickingColors); - colors.needsUpdate = true; - - let scaleFactors = (this.points - .geometry as THREE.BufferGeometry).getAttribute( - 'scaleFactor' - ) as THREE.BufferAttribute; - (scaleFactors as any).setArray(rc.pointScaleFactors); - scaleFactors.needsUpdate = true; + const positions = (this.points + .geometry as THREE.BufferGeometry).getAttribute( + 'position' + ) as THREE.BufferAttribute; + (positions as any).setArray(newPositions); + positions.needsUpdate = true; + } + onPickingRender(rc: RenderContext) { + if (this.points == null) { + return; } - - onRender(rc: RenderContext) { - if (!this.points) { - return; - } - const sceneIs3D: boolean = rc.camera instanceof THREE.PerspectiveCamera; - - this.setFogDistances( - sceneIs3D, - rc.nearestCameraSpacePointZ, - rc.farthestCameraSpacePointZ - ); - - this.scene.fog = this.fog; - this.scene.fog.color = new THREE.Color(rc.backgroundColor); - - this.renderMaterial.uniforms.fogColor.value = this.scene.fog.color; - this.renderMaterial.uniforms.fogNear.value = this.fog.near; - this.renderMaterial.uniforms.fogFar.value = this.fog.far; - this.renderMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; - this.renderMaterial.uniforms.spritesPerColumn.value = this.spritesPerColumn; - this.renderMaterial.uniforms.isImage.value = this.texture != null; - this.renderMaterial.uniforms.texture.value = - this.texture != null ? this.texture : this.standinTextureForPoints; - this.renderMaterial.uniforms.sizeAttenuation.value = sceneIs3D; - this.renderMaterial.uniforms.pointSize.value = this.calculatePointSize( - sceneIs3D - ); - this.points.material = this.renderMaterial; - - let colors = (this.points.geometry as THREE.BufferGeometry).getAttribute( - 'color' - ) as THREE.BufferAttribute; - this.renderColors = rc.pointColors; - (colors as any).setArray(this.renderColors); - colors.needsUpdate = true; - - let scaleFactors = (this.points - .geometry as THREE.BufferGeometry).getAttribute( - 'scaleFactor' - ) as THREE.BufferAttribute; - (scaleFactors as any).setArray(rc.pointScaleFactors); - scaleFactors.needsUpdate = true; + const sceneIs3D: boolean = rc.cameraType === CameraType.Perspective; + this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; + this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerColumn; + this.pickingMaterial.uniforms.sizeAttenuation.value = sceneIs3D; + this.pickingMaterial.uniforms.pointSize.value = this.calculatePointSize( + sceneIs3D + ); + this.points.material = this.pickingMaterial; + let colors = (this.points.geometry as THREE.BufferGeometry).getAttribute( + 'color' + ) as THREE.BufferAttribute; + (colors as any).setArray(this.pickingColors); + colors.needsUpdate = true; + let scaleFactors = (this.points + .geometry as THREE.BufferGeometry).getAttribute( + 'scaleFactor' + ) as THREE.BufferAttribute; + (scaleFactors as any).setArray(rc.pointScaleFactors); + scaleFactors.needsUpdate = true; + } + onRender(rc: RenderContext) { + if (!this.points) { + return; } - - onResize(newWidth: number, newHeight: number) {} + const sceneIs3D: boolean = rc.camera instanceof THREE.PerspectiveCamera; + this.setFogDistances( + sceneIs3D, + rc.nearestCameraSpacePointZ, + rc.farthestCameraSpacePointZ + ); + this.scene.fog = this.fog; + this.scene.fog.color = new THREE.Color(rc.backgroundColor); + this.renderMaterial.uniforms.fogColor.value = this.scene.fog.color; + this.renderMaterial.uniforms.fogNear.value = this.fog.near; + this.renderMaterial.uniforms.fogFar.value = this.fog.far; + this.renderMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; + this.renderMaterial.uniforms.spritesPerColumn.value = this.spritesPerColumn; + this.renderMaterial.uniforms.isImage.value = this.texture != null; + this.renderMaterial.uniforms.texture.value = + this.texture != null ? this.texture : this.standinTextureForPoints; + this.renderMaterial.uniforms.sizeAttenuation.value = sceneIs3D; + this.renderMaterial.uniforms.pointSize.value = this.calculatePointSize( + sceneIs3D + ); + this.points.material = this.renderMaterial; + let colors = (this.points.geometry as THREE.BufferGeometry).getAttribute( + 'color' + ) as THREE.BufferAttribute; + this.renderColors = rc.pointColors; + (colors as any).setArray(this.renderColors); + colors.needsUpdate = true; + let scaleFactors = (this.points + .geometry as THREE.BufferGeometry).getAttribute( + 'scaleFactor' + ) as THREE.BufferAttribute; + (scaleFactors as any).setArray(rc.pointScaleFactors); + scaleFactors.needsUpdate = true; } -} // namespace vz_projector + onResize(newWidth: number, newHeight: number) {} +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/sptree.ts b/tensorboard/plugins/projector/polymer3/vz_projector/sptree.ts index ffae5f4c82..57b015d0b2 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/sptree.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/sptree.ts @@ -12,167 +12,153 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - /** N-dimensional point. Usually 2D or 3D. */ - export type Point = number[]; - - export interface BBox { - center: Point; - halfDim: number; - } - - /** A node in a space-partitioning tree. */ - export interface SPNode { - /** The children of this node. */ - children?: SPNode[]; - /** The bounding box of the region this node occupies. */ - box: BBox; - /** One or more points this node has. */ - point: Point; - } +export type Point = number[]; +export interface BBox { + center: Point; + halfDim: number; +} +/** A node in a space-partitioning tree. */ +export interface SPNode { + /** The children of this node. */ + children?: SPNode[]; + /** The bounding box of the region this node occupies. */ + box: BBox; + /** One or more points this node has. */ + point: Point; +} +/** + * A Space-partitioning tree (https://en.wikipedia.org/wiki/Space_partitioning) + * that recursively divides the space into regions of equal sizes. This data + * structure can act both as a Quad tree and an Octree when the data is 2 or + * 3 dimensional respectively. One usage is in t-SNE in order to do Barnes-Hut + * approximation. + */ +export class SPTree { + root: SPNode; + private masks: number[]; + private dim: number; /** - * A Space-partitioning tree (https://en.wikipedia.org/wiki/Space_partitioning) - * that recursively divides the space into regions of equal sizes. This data - * structure can act both as a Quad tree and an Octree when the data is 2 or - * 3 dimensional respectively. One usage is in t-SNE in order to do Barnes-Hut - * approximation. + * Constructs a new tree with the provided data. + * + * @param data List of n-dimensional data points. + * @param capacity Number of data points to store in a single node. */ - export class SPTree { - root: SPNode; - - private masks: number[]; - private dim: number; - - /** - * Constructs a new tree with the provided data. - * - * @param data List of n-dimensional data points. - * @param capacity Number of data points to store in a single node. - */ - constructor(data: Point[]) { - if (data.length < 1) { - throw new Error('There should be at least 1 data point'); - } - // Make a bounding box based on the extent of the data. - this.dim = data[0].length; - // Each node has 2^d children, where d is the dimension of the space. - // Binary masks (e.g. 000, 001, ... 111 in 3D) are used to determine in - // which child (e.g. quadron in 2D) the new point is going to be assigned. - // For more details, see the insert() method and its comments. - this.masks = new Array(Math.pow(2, this.dim)); - for (let d = 0; d < this.masks.length; ++d) { - this.masks[d] = 1 << d; - } - let min: Point = new Array(this.dim); - fillArray(min, Number.POSITIVE_INFINITY); - let max: Point = new Array(this.dim); - fillArray(max, Number.NEGATIVE_INFINITY); - - for (let i = 0; i < data.length; ++i) { - // For each dim get the min and max. - // E.g. For 2-D, get the x_min, x_max, y_min, y_max. - for (let d = 0; d < this.dim; ++d) { - min[d] = Math.min(min[d], data[i][d]); - max[d] = Math.max(max[d], data[i][d]); - } - } - // Create a bounding box with the center of the largest span. - let center: Point = new Array(this.dim); - let halfDim = 0; + constructor(data: Point[]) { + if (data.length < 1) { + throw new Error('There should be at least 1 data point'); + } + // Make a bounding box based on the extent of the data. + this.dim = data[0].length; + // Each node has 2^d children, where d is the dimension of the space. + // Binary masks (e.g. 000, 001, ... 111 in 3D) are used to determine in + // which child (e.g. quadron in 2D) the new point is going to be assigned. + // For more details, see the insert() method and its comments. + this.masks = new Array(Math.pow(2, this.dim)); + for (let d = 0; d < this.masks.length; ++d) { + this.masks[d] = 1 << d; + } + let min: Point = new Array(this.dim); + fillArray(min, Number.POSITIVE_INFINITY); + let max: Point = new Array(this.dim); + fillArray(max, Number.NEGATIVE_INFINITY); + for (let i = 0; i < data.length; ++i) { + // For each dim get the min and max. + // E.g. For 2-D, get the x_min, x_max, y_min, y_max. for (let d = 0; d < this.dim; ++d) { - let span = max[d] - min[d]; - center[d] = min[d] + span / 2; - halfDim = Math.max(halfDim, span / 2); - } - this.root = {box: {center: center, halfDim: halfDim}, point: data[0]}; - for (let i = 1; i < data.length; ++i) { - this.insert(this.root, data[i]); + min[d] = Math.min(min[d], data[i][d]); + max[d] = Math.max(max[d], data[i][d]); } } - - /** - * Visits every node in the tree. Each node can store 1 or more points, - * depending on the node capacity provided in the constructor. - * - * @param accessor Method that takes the currently visited node, and the - * low and high point of the region that this node occupies. E.g. in 2D, - * the low and high points will be the lower-left corner and the upper-right - * corner. - */ - visit( - accessor: (node: SPNode, lowPoint: Point, highPoint: Point) => boolean, - noBox = false - ) { - this.visitNode(this.root, accessor, noBox); + // Create a bounding box with the center of the largest span. + let center: Point = new Array(this.dim); + let halfDim = 0; + for (let d = 0; d < this.dim; ++d) { + let span = max[d] - min[d]; + center[d] = min[d] + span / 2; + halfDim = Math.max(halfDim, span / 2); } - - private visitNode( - node: SPNode, - accessor: (node: SPNode, lowPoint?: Point, highPoint?: Point) => boolean, - noBox: boolean - ) { - let skipChildren: boolean; - if (noBox) { - skipChildren = accessor(node); - } else { - let lowPoint = new Array(this.dim); - let highPoint = new Array(this.dim); - for (let d = 0; d < this.dim; ++d) { - lowPoint[d] = node.box.center[d] - node.box.halfDim; - highPoint[d] = node.box.center[d] + node.box.halfDim; - } - skipChildren = accessor(node, lowPoint, highPoint); - } - if (!node.children || skipChildren) { - return; - } - for (let i = 0; i < node.children.length; ++i) { - let child = node.children[i]; - if (child) { - this.visitNode(child, accessor, noBox); - } - } + this.root = {box: {center: center, halfDim: halfDim}, point: data[0]}; + for (let i = 1; i < data.length; ++i) { + this.insert(this.root, data[i]); } - - private insert(node: SPNode, p: Point) { - // Subdivide and then add the point to whichever node will accept it. - if (node.children == null) { - node.children = new Array(this.masks.length); - } - - // Decide which child will get the new point by constructing a D-bits binary - // signature (D=3 for 3D) where the k-th bit is 1 if the point's k-th - // coordinate is greater than the node's k-th coordinate, 0 otherwise. - // Then the binary signature in decimal system gives us the index of the - // child where the new point should be. - let index = 0; + } + /** + * Visits every node in the tree. Each node can store 1 or more points, + * depending on the node capacity provided in the constructor. + * + * @param accessor Method that takes the currently visited node, and the + * low and high point of the region that this node occupies. E.g. in 2D, + * the low and high points will be the lower-left corner and the upper-right + * corner. + */ + visit( + accessor: (node: SPNode, lowPoint: Point, highPoint: Point) => boolean, + noBox = false + ) { + this.visitNode(this.root, accessor, noBox); + } + private visitNode( + node: SPNode, + accessor: (node: SPNode, lowPoint?: Point, highPoint?: Point) => boolean, + noBox: boolean + ) { + let skipChildren: boolean; + if (noBox) { + skipChildren = accessor(node); + } else { + let lowPoint = new Array(this.dim); + let highPoint = new Array(this.dim); for (let d = 0; d < this.dim; ++d) { - if (p[d] > node.box.center[d]) { - index |= this.masks[d]; - } + lowPoint[d] = node.box.center[d] - node.box.halfDim; + highPoint[d] = node.box.center[d] + node.box.halfDim; } - if (node.children[index] == null) { - this.makeChild(node, index, p); - } else { - this.insert(node.children[index], p); + skipChildren = accessor(node, lowPoint, highPoint); + } + if (!node.children || skipChildren) { + return; + } + for (let i = 0; i < node.children.length; ++i) { + let child = node.children[i]; + if (child) { + this.visitNode(child, accessor, noBox); } } - - private makeChild(node: SPNode, index: number, p: Point): void { - let oldC = node.box.center; - let h = node.box.halfDim / 2; - let newC: Point = new Array(this.dim); - for (let d = 0; d < this.dim; ++d) { - newC[d] = index & (1 << d) ? oldC[d] + h : oldC[d] - h; + } + private insert(node: SPNode, p: Point) { + // Subdivide and then add the point to whichever node will accept it. + if (node.children == null) { + node.children = new Array(this.masks.length); + } + // Decide which child will get the new point by constructing a D-bits binary + // signature (D=3 for 3D) where the k-th bit is 1 if the point's k-th + // coordinate is greater than the node's k-th coordinate, 0 otherwise. + // Then the binary signature in decimal system gives us the index of the + // child where the new point should be. + let index = 0; + for (let d = 0; d < this.dim; ++d) { + if (p[d] > node.box.center[d]) { + index |= this.masks[d]; } - node.children[index] = {box: {center: newC, halfDim: h}, point: p}; + } + if (node.children[index] == null) { + this.makeChild(node, index, p); + } else { + this.insert(node.children[index], p); } } - - function fillArray(arr: T[], value: T): void { - for (let i = 0; i < arr.length; ++i) { - arr[i] = value; + private makeChild(node: SPNode, index: number, p: Point): void { + let oldC = node.box.center; + let h = node.box.halfDim / 2; + let newC: Point = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + newC[d] = index & (1 << d) ? oldC[d] + h : oldC[d] - h; } + node.children[index] = {box: {center: newC, halfDim: h}, point: p}; + } +} +function fillArray(arr: T[], value: T): void { + for (let i = 0; i < arr.length; ++i) { + arr[i] = value; } -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/standalone.html b/tensorboard/plugins/projector/polymer3/vz_projector/standalone_lib.html similarity index 78% rename from tensorboard/plugins/projector/polymer3/vz_projector/standalone.html rename to tensorboard/plugins/projector/polymer3/vz_projector/standalone_lib.html index 77cc2bd5b0..137dd76679 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/standalone.html +++ b/tensorboard/plugins/projector/polymer3/vz_projector/standalone_lib.html @@ -28,14 +28,10 @@ - - - - - + + Embedding projector - visualization of high-dimensional data - - - - - - diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/styles.ts b/tensorboard/plugins/projector/polymer3/vz_projector/styles.ts new file mode 100644 index 0000000000..0b31f50d43 --- /dev/null +++ b/tensorboard/plugins/projector/polymer3/vz_projector/styles.ts @@ -0,0 +1,182 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {registerStyleDomModule} from '../../../../components_polymer3/polymer/register_style_dom_module'; + +registerStyleDomModule({ + moduleName: 'vz-projector-styles', + styleContent: ` + :host { + --paper-input-container-label: { + font-size: 14px; + } + --paper-input-container-input: { + font-size: 14px; + } + /* TODO: Figure out why this doesn't work */ + --paper-dropdown-menu-input: { + font-size: 14px; + } + } + + paper-button { + background: #e3e3e3; + margin-left: 0; + text-transform: none; + } + + paper-dropdown-menu paper-item { + font-size: 13px; + } + + paper-tooltip { + max-width: 200px; + --paper-tooltip: { + font-size: 12px; + } + } + + paper-checkbox { + --paper-checkbox-checked-color: #880e4f; + } + + paper-toggle-button { + --paper-toggle-button-checked-bar-color: #880e4f; + --paper-toggle-button-checked-button-color: #880e4f; + --paper-toggle-button-checked-ink-color: #880e4f; + } + + paper-icon-button { + border-radius: 50%; + } + + paper-icon-button[active] { + color: white; + background-color: #880e4f; + } + + .slider { + display: flex; + align-items: center; + margin-bottom: 10px; + justify-content: space-between; + } + + .slider span { + width: 35px; + text-align: right; + } + + .slider label { + align-items: center; + display: flex; + } + + .help-icon { + height: 15px; + left: 2px; + min-width: 15px; + min-height: 15px; + margin: 0; + padding: 0; + top: -2px; + width: 15px; + } + + .ink-panel { + display: flex; + flex-direction: column; + font-size: 14px; + } + + .ink-panel h4 { + border-bottom: 1px solid #ddd; + font-size: 14px; + font-weight: 500; + margin: 0; + margin-bottom: 10px; + padding-bottom: 5px; + } + + .ink-panel-header { + border-bottom: 1px solid rgba(0, 0, 0, 0.1); + border-top: 1px solid rgba(0, 0, 0, 0.1); + height: 50px; + } + + .ink-panel-content { + display: none; + height: 100%; + } + + .ink-panel-content.active { + display: block; + } + + .ink-panel-content h3 { + font-weight: 500; + font-size: 14px; + margin-top: 20px; + margin-bottom: 5px; + text-transform: uppercase; + } + + .ink-panel-header h3 { + font-weight: 500; + font-size: 14px; + margin: 0; + padding: 0 24px; + text-transform: uppercase; + } + + /* - Tabs */ + .ink-tab-group { + align-items: center; + box-sizing: border-box; + display: flex; + height: 100%; + justify-content: space-around; + } + + .ink-tab-group .projection-tab { + color: rgba(0, 0, 0, 0.5); + cursor: pointer; + font-weight: 300; + line-height: 49px; + padding: 0 12px; + text-align: center; + text-transform: uppercase; + } + + .ink-tab-group .projection-tab:hover { + color: black; + } + + .ink-tab-group .projection-tab.active { + border-bottom: 2px solid black; + color: black; + font-weight: 500; + } + + h4 { + margin: 30px 0 10px 0; + } + + .dismiss-dialog-note { + margin-top: 25px; + font-size: 11px; + text-align: right; + } + `, +}); diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/test/BUILD b/tensorboard/plugins/projector/polymer3/vz_projector/test/BUILD index 84dd2ff73d..6162489914 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/test/BUILD +++ b/tensorboard/plugins/projector/polymer3/vz_projector/test/BUILD @@ -10,6 +10,7 @@ licenses(["notice"]) # Apache 2.0 tf_web_test( name = "test", src = "/vz-projector/test/tests.html", + tags = ["manual"], web_library = ":test_web_library", ) diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/util.ts b/tensorboard/plugins/projector/polymer3/vz_projector/util.ts index 8fb8203994..7c35c26eb1 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/util.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/util.ts @@ -12,259 +12,244 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector.util { - /** - * Delay for running expensive tasks, in milliseconds. - * The duration was empirically found so that it leaves enough time for the - * browser to update its UI state before starting an expensive UI-blocking task. - */ - const TASK_DELAY_MS = 200; - - /** Shuffles the array in-place in O(n) time using Fisher-Yates algorithm. */ - export function shuffle(array: T[]): T[] { - let m = array.length; - let t: T; - let i: number; - - // While there remain elements to shuffle. - while (m) { - // Pick a remaining element - i = Math.floor(Math.random() * m--); - // And swap it with the current element. - t = array[m]; - array[m] = array[i]; - array[i] = t; - } - return array; +import * as THREE from 'three'; +import weblas from 'weblas/dist/weblas'; + +import {DataPoint} from './data'; +import * as vector from './vector'; +import * as logging from './logging'; + +const TASK_DELAY_MS = 200; +/** Shuffles the array in-place in O(n) time using Fisher-Yates algorithm. */ +export function shuffle(array: T[]): T[] { + let m = array.length; + let t: T; + let i: number; + // While there remain elements to shuffle. + while (m) { + // Pick a remaining element + i = Math.floor(Math.random() * m--); + // And swap it with the current element. + t = array[m]; + array[m] = array[i]; + array[i] = t; } - - export function range(count: number): number[] { - const rangeOutput: number[] = []; - for (let i = 0; i < count; i++) { - rangeOutput.push(i); - } - return rangeOutput; + return array; +} +export function range(count: number): number[] { + const rangeOutput: number[] = []; + for (let i = 0; i < count; i++) { + rangeOutput.push(i); } - - export function classed( - element: HTMLElement, - className: string, - enabled: boolean - ) { - const classNames = element.className.split(' '); - if (enabled) { - if (className in classNames) { - return; - } else { - classNames.push(className); - } + return rangeOutput; +} +export function classed( + element: HTMLElement, + className: string, + enabled: boolean +) { + const classNames = element.className.split(' '); + if (enabled) { + if (className in classNames) { + return; } else { - const index = classNames.indexOf(className); - if (index === -1) { - return; - } - classNames.splice(index, 1); + classNames.push(className); } - element.className = classNames.join(' '); - } - - /** Projects a 3d point into screen space */ - export function vector3DToScreenCoords( - cam: THREE.Camera, - w: number, - h: number, - v: THREE.Vector3 - ): vector.Point2D { - let dpr = window.devicePixelRatio; - let pv = new THREE.Vector3().copy(v).project(cam); - - // The screen-space origin is at the middle of the screen, with +y up. - let coords: vector.Point2D = [ - ((pv.x + 1) / 2) * w * dpr, - -(((pv.y - 1) / 2) * h) * dpr, - ]; - return coords; - } - - /** Loads 3 contiguous elements from a packed xyz array into a Vector3. */ - export function vector3FromPackedArray( - a: Float32Array, - pointIndex: number - ): THREE.Vector3 { - const offset = pointIndex * 3; - return new THREE.Vector3(a[offset], a[offset + 1], a[offset + 2]); - } - - /** - * Gets the camera-space z coordinates of the nearest and farthest points. - * Ignores points that are behind the camera. - */ - export function getNearFarPoints( - worldSpacePoints: Float32Array, - cameraPos: THREE.Vector3, - cameraTarget: THREE.Vector3 - ): [number, number] { - let shortestDist: number = Infinity; - let furthestDist: number = 0; - const camToTarget = new THREE.Vector3().copy(cameraTarget).sub(cameraPos); - const camPlaneNormal = new THREE.Vector3().copy(camToTarget).normalize(); - const n = worldSpacePoints.length / 3; - let src = 0; - let p = new THREE.Vector3(); - let camToPoint = new THREE.Vector3(); - for (let i = 0; i < n; i++) { - p.x = worldSpacePoints[src]; - p.y = worldSpacePoints[src + 1]; - p.z = worldSpacePoints[src + 2]; - src += 3; - - camToPoint.copy(p).sub(cameraPos); - const dist = camPlaneNormal.dot(camToPoint); - if (dist < 0) { - continue; - } - furthestDist = dist > furthestDist ? dist : furthestDist; - shortestDist = dist < shortestDist ? dist : shortestDist; + } else { + const index = classNames.indexOf(className); + if (index === -1) { + return; } - return [shortestDist, furthestDist]; + classNames.splice(index, 1); } - - /** - * Generate a texture for the points/images and sets some initial params - */ - export function createTexture( - image: HTMLImageElement | HTMLCanvasElement - ): THREE.Texture { - let tex = new THREE.Texture(image); - tex.needsUpdate = true; - // Used if the texture isn't a power of 2. - tex.minFilter = THREE.LinearFilter; - tex.generateMipmaps = false; - tex.flipY = false; - return tex; - } - - /** - * Assert that the condition is satisfied; if not, log user-specified message - * to the console. - */ - export function assert(condition: boolean, message?: string) { - if (!condition) { - message = message || 'Assertion failed'; - throw new Error(message); + element.className = classNames.join(' '); +} +/** Projects a 3d point into screen space */ +export function vector3DToScreenCoords( + cam: THREE.Camera, + w: number, + h: number, + v: THREE.Vector3 +): vector.Point2D { + let dpr = window.devicePixelRatio; + let pv = new THREE.Vector3().copy(v).project(cam); + // The screen-space origin is at the middle of the screen, with +y up. + let coords: vector.Point2D = [ + ((pv.x + 1) / 2) * w * dpr, + -(((pv.y - 1) / 2) * h) * dpr, + ]; + return coords; +} +/** Loads 3 contiguous elements from a packed xyz array into a Vector3. */ +export function vector3FromPackedArray( + a: Float32Array, + pointIndex: number +): THREE.Vector3 { + const offset = pointIndex * 3; + return new THREE.Vector3(a[offset], a[offset + 1], a[offset + 2]); +} +/** + * Gets the camera-space z coordinates of the nearest and farthest points. + * Ignores points that are behind the camera. + */ +export function getNearFarPoints( + worldSpacePoints: Float32Array, + cameraPos: THREE.Vector3, + cameraTarget: THREE.Vector3 +): [number, number] { + let shortestDist: number = Infinity; + let furthestDist: number = 0; + const camToTarget = new THREE.Vector3().copy(cameraTarget).sub(cameraPos); + const camPlaneNormal = new THREE.Vector3().copy(camToTarget).normalize(); + const n = worldSpacePoints.length / 3; + let src = 0; + let p = new THREE.Vector3(); + let camToPoint = new THREE.Vector3(); + for (let i = 0; i < n; i++) { + p.x = worldSpacePoints[src]; + p.y = worldSpacePoints[src + 1]; + p.z = worldSpacePoints[src + 2]; + src += 3; + camToPoint.copy(p).sub(cameraPos); + const dist = camPlaneNormal.dot(camToPoint); + if (dist < 0) { + continue; } + furthestDist = dist > furthestDist ? dist : furthestDist; + shortestDist = dist < shortestDist ? dist : shortestDist; } - - export type SearchPredicate = (p: DataPoint) => boolean; - - export function getSearchPredicate( - query: string, - inRegexMode: boolean, - fieldName: string - ): SearchPredicate { - let predicate: SearchPredicate; - if (inRegexMode) { - let regExp = new RegExp(query, 'i'); - predicate = (p) => regExp.test(p.metadata[fieldName].toString()); - } else { - // Doing a case insensitive substring match. - query = query.toLowerCase(); - predicate = (p) => { - let label = p.metadata[fieldName].toString().toLowerCase(); - return label.indexOf(query) >= 0; - }; - } - return predicate; - } - - /** - * Runs an expensive task asynchronously with some delay - * so that it doesn't block the UI thread immediately. - * - * @param message The message to display to the user. - * @param task The expensive task to run. - * @param msgId Optional. ID of an existing message. If provided, will overwrite - * an existing message and won't automatically clear the message when the - * task is done. - * @return The value returned by the task. - */ - export function runAsyncTask( - message: string, - task: () => T, - msgId: string = null, - taskDelay = TASK_DELAY_MS - ): Promise { - let autoClear = msgId == null; - msgId = logging.setModalMessage(message, msgId); - return new Promise((resolve, reject) => { - setTimeout(() => { - try { - let result = task(); - // Clearing the old message. - if (autoClear) { - logging.setModalMessage(null, msgId); - } - resolve(result); - } catch (ex) { - reject(ex); - } - return true; - }, taskDelay); - }); + return [shortestDist, furthestDist]; +} +/** + * Generate a texture for the points/images and sets some initial params + */ +export function createTexture( + image: HTMLImageElement | HTMLCanvasElement +): THREE.Texture { + let tex = new THREE.Texture(image); + tex.needsUpdate = true; + // Used if the texture isn't a power of 2. + tex.minFilter = THREE.LinearFilter; + tex.generateMipmaps = false; + tex.flipY = false; + return tex; +} +/** + * Assert that the condition is satisfied; if not, log user-specified message + * to the console. + */ +export function assert(condition: boolean, message?: string) { + if (!condition) { + message = message || 'Assertion failed'; + throw new Error(message); } - - /** - * Parses the URL for query parameters, e.g. ?foo=1&bar=2 will return - * {'foo': '1', 'bar': '2'}. - * @param url The URL to parse. - * @return A map of queryParam key to its value. - */ - export function getURLParams(url: string): {[key: string]: string} { - if (!url) { - return {}; - } - - let queryString = url.indexOf('?') !== -1 ? url.split('?')[1] : url; - if (queryString.indexOf('#')) { - queryString = queryString.split('#')[0]; - } - - const queryEntries = queryString.split('&'); - let queryParams: {[key: string]: string} = {}; - for (let i = 0; i < queryEntries.length; i++) { - let queryEntryComponents = queryEntries[i].split('='); - queryParams[queryEntryComponents[0].toLowerCase()] = decodeURIComponent( - queryEntryComponents[1] - ); - } - return queryParams; +} +export type SearchPredicate = (p: DataPoint) => boolean; +export function getSearchPredicate( + query: string, + inRegexMode: boolean, + fieldName: string +): SearchPredicate { + let predicate: SearchPredicate; + if (inRegexMode) { + let regExp = new RegExp(query, 'i'); + predicate = (p) => regExp.test(p.metadata[fieldName].toString()); + } else { + // Doing a case insensitive substring match. + query = query.toLowerCase(); + predicate = (p) => { + let label = p.metadata[fieldName].toString().toLowerCase(); + return label.indexOf(query) >= 0; + }; } - - /** List of substrings that auto generated tensors have in their name. */ - const SUBSTR_GEN_TENSORS = ['/Adagrad']; - - /** Returns true if the tensor was automatically generated by TF API calls. */ - export function tensorIsGenerated(tensorName: string): boolean { - for (let i = 0; i < SUBSTR_GEN_TENSORS.length; i++) { - if (tensorName.indexOf(SUBSTR_GEN_TENSORS[i]) >= 0) { - return true; + return predicate; +} +/** + * Runs an expensive task asynchronously with some delay + * so that it doesn't block the UI thread immediately. + * + * @param message The message to display to the user. + * @param task The expensive task to run. + * @param msgId Optional. ID of an existing message. If provided, will overwrite + * an existing message and won't automatically clear the message when the + * task is done. + * @return The value returned by the task. + */ +export function runAsyncTask( + message: string, + task: () => T, + msgId: string = null, + taskDelay = TASK_DELAY_MS +): Promise { + let autoClear = msgId == null; + msgId = logging.setModalMessage(message, msgId); + return new Promise((resolve, reject) => { + setTimeout(() => { + try { + let result = task(); + // Clearing the old message. + if (autoClear) { + logging.setModalMessage(null, msgId); + } + resolve(result); + } catch (ex) { + reject(ex); } - } - return false; + return true; + }, taskDelay); + }); +} +/** + * Parses the URL for query parameters, e.g. ?foo=1&bar=2 will return + * {'foo': '1', 'bar': '2'}. + * @param url The URL to parse. + * @return A map of queryParam key to its value. + */ +export function getURLParams( + url: string +): { + [key: string]: string; +} { + if (!url) { + return {}; } - - export function xor(cond1: boolean, cond2: boolean): boolean { - return (cond1 || cond2) && !(cond1 && cond2); + let queryString = url.indexOf('?') !== -1 ? url.split('?')[1] : url; + if (queryString.indexOf('#')) { + queryString = queryString.split('#')[0]; } - - /** Checks to see if the browser supports webgl. */ - export function hasWebGLSupport(): boolean { - try { - let c = document.createElement('canvas'); - let gl = c.getContext('webgl') || c.getContext('experimental-webgl'); - return gl != null && typeof weblas !== 'undefined'; - } catch (e) { - return false; + const queryEntries = queryString.split('&'); + let queryParams: { + [key: string]: string; + } = {}; + for (let i = 0; i < queryEntries.length; i++) { + let queryEntryComponents = queryEntries[i].split('='); + queryParams[queryEntryComponents[0].toLowerCase()] = decodeURIComponent( + queryEntryComponents[1] + ); + } + return queryParams; +} +/** List of substrings that auto generated tensors have in their name. */ +const SUBSTR_GEN_TENSORS = ['/Adagrad']; +/** Returns true if the tensor was automatically generated by TF API calls. */ +export function tensorIsGenerated(tensorName: string): boolean { + for (let i = 0; i < SUBSTR_GEN_TENSORS.length; i++) { + if (tensorName.indexOf(SUBSTR_GEN_TENSORS[i]) >= 0) { + return true; } } -} // namespace vz_projector.util + return false; +} +export function xor(cond1: boolean, cond2: boolean): boolean { + return (cond1 || cond2) && !(cond1 && cond2); +} +/** Checks to see if the browser supports webgl. */ +export function hasWebGLSupport(): boolean { + try { + let c = document.createElement('canvas'); + let gl = c.getContext('webgl') || c.getContext('experimental-webgl'); + return gl != null && typeof weblas !== 'undefined'; + } catch (e) { + return false; + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vector.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vector.ts index c66d519680..37cad98b30 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vector.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vector.ts @@ -12,277 +12,238 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector.vector { - /** - * @fileoverview Useful vector utilities. - */ - - export type Vector = Float32Array | number[]; - export type Point2D = [number, number]; - export type Point3D = [number, number, number]; - - /** Returns the dot product of two vectors. */ - export function dot(a: Vector, b: Vector): number { - util.assert( - a.length === b.length, - 'Vectors a and b must be of same length' - ); - let result = 0; - for (let i = 0; i < a.length; ++i) { - result += a[i] * b[i]; - } - return result; - } - - /** Sums all the elements in the vector */ - export function sum(a: Vector): number { - let result = 0; - for (let i = 0; i < a.length; ++i) { - result += a[i]; - } - return result; - } - - /** Returns the sum of two vectors, i.e. a + b */ - export function add(a: Vector, b: Vector): Float32Array { - util.assert( - a.length === b.length, - 'Vectors a and b must be of same length' - ); - let result = new Float32Array(a.length); - for (let i = 0; i < a.length; ++i) { - result[i] = a[i] + b[i]; - } - return result; - } - - /** Subtracts vector b from vector a, i.e. returns a - b */ - export function sub(a: Vector, b: Vector): Float32Array { - util.assert( - a.length === b.length, - 'Vectors a and b must be of same length' - ); - let result = new Float32Array(a.length); - for (let i = 0; i < a.length; ++i) { - result[i] = a[i] - b[i]; - } - return result; - } - - /** Returns the square norm of the vector */ - export function norm2(a: Vector): number { - let result = 0; - for (let i = 0; i < a.length; ++i) { - result += a[i] * a[i]; - } - return result; - } - - /** Returns the euclidean distance between two vectors. */ - export function dist(a: Vector, b: Vector): number { - return Math.sqrt(dist2(a, b)); - } - - /** Returns the square euclidean distance between two vectors. */ - export function dist2(a: Vector, b: Vector): number { - util.assert( - a.length === b.length, - 'Vectors a and b must be of same length' - ); - let result = 0; - for (let i = 0; i < a.length; ++i) { - let diff = a[i] - b[i]; - result += diff * diff; - } - return result; - } - - /** Returns the square euclidean distance between two 2D points. */ - export function dist2_2D(a: Vector, b: Vector): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - return dX * dX + dY * dY; - } - - /** Returns the square euclidean distance between two 3D points. */ - export function dist2_3D(a: Vector, b: Vector): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - let dZ = a[2] - b[2]; - return dX * dX + dY * dY + dZ * dZ; - } - - /** Returns the euclidean distance between 2 3D points. */ - export function dist_3D(a: Vector, b: Vector): number { - return Math.sqrt(dist2_3D(a, b)); - } - - /** - * Returns the square euclidean distance between two vectors, with an early - * exit (returns -1) if the distance is >= to the provided limit. - */ - export function dist2WithLimit(a: Vector, b: Vector, limit: number): number { - util.assert( - a.length === b.length, - 'Vectors a and b must be of same length' - ); - let result = 0; - for (let i = 0; i < a.length; ++i) { - let diff = a[i] - b[i]; - result += diff * diff; - if (result >= limit) { - return -1; - } - } - return result; - } - - /** Returns the square euclidean distance between two 2D points. */ - export function dist22D(a: Point2D, b: Point2D): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - return dX * dX + dY * dY; - } - - /** Modifies the vector in-place to have unit norm. */ - export function unit(a: Vector): void { - let norm = Math.sqrt(norm2(a)); - util.assert(norm >= 0, 'Norm of the vector must be > 0'); - for (let i = 0; i < a.length; ++i) { - a[i] /= norm; - } - } - - /** - * Projects the vectors to a lower dimension - * - * @param vectors Array of vectors to be projected. - * @param newDim The resulting dimension of the vectors. - */ - export function projectRandom( - vectors: Float32Array[], - newDim: number - ): Float32Array[] { - let dim = vectors[0].length; - let N = vectors.length; - let newVectors: Float32Array[] = new Array(N); +import * as d3 from 'd3'; + +import * as util from './util'; + +export type Vector = Float32Array | number[]; +export type Point2D = [number, number]; +export type Point3D = [number, number, number]; +/** Returns the dot product of two vectors. */ +export function dot(a: Vector, b: Vector): number { + util.assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i] * b[i]; + } + return result; +} +/** Sums all the elements in the vector */ +export function sum(a: Vector): number { + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i]; + } + return result; +} +/** Returns the sum of two vectors, i.e. a + b */ +export function add(a: Vector, b: Vector): Float32Array { + util.assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = new Float32Array(a.length); + for (let i = 0; i < a.length; ++i) { + result[i] = a[i] + b[i]; + } + return result; +} +/** Subtracts vector b from vector a, i.e. returns a - b */ +export function sub(a: Vector, b: Vector): Float32Array { + util.assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = new Float32Array(a.length); + for (let i = 0; i < a.length; ++i) { + result[i] = a[i] - b[i]; + } + return result; +} +/** Returns the square norm of the vector */ +export function norm2(a: Vector): number { + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i] * a[i]; + } + return result; +} +/** Returns the euclidean distance between two vectors. */ +export function dist(a: Vector, b: Vector): number { + return Math.sqrt(dist2(a, b)); +} +/** Returns the square euclidean distance between two vectors. */ +export function dist2(a: Vector, b: Vector): number { + util.assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; + } + return result; +} +/** Returns the square euclidean distance between two 2D points. */ +export function dist2_2D(a: Vector, b: Vector): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} +/** Returns the square euclidean distance between two 3D points. */ +export function dist2_3D(a: Vector, b: Vector): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + let dZ = a[2] - b[2]; + return dX * dX + dY * dY + dZ * dZ; +} +/** Returns the euclidean distance between 2 3D points. */ +export function dist_3D(a: Vector, b: Vector): number { + return Math.sqrt(dist2_3D(a, b)); +} +/** + * Returns the square euclidean distance between two vectors, with an early + * exit (returns -1) if the distance is >= to the provided limit. + */ +export function dist2WithLimit(a: Vector, b: Vector, limit: number): number { + util.assert(a.length === b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; + if (result >= limit) { + return -1; + } + } + return result; +} +/** Returns the square euclidean distance between two 2D points. */ +export function dist22D(a: Point2D, b: Point2D): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} +/** Modifies the vector in-place to have unit norm. */ +export function unit(a: Vector): void { + let norm = Math.sqrt(norm2(a)); + util.assert(norm >= 0, 'Norm of the vector must be > 0'); + for (let i = 0; i < a.length; ++i) { + a[i] /= norm; + } +} +/** + * Projects the vectors to a lower dimension + * + * @param vectors Array of vectors to be projected. + * @param newDim The resulting dimension of the vectors. + */ +export function projectRandom( + vectors: Float32Array[], + newDim: number +): Float32Array[] { + let dim = vectors[0].length; + let N = vectors.length; + let newVectors: Float32Array[] = new Array(N); + for (let i = 0; i < N; ++i) { + newVectors[i] = new Float32Array(newDim); + } + // Make nDim projections. + for (let k = 0; k < newDim; ++k) { + let randomVector = rn(dim); for (let i = 0; i < N; ++i) { - newVectors[i] = new Float32Array(newDim); - } - // Make nDim projections. - for (let k = 0; k < newDim; ++k) { - let randomVector = rn(dim); - for (let i = 0; i < N; ++i) { - newVectors[i][k] = dot(vectors[i], randomVector); - } - } - return newVectors; - } - - /** - * Projects a vector onto a 2D plane specified by the two direction vectors. - */ - export function project2d(a: Vector, dir1: Vector, dir2: Vector): Point2D { - return [dot(a, dir1), dot(a, dir2)]; - } - - /** - * Computes the centroid of the data points. If the provided data points are not - * vectors, an accessor function needs to be provided. - */ - export function centroid( - dataPoints: T[], - accessor?: (a: T) => Vector - ): Vector { - if (dataPoints.length === 0) { - return null; - } - if (accessor == null) { - accessor = (a: T) => a; - } - util.assert(dataPoints.length >= 0, '`vectors` must be of length >= 1'); - let centroid = new Float32Array(accessor(dataPoints[0]).length); - for (let i = 0; i < dataPoints.length; ++i) { - let dataPoint = dataPoints[i]; - let vector = accessor(dataPoint); - for (let j = 0; j < centroid.length; ++j) { - centroid[j] += vector[j]; - } - } + newVectors[i][k] = dot(vectors[i], randomVector); + } + } + return newVectors; +} +/** + * Projects a vector onto a 2D plane specified by the two direction vectors. + */ +export function project2d(a: Vector, dir1: Vector, dir2: Vector): Point2D { + return [dot(a, dir1), dot(a, dir2)]; +} +/** + * Computes the centroid of the data points. If the provided data points are not + * vectors, an accessor function needs to be provided. + */ +export function centroid( + dataPoints: T[], + accessor?: (a: T) => Vector +): Vector { + if (dataPoints.length === 0) { + return null; + } + if (accessor == null) { + accessor = (a: T) => a; + } + util.assert(dataPoints.length >= 0, '`vectors` must be of length >= 1'); + let centroid = new Float32Array(accessor(dataPoints[0]).length); + for (let i = 0; i < dataPoints.length; ++i) { + let dataPoint = dataPoints[i]; + let vector = accessor(dataPoint); for (let j = 0; j < centroid.length; ++j) { - centroid[j] /= dataPoints.length; - } - return centroid; - } - - /** - * Generates a vector of the specified size where each component is drawn from - * a random (0, 1) gaussian distribution. - */ - export function rn(size: number): Float32Array { - const normal = d3.randomNormal(); - let result = new Float32Array(size); - for (let i = 0; i < size; ++i) { - result[i] = normal(); - } - return result; - } - - /** - * Returns the cosine distance ([0, 2]) between two vectors - * that have been normalized to unit norm. - */ - export function cosDistNorm(a: Vector, b: Vector): number { - return 1 - dot(a, b); - } - - /** - * Returns the cosine distance ([0, 2]) between two vectors. - */ - export function cosDist(a: Vector, b: Vector): number { - return 1 - cosSim(a, b); - } - - /** Returns the cosine similarity ([-1, 1]) between two vectors. */ - export function cosSim(a: Vector, b: Vector): number { - return dot(a, b) / Math.sqrt(norm2(a) * norm2(b)); - } - - /** - * Converts list of vectors (matrix) into a 1-dimensional - * typed array with row-first order. - */ - export function toTypedArray( - dataPoints: T[], - accessor: (dataPoint: T) => Float32Array - ): Float32Array { - let N = dataPoints.length; - let dim = accessor(dataPoints[0]).length; - let result = new Float32Array(N * dim); - for (let i = 0; i < N; ++i) { - let vector = accessor(dataPoints[i]); - for (let d = 0; d < dim; ++d) { - result[i * dim + d] = vector[d]; - } - } - return result; - } - - /** - * Transposes an RxC matrix represented as a flat typed array - * into a CxR matrix, again represented as a flat typed array. - */ - export function transposeTypedArray( - r: number, - c: number, - typedArray: Float32Array - ) { - let result = new Float32Array(r * c); - for (let i = 0; i < r; ++i) { - for (let j = 0; j < c; ++j) { - result[j * r + i] = typedArray[i * c + j]; - } - } - return result; - } -} // namespace vz_projector.vector + centroid[j] += vector[j]; + } + } + for (let j = 0; j < centroid.length; ++j) { + centroid[j] /= dataPoints.length; + } + return centroid; +} +/** + * Generates a vector of the specified size where each component is drawn from + * a random (0, 1) gaussian distribution. + */ +export function rn(size: number): Float32Array { + const normal = d3.randomNormal(); + let result = new Float32Array(size); + for (let i = 0; i < size; ++i) { + result[i] = normal(); + } + return result; +} +/** + * Returns the cosine distance ([0, 2]) between two vectors + * that have been normalized to unit norm. + */ +export function cosDistNorm(a: Vector, b: Vector): number { + return 1 - dot(a, b); +} +/** + * Returns the cosine distance ([0, 2]) between two vectors. + */ +export function cosDist(a: Vector, b: Vector): number { + return 1 - cosSim(a, b); +} +/** Returns the cosine similarity ([-1, 1]) between two vectors. */ +export function cosSim(a: Vector, b: Vector): number { + return dot(a, b) / Math.sqrt(norm2(a) * norm2(b)); +} +/** + * Converts list of vectors (matrix) into a 1-dimensional + * typed array with row-first order. + */ +export function toTypedArray( + dataPoints: T[], + accessor: (dataPoint: T) => Float32Array +): Float32Array { + let N = dataPoints.length; + let dim = accessor(dataPoints[0]).length; + let result = new Float32Array(N * dim); + for (let i = 0; i < N; ++i) { + let vector = accessor(dataPoints[i]); + for (let d = 0; d < dim; ++d) { + result[i * dim + d] = vector[d]; + } + } + return result; +} +/** + * Transposes an RxC matrix represented as a flat typed array + * into a CxR matrix, again represented as a flat typed array. + */ +export function transposeTypedArray( + r: number, + c: number, + typedArray: Float32Array +) { + let result = new Float32Array(r * c); + for (let i = 0; i < r; ++i) { + for (let j = 0; j < c; ++j) { + result[j * r + i] = typedArray[i * c + j]; + } + } + return result; +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-app.html b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-app.ts similarity index 59% rename from tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-app.html rename to tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-app.ts index 264ff518dc..039352bcbf 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-app.html +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-app.ts @@ -1,6 +1,4 @@ - +==============================================================================*/ - - - - - +import {PolymerElement, html} from '@polymer/polymer'; +import {customElement, property} from '@polymer/decorators'; +import '@polymer/paper-icon-button'; +import '@polymer/paper-tooltip'; - - - - diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.html.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.html.ts new file mode 100644 index 0000000000..b7bc6ebf16 --- /dev/null +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.html.ts @@ -0,0 +1,211 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {html} from '@polymer/polymer'; + +import './styles'; + +export const template = html` + + + + +
+ + + +
+ + + + +
+ + + + +
+
+
+
+`; diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.ts index f0e304e8e7..089486aebf 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-bookmark-panel.ts @@ -12,273 +12,252 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - // tslint:disable-next-line - export let BookmarkPanelPolymer = PolymerElement({ - is: 'vz-projector-bookmark-panel', - properties: { - savedStates: Object, - // Keep a separate polymer property because the savedStates doesn't change - // when adding and removing states. - hasStates: {type: Boolean, value: false}, - selectedState: Number, - }, - }); - - export class BookmarkPanel extends BookmarkPanelPolymer { - private projector: Projector; - - // A list containing all of the saved states. - private savedStates: State[]; - private hasStates = false; - private selectedState: number; - private ignoreNextProjectionEvent: boolean; - - private expandLessButton: HTMLButtonElement; - private expandMoreButton: HTMLButtonElement; - - ready() { - super.ready(); - this.savedStates = []; - this.setupUploadButton(); - this.ignoreNextProjectionEvent = false; - this.expandLessButton = this.$$('#expand-less') as HTMLButtonElement; - this.expandMoreButton = this.$$('#expand-more') as HTMLButtonElement; - } - - initialize( - projector: Projector, - projectorEventContext: ProjectorEventContext - ) { - this.projector = projector; - projectorEventContext.registerProjectionChangedListener(() => { - if (this.ignoreNextProjectionEvent) { - this.ignoreNextProjectionEvent = false; - } else { - this.clearStateSelection(); - } - }); - } - - setSelectedTensor( - run: string, - tensorInfo: EmbeddingInfo, - dataProvider: DataProvider - ) { - // Clear any existing bookmarks. - this.addStates(null); - if (tensorInfo && tensorInfo.bookmarksPath) { - // Get any bookmarks that may come when the projector starts up. - dataProvider.getBookmarks(run, tensorInfo.tensorName, (bookmarks) => { - this.addStates(bookmarks); - this._expandMore(); - }); +import {PolymerElement} from '@polymer/polymer'; +import {LegacyElementMixin} from '@polymer/polymer/lib/legacy/legacy-element-mixin'; +import {customElement, property} from '@polymer/decorators'; + +import '@polymer/iron-collapse'; +import '@polymer/paper-icon-button'; +import '@polymer/paper-tooltip'; + +import {Projector} from './vz-projector'; +import {template} from './vz-projector-bookmark-panel.html'; +import {State} from './data'; +import {ProjectorEventContext} from './projectorEventContext'; +import {DataProvider, EmbeddingInfo} from './data-provider'; +import * as logging from './logging'; + +@customElement('vz-projector-bookmark-panel') +export class BookmarkPanel extends LegacyElementMixin(PolymerElement) { + static readonly template = template; + + @property({type: Object}) + savedStates: Array; + // Keep a separate polymer property because the savedStates doesn't change + // when adding and removing states. + @property({type: Boolean}) + hasStates: boolean = false; + @property({type: Number}) + selectedState: number; + + private projector: Projector; + private ignoreNextProjectionEvent: boolean; + private expandLessButton: HTMLButtonElement; + private expandMoreButton: HTMLButtonElement; + + ready() { + super.ready(); + this.savedStates = []; + this.setupUploadButton(); + this.ignoreNextProjectionEvent = false; + this.expandLessButton = this.$$('#expand-less') as HTMLButtonElement; + this.expandMoreButton = this.$$('#expand-more') as HTMLButtonElement; + } + initialize( + projector: Projector, + projectorEventContext: ProjectorEventContext + ) { + this.projector = projector; + projectorEventContext.registerProjectionChangedListener(() => { + if (this.ignoreNextProjectionEvent) { + this.ignoreNextProjectionEvent = false; } else { - this._expandLess(); - } - } - - /** Handles a click on show bookmarks tray button. */ - _expandMore() { - this.$.panel.show(); - this.expandMoreButton.style.display = 'none'; - this.expandLessButton.style.display = ''; - } - - /** Handles a click on hide bookmarks tray button. */ - _expandLess() { - this.$.panel.hide(); - this.expandMoreButton.style.display = ''; - this.expandLessButton.style.display = 'none'; - } - - /** Handles a click on the add bookmark button. */ - _addBookmark() { - let currentState = this.projector.getCurrentState(); - currentState.label = 'State ' + this.savedStates.length; - currentState.isSelected = true; - - this.selectedState = this.savedStates.length; - - for (let i = 0; i < this.savedStates.length; i++) { - this.savedStates[i].isSelected = false; - // We have to call notifyPath so that polymer knows this element was - // updated. - this.notifyPath('savedStates.' + i + '.isSelected', false, false); + this.clearStateSelection(); } - - this.push('savedStates', currentState as any); - this.updateHasStates(); - } - - /** Handles a click on the download bookmarks button. */ - _downloadFile() { - let serializedState = this.serializeAllSavedStates(); - let blob = new Blob([serializedState], {type: 'text/plain'}); - let textFile = window.URL.createObjectURL(blob); - - // Force a download. - let a = document.createElement('a'); - document.body.appendChild(a); - a.style.display = 'none'; - a.href = textFile; - (a as any).download = 'state'; - a.click(); - - document.body.removeChild(a); - window.URL.revokeObjectURL(textFile); - } - - /** Handles a click on the upload bookmarks button. */ - _uploadFile() { - let fileInput = this.$$('#state-file'); - (fileInput as HTMLInputElement).click(); + }); + } + setSelectedTensor( + run: string, + tensorInfo: EmbeddingInfo, + dataProvider: DataProvider + ) { + // Clear any existing bookmarks. + this.addStates(null); + if (tensorInfo && tensorInfo.bookmarksPath) { + // Get any bookmarks that may come when the projector starts up. + dataProvider.getBookmarks(run, tensorInfo.tensorName, (bookmarks) => { + this.addStates(bookmarks); + this._expandMore(); + }); + } else { + this._expandLess(); } - - private setupUploadButton() { - // Show and setup the load view button. - const fileInput = this.$$('#state-file') as HTMLInputElement; - fileInput.onchange = () => { - const file: File = fileInput.files[0]; - // Clear out the value of the file chooser. This ensures that if the user - // selects the same file, we'll re-read it. - fileInput.value = ''; - const fileReader = new FileReader(); - fileReader.onload = (evt) => { - const str: string = fileReader.result; - const savedStates = JSON.parse(str); - - // Verify the bookmarks match. - if (this.savedStatesValid(savedStates)) { - this.addStates(savedStates); - this.loadSavedState(0); - } else { - logging.setWarningMessage( - `Unable to load bookmarks: wrong dataset, expected dataset ` + - `with shape (${savedStates[0].dataSetDimensions}).` - ); - } - }; - fileReader.readAsText(file); - }; + } + /** Handles a click on show bookmarks tray button. */ + _expandMore() { + (this.$.panel as any).show(); + this.expandMoreButton.style.display = 'none'; + this.expandLessButton.style.display = ''; + } + /** Handles a click on hide bookmarks tray button. */ + _expandLess() { + (this.$.panel as any).hide(); + this.expandMoreButton.style.display = ''; + this.expandLessButton.style.display = 'none'; + } + /** Handles a click on the add bookmark button. */ + _addBookmark() { + let currentState = this.projector.getCurrentState(); + currentState.label = 'State ' + this.savedStates.length; + currentState.isSelected = true; + this.selectedState = this.savedStates.length; + for (let i = 0; i < this.savedStates.length; i++) { + this.savedStates[i].isSelected = false; + // We have to call notifyPath so that polymer knows this element was + // updated. + this.notifyPath('savedStates.' + i + '.isSelected', false); } - - addStates(savedStates?: State[]) { - if (savedStates == null) { - this.savedStates = []; - } else { - for (let i = 0; i < savedStates.length; i++) { - savedStates[i].isSelected = false; - this.push('savedStates', savedStates[i] as any); + this.push('savedStates', currentState as any); + this.updateHasStates(); + } + /** Handles a click on the download bookmarks button. */ + _downloadFile() { + let serializedState = this.serializeAllSavedStates(); + let blob = new Blob([serializedState], {type: 'text/plain'}); + let textFile = window.URL.createObjectURL(blob); + // Force a download. + let a = document.createElement('a'); + document.body.appendChild(a); + a.style.display = 'none'; + a.href = textFile; + (a as any).download = 'state'; + a.click(); + document.body.removeChild(a); + window.URL.revokeObjectURL(textFile); + } + /** Handles a click on the upload bookmarks button. */ + _uploadFile() { + let fileInput = this.$$('#state-file'); + (fileInput as HTMLInputElement).click(); + } + private setupUploadButton() { + // Show and setup the load view button. + const fileInput = this.$$('#state-file') as HTMLInputElement; + fileInput.onchange = () => { + const file: File = fileInput.files[0]; + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + fileInput.value = ''; + const fileReader = new FileReader(); + fileReader.onload = (evt) => { + const str: string = fileReader.result as string; + const savedStates = JSON.parse(str); + // Verify the bookmarks match. + if (this.savedStatesValid(savedStates)) { + this.addStates(savedStates); + this.loadSavedState(0); + } else { + logging.setWarningMessage( + `Unable to load bookmarks: wrong dataset, expected dataset ` + + `with shape (${savedStates[0].dataSetDimensions}).` + ); } - } - this.updateHasStates(); - } - - /** Deselects any selected state selection. */ - clearStateSelection() { - for (let i = 0; i < this.savedStates.length; i++) { - this.setSelectionState(i, false); + }; + fileReader.readAsText(file); + }; + } + addStates(savedStates?: State[]) { + if (savedStates == null) { + this.savedStates = []; + } else { + for (let i = 0; i < savedStates.length; i++) { + savedStates[i].isSelected = false; + this.push('savedStates', savedStates[i] as any); } } - - /** Handles a radio button click on a saved state. */ - _radioButtonHandler(evt: Event) { - const index = this.getParentDataIndex(evt); - this.loadSavedState(index); - this.setSelectionState(index, true); + this.updateHasStates(); + } + /** Deselects any selected state selection. */ + clearStateSelection() { + for (let i = 0; i < this.savedStates.length; i++) { + this.setSelectionState(i, false); } - - loadSavedState(index: number) { - for (let i = 0; i < this.savedStates.length; i++) { - if (this.savedStates[i].isSelected) { - this.setSelectionState(i, false); - } else if (index === i) { - this.setSelectionState(i, true); - this.ignoreNextProjectionEvent = true; - this.projector.loadState(this.savedStates[i]); - } + } + /** Handles a radio button click on a saved state. */ + _radioButtonHandler(evt: Event) { + const index = this.getParentDataIndex(evt); + this.loadSavedState(index); + this.setSelectionState(index, true); + } + loadSavedState(index: number) { + for (let i = 0; i < this.savedStates.length; i++) { + if (this.savedStates[i].isSelected) { + this.setSelectionState(i, false); + } else if (index === i) { + this.setSelectionState(i, true); + this.ignoreNextProjectionEvent = true; + this.projector.loadState(this.savedStates[i]); } } - - private setSelectionState(stateIndex: number, selected: boolean) { - this.savedStates[stateIndex].isSelected = selected; - const path = 'savedStates.' + stateIndex + '.isSelected'; - this.notifyPath(path, selected, false); - } - - /** - * Crawls up the DOM to find an ancestor with a data-index attribute. This is - * used to match events to their bookmark index. - */ - private getParentDataIndex(evt: Event) { - for (let i = 0; i < (evt as any).path.length; i++) { - let dataIndex = (evt as any).path[i].getAttribute('data-index'); - if (dataIndex != null) { - return +dataIndex; - } + } + private setSelectionState(stateIndex: number, selected: boolean) { + this.savedStates[stateIndex].isSelected = selected; + const path = 'savedStates.' + stateIndex + '.isSelected'; + this.notifyPath(path, selected); + } + /** + * Crawls up the DOM to find an ancestor with a data-index attribute. This is + * used to match events to their bookmark index. + */ + private getParentDataIndex(evt: Event) { + for (let i = 0; i < (evt as any).path.length; i++) { + let dataIndex = (evt as any).path[i].getAttribute('data-index'); + if (dataIndex != null) { + return +dataIndex; } - return -1; - } - - /** Handles a clear button click on a bookmark. */ - _clearButtonHandler(evt: Event) { - let index = this.getParentDataIndex(evt); - this.splice('savedStates', index, 1); - this.updateHasStates(); } - - /** Handles a label change event on a bookmark. */ - _labelChange(evt: Event) { - let index = this.getParentDataIndex(evt); - this.savedStates[index].label = (evt.target as any).value; - } - - /** - * Used to determine whether to select the radio button for a given bookmark. - */ - _isSelectedState(index: number) { - return index === this.selectedState; - } - _isNotSelectedState(index: number) { - return index !== this.selectedState; - } - - /** - * Gets all of the saved states as a serialized string. - */ - serializeAllSavedStates(): string { - return JSON.stringify(this.savedStates); - } - - /** - * Loads all of the serialized states and shows them in the list of - * viewable states. - */ - loadSavedStates(serializedStates: string) { - this.savedStates = JSON.parse(serializedStates); - this.updateHasStates(); - } - - /** - * Updates the hasState polymer property. - */ - private updateHasStates() { - this.hasStates = this.savedStates.length !== 0; - } - - /** Sanity checks a State array to ensure it matches the current dataset. */ - private savedStatesValid(states: State[]): boolean { - for (let i = 0; i < states.length; i++) { - if ( - states[i].dataSetDimensions[0] !== this.projector.dataSet.dim[0] || - states[i].dataSetDimensions[1] !== this.projector.dataSet.dim[1] - ) { - return false; - } + return -1; + } + /** Handles a clear button click on a bookmark. */ + _clearButtonHandler(evt: Event) { + let index = this.getParentDataIndex(evt); + this.splice('savedStates', index, 1); + this.updateHasStates(); + } + /** Handles a label change event on a bookmark. */ + _labelChange(evt: Event) { + let index = this.getParentDataIndex(evt); + this.savedStates[index].label = (evt.target as any).value; + } + /** + * Used to determine whether to select the radio button for a given bookmark. + */ + _isSelectedState(index: number) { + return index === this.selectedState; + } + _isNotSelectedState(index: number) { + return index !== this.selectedState; + } + /** + * Gets all of the saved states as a serialized string. + */ + serializeAllSavedStates(): string { + return JSON.stringify(this.savedStates); + } + /** + * Loads all of the serialized states and shows them in the list of + * viewable states. + */ + loadSavedStates(serializedStates: string) { + this.savedStates = JSON.parse(serializedStates); + this.updateHasStates(); + } + /** + * Updates the hasState polymer property. + */ + private updateHasStates() { + this.hasStates = this.savedStates.length !== 0; + } + /** Sanity checks a State array to ensure it matches the current dataset. */ + private savedStatesValid(states: State[]): boolean { + for (let i = 0; i < states.length; i++) { + if ( + states[i].dataSetDimensions[0] !== this.projector.dataSet.dim[0] || + states[i].dataSetDimensions[1] !== this.projector.dataSet.dim[1] + ) { + return false; } - return true; } + return true; } - customElements.define(BookmarkPanel.prototype.is, BookmarkPanel); -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-dashboard.html b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-dashboard.ts similarity index 61% rename from tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-dashboard.html rename to tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-dashboard.ts index 043bfe39de..afce7d87f7 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-dashboard.html +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-dashboard.ts @@ -1,6 +1,4 @@ - +==============================================================================*/ - - - - - +import {PolymerElement, html} from '@polymer/polymer'; +import {customElement, property} from '@polymer/decorators'; - - - - + `; + @property({type: Boolean}) + dataNotFound: boolean; + @property({ + type: String, + }) + _routePrefix: string = '.'; + @property({type: Boolean}) + _initialized: boolean; + reload() { + // Do not reload the embedding projector. Reloading could take a long time. + } + attached() { + if (this._initialized) { + return; + } + let xhr = new XMLHttpRequest(); + xhr.open('GET', this._routePrefix + '/runs'); + xhr.onload = () => { + // Set this to true so we only initialize once. + this._initialized = true; + let runs = JSON.parse(xhr.responseText); + this.set('dataNotFound', runs.length === 0); + }; + xhr.onerror = () => { + this.set('dataNotFound', false); + }; + xhr.send(); + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.html b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.html deleted file mode 100644 index caf8a73119..0000000000 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.html +++ /dev/null @@ -1,677 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.html.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.html.ts new file mode 100644 index 0000000000..e9c9065c93 --- /dev/null +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.html.ts @@ -0,0 +1,652 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {html} from '@polymer/polymer'; + +import './styles'; + +export const template = html` + + +
DATA
+
+ + + + + +
+
+ + + + + + + + + + + + +
+
+ Use categorical coloring + + + For metadata fields that have many unique values we use a gradient + color map by default. This checkbox allows you to force categorical + coloring by a given metadata field. + +
+ +
+ +
+ + + Load data from your computer + + Load + + + + Publish your embedding visualization and data + + Publish + + + + Download the metadata with applied modifications + + Download + + + + + Label selected metadata + + Label + +
+
+ +

Load data from your computer

+ +
+
+
+ Step 1: Load a TSV file of + vectors. +
+
+
+
+ Example of 3 vectors with dimension 4: +
+ 0.1 0.2 + 0.5 0.9
+ 0.2 0.1 + 5.0 0.2
+ 0.4 0.1 + 7.0 0.8 +
+
+
+ Choose file + +
+
+
+
+
+
+ Step 2 (optional): + Load a TSV file of metadata. +
+
+
+
+ Example of 3 data points and 2 columns.
+ Note: If there is more than one column, the first row will be + parsed as column labels. +
+ Pokémon Species
+ Wartortle Turtle
+ Venusaur Seed
+ Charmeleon Flame +
+
+
+ Choose file + +
+
+
+
+
Click outside to dismiss.
+
+ +

Publish your embedding visualization and data

+ +
+

+ If you'd like to share your visualization with the world, follow + these simple steps. See + this tutorial + for more. +

+

Step 1: Make data public

+

+ Host tensors, metadata, sprite image, and bookmarks TSV files + publicly on the web. +

+

+ One option is using a + github gist. If you choose this approach, make sure to link directly to the + raw file. +

+
+
+

Step 2: Projector config

+
+ Optional: +
+ Metadata +
+
+ Sprite +
+
+ Bookmarks +
+
+
+ +
+

+ Step 3: Host projector config +

+ After you have hosted the projector config JSON file you built + above, paste the URL to the config below. +
+ + + +
+
Click outside to dismiss.
+
+
+ + Sphereize data + + + The data is normalized by shifting each point by the centroid and making + it unit norm. + + +
+ + + + + + + + + +
Checkpoint:
Metadata:
+
+
+`; diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.ts index 3338020ffd..7318124c70 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-data-panel.ts @@ -12,810 +12,746 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - export let DataPanelPolymer = PolymerElement({ - is: 'vz-projector-data-panel', - properties: { - selectedTensor: {type: String, observer: '_selectedTensorChanged'}, - selectedRun: String, - selectedColorOptionName: { - type: String, - notify: true, - observer: '_selectedColorOptionNameChanged', - }, - selectedLabelOption: { - type: String, - notify: true, - observer: '_selectedLabelOptionChanged', - }, - normalizeData: Boolean, - showForceCategoricalColorsCheckbox: Boolean, - metadataEditorInput: {type: String}, - metadataEditorInputLabel: {type: String, value: 'Tag selection as'}, - metadataEditorInputChange: {type: Object}, - metadataEditorColumn: {type: String}, - metadataEditorColumnChange: {type: Object}, - metadataEditorButtonClicked: {type: Object}, - metadataEditorButtonDisabled: {type: Boolean}, - downloadMetadataClicked: {type: Boolean}, - superviseInput: {type: String}, - superviseInputTyping: {type: Object}, - superviseInputChange: {type: Object}, - superviseInputLabel: {type: String, value: 'Ignored label'}, - superviseColumn: {type: String}, - superviseColumnChanged: {type: Object}, - showSuperviseSettings: {type: Boolean, value: false}, - }, - observers: ['_generateUiForNewCheckpointForRun(selectedRun)'], - }); - - export class DataPanel extends DataPanelPolymer { - selectedLabelOption: string; - selectedColorOptionName: string; - showForceCategoricalColorsCheckbox: boolean; - showSuperviseSettings: boolean; - - private normalizeData: boolean; - private labelOptions: string[]; - private colorOptions: ColorOption[]; - forceCategoricalColoring: boolean = false; - - private metadataEditorInput: string; - private metadataEditorInputLabel: string; - private metadataEditorButtonDisabled: boolean; - private superviseInput: string; - private superviseInputLabel: string; - private superviseInputSelected: string; - private superviseColumn: string; - - private selectedPointIndices: number[]; - private neighborsOfFirstPoint: knn.NearestEntry[]; - private selectedTensor: string; - private selectedRun: string; - private dataProvider: DataProvider; - private tensorNames: {name: string; shape: number[]}[]; - private runNames: string[]; - private projector: Projector; - private projectorConfig: ProjectorConfig; - private colorLegendRenderInfo: ColorLegendRenderInfo; - private spriteAndMetadata: SpriteAndMetadataInfo; - private metadataFile: string; - - ready() { - super.ready(); - this.normalizeData = true; - this.superviseInputSelected = ''; - } - - initialize(projector: Projector, dp: DataProvider) { - this.projector = projector; - this.dataProvider = dp; - this.setupUploadButtons(); - - // Tell the projector whenever the data normalization changes. - // Unknown why, but the polymer checkbox button stops working as soon as - // you do d3.select() on it. - this.$$('#normalize-data-checkbox').addEventListener('change', () => { - this.projector.setNormalizeData(this.normalizeData); - }); - - let forceCategoricalColoringCheckbox = this.$$( - '#force-categorical-checkbox' +import {PolymerElement} from '@polymer/polymer'; +import {LegacyElementMixin} from '@polymer/polymer/lib/legacy/legacy-element-mixin'; +import {customElement, property, observe} from '@polymer/decorators'; + +import * as d3 from 'd3'; + +import '@polymer/paper-button'; +import '@polymer/paper-checkbox'; +import '@polymer/paper-dialog-scrollable'; +import '@polymer/paper-dialog'; +import '@polymer/paper-dropdown-menu/paper-dropdown-menu'; +import '@polymer/paper-icon-button'; +import '@polymer/paper-input/paper-input'; +import '@polymer/paper-item'; +import '@polymer/paper-listbox'; +import '@polymer/paper-tooltip'; + +import {Projector} from './vz-projector'; +import {template} from './vz-projector-data-panel.html'; +import { + ColorLegendThreshold, + ColorLegendRenderInfo, +} from './vz-projector-legend'; +import { + ColumnStats, + ColorOption, + SpriteAndMetadataInfo, + Projection, +} from './data'; +import { + DataProvider, + EmbeddingInfo, + ProjectorConfig, + parseRawMetadata, + parseRawTensors, +} from './data-provider'; +import * as knn from './knn'; +import * as util from './util'; + +@customElement('vz-projector-data-panel') +export class DataPanel extends LegacyElementMixin(PolymerElement) { + static readonly template = template; + + @property({type: String}) + selectedTensor: string; + @property({type: String}) + selectedRun: string; + @property({type: String, notify: true}) + selectedColorOptionName: string; + @property({type: String, notify: true}) + selectedLabelOption: string; + @property({type: Boolean}) + normalizeData: boolean; + @property({type: Boolean}) + showForceCategoricalColorsCheckbox: boolean; + @property({type: String}) + metadataEditorInput: string; + @property({type: String}) + metadataEditorInputLabel: string = 'Tag selection as'; + @property({type: String}) + metadataEditorColumn: string; + @property({type: Boolean}) + metadataEditorButtonDisabled: boolean; + @property({type: String}) + superviseInput: string; + @property({type: String}) + superviseInputLabel: string = 'Ignored label'; + @property({type: String}) + superviseColumn: string; + @property({type: Boolean}) + showSuperviseSettings: boolean = false; + + private labelOptions: string[]; + private colorOptions: ColorOption[]; + forceCategoricalColoring: boolean = false; + private superviseInputSelected: string; + private selectedPointIndices: number[]; + private neighborsOfFirstPoint: knn.NearestEntry[]; + private dataProvider: DataProvider; + private tensorNames: { + name: string; + shape: number[]; + }[]; + private runNames: string[]; + private projector: Projector; + private projectorConfig: ProjectorConfig; + private colorLegendRenderInfo: ColorLegendRenderInfo; + private spriteAndMetadata: SpriteAndMetadataInfo; + private metadataFile: string; + private metadataFields: string[]; + + ready() { + super.ready(); + this.normalizeData = true; + this.superviseInputSelected = ''; + } + initialize(projector: Projector, dp: DataProvider) { + this.projector = projector; + this.dataProvider = dp; + this.setupUploadButtons(); + // Tell the projector whenever the data normalization changes. + // Unknown why, but the polymer checkbox button stops working as soon as + // you do d3.select() on it. + this.$$('#normalize-data-checkbox').addEventListener('change', () => { + this.projector.setNormalizeData(this.normalizeData); + }); + let forceCategoricalColoringCheckbox = this.$$( + '#force-categorical-checkbox' + ); + forceCategoricalColoringCheckbox.addEventListener('change', () => { + this.setForceCategoricalColoring( + (forceCategoricalColoringCheckbox as HTMLInputElement).checked ); - forceCategoricalColoringCheckbox.addEventListener('change', () => { - this.setForceCategoricalColoring( - (forceCategoricalColoringCheckbox as HTMLInputElement).checked - ); - }); - - // Get all the runs. - this.dataProvider.retrieveRuns((runs) => { - this.runNames = runs; - // Choose the first run by default. - if (this.runNames.length > 0) { - if (this.selectedRun != runs[0]) { - // This set operation will automatically trigger the observer. - this.selectedRun = runs[0]; - } else { - // Explicitly load the projector config. We explicitly load because - // the run name stays the same, which means that the observer won't - // actually be triggered by setting the selected run. - this._generateUiForNewCheckpointForRun(this.selectedRun); - } + }); + // Get all the runs. + this.dataProvider.retrieveRuns((runs) => { + this.runNames = runs; + // Choose the first run by default. + if (this.runNames.length > 0) { + if (this.selectedRun != runs[0]) { + // This set operation will automatically trigger the observer. + this.selectedRun = runs[0]; + } else { + // Explicitly load the projector config. We explicitly load because + // the run name stays the same, which means that the observer won't + // actually be triggered by setting the selected run. + this._generateUiForNewCheckpointForRun(this.selectedRun); } - }); - } - - setForceCategoricalColoring(forceCategoricalColoring: boolean) { - this.forceCategoricalColoring = forceCategoricalColoring; - (this.$$( - '#force-categorical-checkbox' - ) as HTMLInputElement).checked = this.forceCategoricalColoring; - - this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile); - - // The selected color option name doesn't change when we switch to using - // categorical coloring for stats with too many unique values, so we - // manually call this polymer observer so that we update the UI. - this._selectedColorOptionNameChanged(); + } + }); + } + setForceCategoricalColoring(forceCategoricalColoring: boolean) { + this.forceCategoricalColoring = forceCategoricalColoring; + (this.$$( + '#force-categorical-checkbox' + ) as HTMLInputElement).checked = this.forceCategoricalColoring; + this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile); + // The selected color option name doesn't change when we switch to using + // categorical coloring for stats with too many unique values, so we + // manually call this polymer observer so that we update the UI. + this._selectedColorOptionNameChanged(); + } + getSeparatorClass(isSeparator: boolean): string { + return isSeparator ? 'separator' : null; + } + metadataChanged( + spriteAndMetadata: SpriteAndMetadataInfo, + metadataFile?: string + ) { + this.spriteAndMetadata = spriteAndMetadata; + if (metadataFile != null) { + this.metadataFile = metadataFile; + } + this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile); + if ( + this.selectedColorOptionName == null || + this.colorOptions.filter((c) => c.name === this.selectedColorOptionName) + .length === 0 + ) { + this.selectedColorOptionName = this.colorOptions[0].name; } - - getSeparatorClass(isSeparator: boolean): string { - return isSeparator ? 'separator' : null; + let labelIndex = -1; + this.metadataFields = spriteAndMetadata.stats.map((stats, i) => { + if (!stats.isNumeric && labelIndex === -1) { + labelIndex = i; + } + return stats.name; + }); + if ( + this.metadataEditorColumn == null || + this.metadataFields.filter((name) => name === this.metadataEditorColumn) + .length === 0 + ) { + // Make the default label the first non-numeric column. + this.metadataEditorColumn = this.metadataFields[Math.max(0, labelIndex)]; } - - metadataChanged( - spriteAndMetadata: SpriteAndMetadataInfo, - metadataFile?: string + if ( + this.superviseColumn == null || + this.metadataFields.filter((name) => name === this.superviseColumn) + .length === 0 ) { - this.spriteAndMetadata = spriteAndMetadata; - if (metadataFile != null) { - this.metadataFile = metadataFile; - } - - this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile); - if ( - this.selectedColorOptionName == null || - this.colorOptions.filter((c) => c.name === this.selectedColorOptionName) - .length === 0 - ) { - this.selectedColorOptionName = this.colorOptions[0].name; - } - - let labelIndex = -1; - this.metadataFields = spriteAndMetadata.stats.map((stats, i) => { - if (!stats.isNumeric && labelIndex === -1) { - labelIndex = i; - } - return stats.name; - }); - - if ( - this.metadataEditorColumn == null || - this.metadataFields.filter((name) => name === this.metadataEditorColumn) - .length === 0 - ) { - // Make the default label the first non-numeric column. - this.metadataEditorColumn = this.metadataFields[ - Math.max(0, labelIndex) - ]; - } - - if ( - this.superviseColumn == null || - this.metadataFields.filter((name) => name === this.superviseColumn) - .length === 0 - ) { - // Make the default supervise class the first non-numeric column. - this.superviseColumn = this.metadataFields[Math.max(0, labelIndex)]; - this.superviseInput = ''; - } - this.superviseInputChange(); + // Make the default supervise class the first non-numeric column. + this.superviseColumn = this.metadataFields[Math.max(0, labelIndex)]; + this.superviseInput = ''; } - - projectionChanged(projection: Projection) { - if (projection) { - switch (projection.projectionType) { - case 'tsne': - this.set('showSuperviseSettings', true); - break; - - default: - this.set('showSuperviseSettings', false); - } + this.superviseInputChange(); + } + projectionChanged(projection: Projection) { + if (projection) { + switch (projection.projectionType) { + case 'tsne': + this.set('showSuperviseSettings', true); + break; + default: + this.set('showSuperviseSettings', false); } } - - onProjectorSelectionChanged( - selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[] - ) { - this.selectedPointIndices = selectedPointIndices; - this.neighborsOfFirstPoint = neighborsOfFirstPoint; - this.metadataEditorInputChange(); + } + onProjectorSelectionChanged( + selectedPointIndices: number[], + neighborsOfFirstPoint: knn.NearestEntry[] + ) { + this.selectedPointIndices = selectedPointIndices; + this.neighborsOfFirstPoint = neighborsOfFirstPoint; + this.metadataEditorInputChange(); + } + private addWordBreaks(longString: string): string { + if (longString == null) { + return ''; } - - private addWordBreaks(longString: string): string { - if (longString == null) { - return ''; + return longString.replace(/([\/=-_,])/g, '$1'); + } + private updateMetadataUI(columnStats: ColumnStats[], metadataFile: string) { + const metadataFileElement = this.$$('#metadata-file') as HTMLSpanElement; + metadataFileElement.innerHTML = this.addWordBreaks(metadataFile); + metadataFileElement.title = metadataFile; + // Label by options. + let labelIndex = -1; + this.labelOptions = columnStats.map((stats, i) => { + // Make the default label by the first non-numeric column. + if (!stats.isNumeric && labelIndex === -1) { + labelIndex = i; } - return longString.replace(/([\/=-_,])/g, '$1'); + return stats.name; + }); + if ( + this.selectedLabelOption == null || + this.labelOptions.filter((name) => name === this.selectedLabelOption) + .length === 0 + ) { + this.selectedLabelOption = this.labelOptions[Math.max(0, labelIndex)]; } - - private updateMetadataUI(columnStats: ColumnStats[], metadataFile: string) { - const metadataFileElement = this.$$('#metadata-file') as HTMLSpanElement; - metadataFileElement.innerHTML = this.addWordBreaks(metadataFile); - metadataFileElement.title = metadataFile; - - // Label by options. - let labelIndex = -1; - this.labelOptions = columnStats.map((stats, i) => { - // Make the default label by the first non-numeric column. - if (!stats.isNumeric && labelIndex === -1) { - labelIndex = i; + if ( + this.metadataEditorColumn == null || + this.labelOptions.filter((name) => name === this.metadataEditorColumn) + .length === 0 + ) { + this.metadataEditorColumn = this.labelOptions[Math.max(0, labelIndex)]; + } + // Color by options. + const standardColorOption: ColorOption[] = [{name: 'No color map'}]; + const metadataColorOption: ColorOption[] = columnStats + .filter((stats) => { + return !stats.tooManyUniqueValues || stats.isNumeric; + }) + .map((stats) => { + let map; + let items: { + label: string; + count: number; + }[]; + let thresholds: ColorLegendThreshold[]; + let isCategorical = + this.forceCategoricalColoring || !stats.tooManyUniqueValues; + let desc; + if (isCategorical) { + const scale = d3.scaleOrdinal(d3.schemeCategory10); + let range = scale.range(); + // Re-order the range. + let newRange = range.map((color, i) => { + let index = (i * 3) % range.length; + return range[index]; + }); + items = stats.uniqueEntries; + scale.range(newRange).domain(items.map((x) => x.label)); + map = scale; + const len = stats.uniqueEntries.length; + desc = + `${len} ${len > range.length ? ' non-unique' : ''} ` + `colors`; + } else { + thresholds = [ + {color: '#ffffdd', value: stats.min}, + {color: '#1f2d86', value: stats.max}, + ]; + map = d3 + .scaleLinear() + .domain(thresholds.map((t) => t.value)) + .range(thresholds.map((t) => t.color)); + desc = 'gradient'; } - return stats.name; + return { + name: stats.name, + desc: desc, + map: map, + items: items, + thresholds: thresholds, + tooManyUniqueValues: stats.tooManyUniqueValues, + }; }); - - if ( - this.selectedLabelOption == null || - this.labelOptions.filter((name) => name === this.selectedLabelOption) - .length === 0 - ) { - this.selectedLabelOption = this.labelOptions[Math.max(0, labelIndex)]; - } - - if ( - this.metadataEditorColumn == null || - this.labelOptions.filter((name) => name === this.metadataEditorColumn) - .length === 0 - ) { - this.metadataEditorColumn = this.labelOptions[Math.max(0, labelIndex)]; - } - - // Color by options. - const standardColorOption: ColorOption[] = [ - {name: 'No color map'}, - // TODO(@dsmilkov): Implement this. - // {name: 'Distance of neighbors', - // desc: 'How far is each point from its neighbors'} - ]; - const metadataColorOption: ColorOption[] = columnStats - .filter((stats) => { - return !stats.tooManyUniqueValues || stats.isNumeric; - }) - .map((stats) => { - let map; - let items: {label: string; count: number}[]; - let thresholds: ColorLegendThreshold[]; - let isCategorical = - this.forceCategoricalColoring || !stats.tooManyUniqueValues; - let desc; - - if (isCategorical) { - const scale = d3.scaleOrdinal(d3.schemeCategory10); - let range = scale.range(); - // Re-order the range. - let newRange = range.map((color, i) => { - let index = (i * 3) % range.length; - return range[index]; - }); - items = stats.uniqueEntries; - scale.range(newRange).domain(items.map((x) => x.label)); - map = scale; - const len = stats.uniqueEntries.length; - desc = - `${len} ${len > range.length ? ' non-unique' : ''} ` + `colors`; - } else { - thresholds = [ - {color: '#ffffdd', value: stats.min}, - {color: '#1f2d86', value: stats.max}, - ]; - map = d3 - .scaleLinear() - .domain(thresholds.map((t) => t.value)) - .range(thresholds.map((t) => t.color)); - desc = 'gradient'; - } - return { - name: stats.name, - desc: desc, - map: map, - items: items, - thresholds: thresholds, - tooManyUniqueValues: stats.tooManyUniqueValues, - }; - }); - - if (metadataColorOption.length > 0) { - // Add a separator line between built-in color maps - // and those based on metadata columns. - standardColorOption.push({name: 'Metadata', isSeparator: true}); - } - this.colorOptions = standardColorOption.concat(metadataColorOption); + if (metadataColorOption.length > 0) { + // Add a separator line between built-in color maps + // and those based on metadata columns. + standardColorOption.push({name: 'Metadata', isSeparator: true}); } - - private metadataEditorContext(enabled: boolean) { - this.metadataEditorButtonDisabled = !enabled; - if (this.projector) { - this.projector.metadataEditorContext( - enabled, - this.metadataEditorColumn - ); - } + this.colorOptions = standardColorOption.concat(metadataColorOption); + } + private metadataEditorContext(enabled: boolean) { + this.metadataEditorButtonDisabled = !enabled; + if (this.projector) { + this.projector.metadataEditorContext(enabled, this.metadataEditorColumn); } - - private metadataEditorInputChange() { - let col = this.metadataEditorColumn; - let value = this.metadataEditorInput; - let selectionSize = - this.selectedPointIndices.length + this.neighborsOfFirstPoint.length; - if (selectionSize > 0) { - if (value != null && value.trim() !== '') { - if ( - this.spriteAndMetadata.stats.filter((s) => s.name === col)[0] - .isNumeric && - isNaN(+value) - ) { - this.metadataEditorInputLabel = `Label must be numeric`; - this.metadataEditorContext(false); + } + private metadataEditorInputChange() { + let col = this.metadataEditorColumn; + let value = this.metadataEditorInput; + let selectionSize = + this.selectedPointIndices.length + this.neighborsOfFirstPoint.length; + if (selectionSize > 0) { + if (value != null && value.trim() !== '') { + if ( + this.spriteAndMetadata.stats.filter((s) => s.name === col)[0] + .isNumeric && + isNaN(+value) + ) { + this.metadataEditorInputLabel = `Label must be numeric`; + this.metadataEditorContext(false); + } else { + let numMatches = this.projector.dataSet.points.filter( + (p) => p.metadata[col].toString() === value.trim() + ).length; + if (numMatches === 0) { + this.metadataEditorInputLabel = `Tag ${selectionSize} with new label`; } else { - let numMatches = this.projector.dataSet.points.filter( - (p) => p.metadata[col].toString() === value.trim() - ).length; - - if (numMatches === 0) { - this.metadataEditorInputLabel = `Tag ${selectionSize} with new label`; - } else { - this.metadataEditorInputLabel = `Tag ${selectionSize} points as`; - } - this.metadataEditorContext(true); + this.metadataEditorInputLabel = `Tag ${selectionSize} points as`; } - } else { - this.metadataEditorInputLabel = 'Tag selection as'; - this.metadataEditorContext(false); + this.metadataEditorContext(true); } } else { + this.metadataEditorInputLabel = 'Tag selection as'; this.metadataEditorContext(false); - - if (value != null && value.trim() !== '') { - this.metadataEditorInputLabel = 'Select points to tag'; - } else { - this.metadataEditorInputLabel = 'Tag selection as'; - } } - } - - private metadataEditorInputKeydown(e) { - // Check if 'Enter' was pressed - if (e.keyCode === 13) { - this.metadataEditorButtonClicked(); + } else { + this.metadataEditorContext(false); + if (value != null && value.trim() !== '') { + this.metadataEditorInputLabel = 'Select points to tag'; + } else { + this.metadataEditorInputLabel = 'Tag selection as'; } - e.stopPropagation(); - } - - private metadataEditorColumnChange() { - this.metadataEditorInputChange(); } - - private metadataEditorButtonClicked() { - if (!this.metadataEditorButtonDisabled) { - let value = this.metadataEditorInput.trim(); - let selectionSize = - this.selectedPointIndices.length + this.neighborsOfFirstPoint.length; - this.projector.metadataEdit(this.metadataEditorColumn, value); - this.projector.metadataEditorContext(true, this.metadataEditorColumn); - this.metadataEditorInputLabel = `${selectionSize} labeled as '${value}'`; - } + } + private metadataEditorInputKeydown(e) { + // Check if 'Enter' was pressed + if (e.keyCode === 13) { + this.metadataEditorButtonClicked(); } - - private downloadMetadataClicked() { - if ( - this.projector && - this.projector.dataSet && - this.projector.dataSet.spriteAndMetadataInfo - ) { - let tsvFile = this.projector.dataSet.spriteAndMetadataInfo.stats - .map((s) => s.name) - .join('\t'); - - this.projector.dataSet.spriteAndMetadataInfo.pointsInfo.forEach((p) => { - let vals = []; - - for (const column in p) { - vals.push(p[column]); - } - tsvFile += '\n' + vals.join('\t'); - }); - - const textBlob = new Blob([tsvFile], {type: 'text/plain'}); - this.$.downloadMetadataLink.download = 'metadata-edited.tsv'; - this.$.downloadMetadataLink.href = window.URL.createObjectURL(textBlob); - this.$.downloadMetadataLink.click(); - } + e.stopPropagation(); + } + private metadataEditorColumnChange() { + this.metadataEditorInputChange(); + } + private metadataEditorButtonClicked() { + if (!this.metadataEditorButtonDisabled) { + let value = this.metadataEditorInput.trim(); + let selectionSize = + this.selectedPointIndices.length + this.neighborsOfFirstPoint.length; + this.projector.metadataEdit(this.metadataEditorColumn, value); + this.projector.metadataEditorContext(true, this.metadataEditorColumn); + this.metadataEditorInputLabel = `${selectionSize} labeled as '${value}'`; } - - private superviseInputTyping() { - let value = this.superviseInput.trim(); - if (value == null || value.trim() === '') { - if (this.superviseInputSelected === '') { - this.superviseInputLabel = 'No ignored label'; - } else { - this.superviseInputLabel = `Supervising without '${this.superviseInputSelected}'`; - } - return; - } - if (this.projector && this.projector.dataSet) { - let numMatches = this.projector.dataSet.points.filter( - (p) => p.metadata[this.superviseColumn].toString().trim() === value - ).length; - - if (numMatches === 0) { - this.superviseInputLabel = 'Label not found'; - } else { - if (this.projector.dataSet.superviseInput != value) { - this.superviseInputLabel = `Supervise without '${value}' [${numMatches} points]`; - } + } + private downloadMetadataClicked() { + if ( + this.projector && + this.projector.dataSet && + this.projector.dataSet.spriteAndMetadataInfo + ) { + let tsvFile = this.projector.dataSet.spriteAndMetadataInfo.stats + .map((s) => s.name) + .join('\t'); + this.projector.dataSet.spriteAndMetadataInfo.pointsInfo.forEach((p) => { + let vals = []; + for (const column in p) { + vals.push(p[column]); } - } + tsvFile += '\n' + vals.join('\t'); + }); + const textBlob = new Blob([tsvFile], {type: 'text/plain'}); + const anyDownloadMetadataLink = this.$.downloadMetadataLink as any; + anyDownloadMetadataLink.download = 'metadata-edited.tsv'; + anyDownloadMetadataLink.href = window.URL.createObjectURL(textBlob); + anyDownloadMetadataLink.click(); } - - private superviseInputChange() { - let value = this.superviseInput.trim(); - if (value == null || value.trim() === '') { - this.superviseInputSelected = ''; + } + private superviseInputTyping() { + let value = this.superviseInput.trim(); + if (value == null || value.trim() === '') { + if (this.superviseInputSelected === '') { this.superviseInputLabel = 'No ignored label'; - this.setSupervision(this.superviseColumn, ''); - return; + } else { + this.superviseInputLabel = `Supervising without '${this.superviseInputSelected}'`; } - if (this.projector && this.projector.dataSet) { - let numMatches = this.projector.dataSet.points.filter( - (p) => p.metadata[this.superviseColumn].toString().trim() === value - ).length; - - if (numMatches === 0) { - this.superviseInputLabel = `Supervising without '${this.superviseInputSelected}'`; - } else { - this.superviseInputSelected = value; - this.superviseInputLabel = `Supervising without '${value}' [${numMatches} points]`; - this.setSupervision(this.superviseColumn, value); + return; + } + if (this.projector && this.projector.dataSet) { + let numMatches = this.projector.dataSet.points.filter( + (p) => p.metadata[this.superviseColumn].toString().trim() === value + ).length; + if (numMatches === 0) { + this.superviseInputLabel = 'Label not found'; + } else { + if (this.projector.dataSet.superviseInput != value) { + this.superviseInputLabel = `Supervise without '${value}' [${numMatches} points]`; } } } - - private superviseColumnChanged() { - this.superviseInput = ''; - this.superviseInputChange(); - } - - private setSupervision(superviseColumn: string, superviseInput: string) { - if (this.projector && this.projector.dataSet) { - this.projector.dataSet.setSupervision(superviseColumn, superviseInput); + } + private superviseInputChange() { + let value = this.superviseInput.trim(); + if (value == null || value.trim() === '') { + this.superviseInputSelected = ''; + this.superviseInputLabel = 'No ignored label'; + this.setSupervision(this.superviseColumn, ''); + return; + } + if (this.projector && this.projector.dataSet) { + let numMatches = this.projector.dataSet.points.filter( + (p) => p.metadata[this.superviseColumn].toString().trim() === value + ).length; + if (numMatches === 0) { + this.superviseInputLabel = `Supervising without '${this.superviseInputSelected}'`; + } else { + this.superviseInputSelected = value; + this.superviseInputLabel = `Supervising without '${value}' [${numMatches} points]`; + this.setSupervision(this.superviseColumn, value); } } - - setNormalizeData(normalizeData: boolean) { - this.normalizeData = normalizeData; - } - - _selectedTensorChanged() { - this.projector.updateDataSet(null, null, null); - if (this.selectedTensor == null) { - return; - } - this.dataProvider.retrieveTensor( - this.selectedRun, - this.selectedTensor, - (ds) => { - let metadataFile = this.getEmbeddingInfoByName(this.selectedTensor) - .metadataPath; - this.dataProvider.retrieveSpriteAndMetadata( - this.selectedRun, - this.selectedTensor, - (metadata) => { - this.projector.updateDataSet(ds, metadata, metadataFile); - } - ); - } - ); - this.projector.setSelectedTensor( - this.selectedRun, - this.getEmbeddingInfoByName(this.selectedTensor) - ); + } + private superviseColumnChanged() { + this.superviseInput = ''; + this.superviseInputChange(); + } + private setSupervision(superviseColumn: string, superviseInput: string) { + if (this.projector && this.projector.dataSet) { + this.projector.dataSet.setSupervision(superviseColumn, superviseInput); } - - _generateUiForNewCheckpointForRun(selectedRun) { - this.dataProvider.retrieveProjectorConfig(selectedRun, (info) => { - this.projectorConfig = info; - let names = this.projectorConfig.embeddings - .map((e) => e.tensorName) - .filter((name) => { - let shape = this.getEmbeddingInfoByName(name).tensorShape; - return shape.length === 2 && shape[0] > 1 && shape[1] > 1; - }) - .sort((a, b) => { - let embA = this.getEmbeddingInfoByName(a); - let embB = this.getEmbeddingInfoByName(b); - - // Prefer tensors with metadata. - if (util.xor(!!embA.metadataPath, !!embB.metadataPath)) { - return embA.metadataPath ? -1 : 1; - } - - // Prefer non-generated tensors. - let isGenA = util.tensorIsGenerated(a); - let isGenB = util.tensorIsGenerated(b); - if (util.xor(isGenA, isGenB)) { - return isGenB ? -1 : 1; - } - - // Prefer bigger tensors. - let sizeA = embA.tensorShape[0]; - let sizeB = embB.tensorShape[0]; - if (sizeA !== sizeB) { - return sizeB - sizeA; - } - - // Sort alphabetically by tensor name. - return a <= b ? -1 : 1; - }); - this.tensorNames = names.map((name) => { - return {name, shape: this.getEmbeddingInfoByName(name).tensorShape}; - }); - const wordBreakablePath = this.addWordBreaks( - this.projectorConfig.modelCheckpointPath + } + setNormalizeData(normalizeData: boolean) { + this.normalizeData = normalizeData; + } + @observe('selectedTensor') + _selectedTensorChanged() { + this.projector.updateDataSet(null, null, null); + if (this.selectedTensor == null) { + return; + } + this.dataProvider.retrieveTensor( + this.selectedRun, + this.selectedTensor, + (ds) => { + let metadataFile = this.getEmbeddingInfoByName(this.selectedTensor) + .metadataPath; + this.dataProvider.retrieveSpriteAndMetadata( + this.selectedRun, + this.selectedTensor, + (metadata) => { + this.projector.updateDataSet(ds, metadata, metadataFile); + } ); - const checkpointFile = this.$$('#checkpoint-file') as HTMLSpanElement; - checkpointFile.innerHTML = wordBreakablePath; - checkpointFile.title = this.projectorConfig.modelCheckpointPath; - - // If in demo mode, let the order decide which tensor to load by default. - const defaultTensor = - this.projector.servingMode === 'demo' - ? this.projectorConfig.embeddings[0].tensorName - : names[0]; - if (this.selectedTensor === defaultTensor) { - // Explicitly call the observer. Polymer won't call it if the previous - // string matches the current string. - this._selectedTensorChanged(); - } else { - this.selectedTensor = defaultTensor; - } - }); - } - - _selectedLabelOptionChanged() { - this.projector.setSelectedLabelOption(this.selectedLabelOption); - } - - _selectedColorOptionNameChanged() { - let colorOption: ColorOption; - for (let i = 0; i < this.colorOptions.length; i++) { - if (this.colorOptions[i].name === this.selectedColorOptionName) { - colorOption = this.colorOptions[i]; - break; - } - } - if (!colorOption) { - return; } - - this.showForceCategoricalColorsCheckbox = !!colorOption.tooManyUniqueValues; - - if (colorOption.map == null) { - this.colorLegendRenderInfo = null; - } else if (colorOption.items) { - let items = colorOption.items.map((item) => { - return { - color: colorOption.map(item.label), - label: item.label, - count: item.count, - }; + ); + this.projector.setSelectedTensor( + this.selectedRun, + this.getEmbeddingInfoByName(this.selectedTensor) + ); + } + @observe('selectedRun') + _generateUiForNewCheckpointForRun(selectedRun) { + this.dataProvider.retrieveProjectorConfig(selectedRun, (info) => { + this.projectorConfig = info; + let names = this.projectorConfig.embeddings + .map((e) => e.tensorName) + .filter((name) => { + let shape = this.getEmbeddingInfoByName(name).tensorShape; + return shape.length === 2 && shape[0] > 1 && shape[1] > 1; + }) + .sort((a, b) => { + let embA = this.getEmbeddingInfoByName(a); + let embB = this.getEmbeddingInfoByName(b); + // Prefer tensors with metadata. + if (util.xor(!!embA.metadataPath, !!embB.metadataPath)) { + return embA.metadataPath ? -1 : 1; + } + // Prefer non-generated tensors. + let isGenA = util.tensorIsGenerated(a); + let isGenB = util.tensorIsGenerated(b); + if (util.xor(isGenA, isGenB)) { + return isGenB ? -1 : 1; + } + // Prefer bigger tensors. + let sizeA = embA.tensorShape[0]; + let sizeB = embB.tensorShape[0]; + if (sizeA !== sizeB) { + return sizeB - sizeA; + } + // Sort alphabetically by tensor name. + return a <= b ? -1 : 1; }); - this.colorLegendRenderInfo = {items, thresholds: null}; + this.tensorNames = names.map((name) => { + return {name, shape: this.getEmbeddingInfoByName(name).tensorShape}; + }); + const wordBreakablePath = this.addWordBreaks( + this.projectorConfig.modelCheckpointPath + ); + const checkpointFile = this.$$('#checkpoint-file') as HTMLSpanElement; + checkpointFile.innerHTML = wordBreakablePath; + checkpointFile.title = this.projectorConfig.modelCheckpointPath; + // If in demo mode, let the order decide which tensor to load by default. + const defaultTensor = + this.projector.servingMode === 'demo' + ? this.projectorConfig.embeddings[0].tensorName + : names[0]; + if (this.selectedTensor === defaultTensor) { + // Explicitly call the observer. Polymer won't call it if the previous + // string matches the current string. + this._selectedTensorChanged(); } else { - this.colorLegendRenderInfo = { - items: null, - thresholds: colorOption.thresholds, - }; + this.selectedTensor = defaultTensor; } - this.projector.setSelectedColorOption(colorOption); - } + }); + } - private tensorWasReadFromFile(rawContents: ArrayBuffer, fileName: string) { - parseRawTensors(rawContents, (ds) => { - const checkpointFile = this.$$('#checkpoint-file') as HTMLSpanElement; - checkpointFile.innerText = fileName; - checkpointFile.title = fileName; - this.projector.updateDataSet(ds); - }); + @observe('selectedLabelOption') + _selectedLabelOptionChanged() { + this.projector.setSelectedLabelOption(this.selectedLabelOption); + } + @observe('selectedColorOptionName') + _selectedColorOptionNameChanged() { + let colorOption: ColorOption; + for (let i = 0; i < this.colorOptions.length; i++) { + if (this.colorOptions[i].name === this.selectedColorOptionName) { + colorOption = this.colorOptions[i]; + break; + } } - - private metadataWasReadFromFile( - rawContents: ArrayBuffer, - fileName: string - ) { - parseRawMetadata(rawContents, (metadata) => { - this.projector.updateDataSet( - this.projector.dataSet, - metadata, - fileName - ); + if (!colorOption) { + return; + } + this.showForceCategoricalColorsCheckbox = !!colorOption.tooManyUniqueValues; + if (colorOption.map == null) { + this.colorLegendRenderInfo = null; + } else if (colorOption.items) { + let items = colorOption.items.map((item) => { + return { + color: colorOption.map(item.label), + label: item.label, + count: item.count, + }; }); + this.colorLegendRenderInfo = {items, thresholds: null}; + } else { + this.colorLegendRenderInfo = { + items: null, + thresholds: colorOption.thresholds, + }; } - - private getEmbeddingInfoByName(tensorName: string): EmbeddingInfo { - for (let i = 0; i < this.projectorConfig.embeddings.length; i++) { - const e = this.projectorConfig.embeddings[i]; - if (e.tensorName === tensorName) { - return e; - } + this.projector.setSelectedColorOption(colorOption); + } + private tensorWasReadFromFile(rawContents: ArrayBuffer, fileName: string) { + parseRawTensors(rawContents, (ds) => { + const checkpointFile = this.$$('#checkpoint-file') as HTMLSpanElement; + checkpointFile.innerText = fileName; + checkpointFile.title = fileName; + this.projector.updateDataSet(ds); + }); + } + private metadataWasReadFromFile(rawContents: ArrayBuffer, fileName: string) { + parseRawMetadata(rawContents, (metadata) => { + this.projector.updateDataSet(this.projector.dataSet, metadata, fileName); + }); + } + private getEmbeddingInfoByName(tensorName: string): EmbeddingInfo { + for (let i = 0; i < this.projectorConfig.embeddings.length; i++) { + const e = this.projectorConfig.embeddings[i]; + if (e.tensorName === tensorName) { + return e; } } - - private setupUploadButtons() { - // Show and setup the upload button. - const fileInput = this.$$('#file') as HTMLInputElement; - fileInput.onchange = () => { - const file: File = fileInput.files[0]; - // Clear out the value of the file chooser. This ensures that if the user - // selects the same file, we'll re-read it. - fileInput.value = ''; - const fileReader = new FileReader(); - fileReader.onload = (evt) => { - const content: ArrayBuffer = fileReader.result; - this.tensorWasReadFromFile(content, file.name); - }; - fileReader.readAsArrayBuffer(file); + } + private setupUploadButtons() { + // Show and setup the upload button. + const fileInput = this.$$('#file') as HTMLInputElement; + fileInput.onchange = () => { + const file: File = fileInput.files[0]; + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + fileInput.value = ''; + const fileReader = new FileReader(); + fileReader.onload = (evt) => { + const content: ArrayBuffer = fileReader.result as ArrayBuffer; + this.tensorWasReadFromFile(content, file.name); }; - - const uploadButton = this.$$('#upload-tensors') as HTMLButtonElement; - uploadButton.onclick = () => { - fileInput.click(); + fileReader.readAsArrayBuffer(file); + }; + const uploadButton = this.$$('#upload-tensors') as HTMLButtonElement; + uploadButton.onclick = () => { + fileInput.click(); + }; + // Show and setup the upload metadata button. + const fileMetadataInput = this.$$('#file-metadata') as HTMLInputElement; + fileMetadataInput.onchange = () => { + const file: File = fileMetadataInput.files[0]; + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + fileMetadataInput.value = ''; + const fileReader = new FileReader(); + fileReader.onload = (evt) => { + const contents: ArrayBuffer = fileReader.result as ArrayBuffer; + this.metadataWasReadFromFile(contents, file.name); }; - - // Show and setup the upload metadata button. - const fileMetadataInput = this.$$('#file-metadata') as HTMLInputElement; - fileMetadataInput.onchange = () => { - const file: File = fileMetadataInput.files[0]; - // Clear out the value of the file chooser. This ensures that if the user - // selects the same file, we'll re-read it. - fileMetadataInput.value = ''; - const fileReader = new FileReader(); - fileReader.onload = (evt) => { - const contents: ArrayBuffer = fileReader.result; - this.metadataWasReadFromFile(contents, file.name); + fileReader.readAsArrayBuffer(file); + }; + const uploadMetadataButton = this.$$( + '#upload-metadata' + ) as HTMLButtonElement; + uploadMetadataButton.onclick = () => { + fileMetadataInput.click(); + }; + if (this.projector.servingMode !== 'demo') { + (this.$$('#publish-container') as HTMLElement).style.display = 'none'; + (this.$$('#upload-tensors-step-container') as HTMLElement).style.display = + 'none'; + (this.$$('#upload-metadata-label') as HTMLElement).style.display = 'none'; + } + (this.$$('#demo-data-buttons-container') as HTMLElement).style.display = + 'flex'; + // Fill out the projector config. + const projectorConfigTemplate = this.$$( + '#projector-config-template' + ) as HTMLTextAreaElement; + const projectorConfigTemplateJson: ProjectorConfig = { + embeddings: [ + { + tensorName: 'My tensor', + tensorShape: [1000, 50], + tensorPath: 'https://raw.githubusercontent.com/.../tensors.tsv', + metadataPath: + 'https://raw.githubusercontent.com/.../optional.metadata.tsv', + }, + ], + }; + this.setProjectorConfigTemplateJson( + projectorConfigTemplate, + projectorConfigTemplateJson + ); + // Set up optional field checkboxes. + const spriteFieldCheckbox = this.$$( + '#config-sprite-checkbox' + ) as HTMLInputElement; + spriteFieldCheckbox.onchange = () => { + if ((spriteFieldCheckbox as any).checked) { + projectorConfigTemplateJson.embeddings[0].sprite = { + imagePath: 'https://github.com/.../optional.sprite.png', + singleImageDim: [32, 32], }; - fileReader.readAsArrayBuffer(file); - }; - - const uploadMetadataButton = this.$$( - '#upload-metadata' - ) as HTMLButtonElement; - uploadMetadataButton.onclick = () => { - fileMetadataInput.click(); - }; - - if (this.projector.servingMode !== 'demo') { - (this.$$('#publish-container') as HTMLElement).style.display = 'none'; - (this.$$( - '#upload-tensors-step-container' - ) as HTMLElement).style.display = 'none'; - (this.$$('#upload-metadata-label') as HTMLElement).style.display = - 'none'; + } else { + delete projectorConfigTemplateJson.embeddings[0].sprite; } - - (this.$$('#demo-data-buttons-container') as HTMLElement).style.display = - 'flex'; - - // Fill out the projector config. - const projectorConfigTemplate = this.$$( - '#projector-config-template' - ) as HTMLTextAreaElement; - const projectorConfigTemplateJson: ProjectorConfig = { - embeddings: [ - { - tensorName: 'My tensor', - tensorShape: [1000, 50], - tensorPath: 'https://raw.githubusercontent.com/.../tensors.tsv', - metadataPath: - 'https://raw.githubusercontent.com/.../optional.metadata.tsv', - }, - ], - }; this.setProjectorConfigTemplateJson( projectorConfigTemplate, projectorConfigTemplateJson ); - - // Set up optional field checkboxes. - const spriteFieldCheckbox = this.$$( - '#config-sprite-checkbox' - ) as HTMLInputElement; - spriteFieldCheckbox.onchange = () => { - if ((spriteFieldCheckbox as any).checked) { - projectorConfigTemplateJson.embeddings[0].sprite = { - imagePath: 'https://github.com/.../optional.sprite.png', - singleImageDim: [32, 32], - }; - } else { - delete projectorConfigTemplateJson.embeddings[0].sprite; - } - this.setProjectorConfigTemplateJson( - projectorConfigTemplate, - projectorConfigTemplateJson - ); - }; - const bookmarksFieldCheckbox = this.$$( - '#config-bookmarks-checkbox' - ) as HTMLInputElement; - bookmarksFieldCheckbox.onchange = () => { - if ((bookmarksFieldCheckbox as any).checked) { - projectorConfigTemplateJson.embeddings[0].bookmarksPath = - 'https://raw.githubusercontent.com/.../bookmarks.txt'; - } else { - delete projectorConfigTemplateJson.embeddings[0].bookmarksPath; - } - this.setProjectorConfigTemplateJson( - projectorConfigTemplate, - projectorConfigTemplateJson - ); - }; - const metadataFieldCheckbox = this.$$( - '#config-metadata-checkbox' - ) as HTMLInputElement; - metadataFieldCheckbox.onchange = () => { - if ((metadataFieldCheckbox as HTMLInputElement).checked) { - projectorConfigTemplateJson.embeddings[0].metadataPath = - 'https://raw.githubusercontent.com/.../optional.metadata.tsv'; - } else { - delete projectorConfigTemplateJson.embeddings[0].metadataPath; - } - this.setProjectorConfigTemplateJson( - projectorConfigTemplate, - projectorConfigTemplateJson - ); - }; - - // Update the link and the readonly shareable URL. - const projectorConfigUrlInput = this.$$( - '#projector-config-url' - ) as HTMLInputElement; - const projectorConfigDemoUrlInput = this.$$('#projector-share-url'); - const projectorConfigDemoUrlLink = this.$$('#projector-share-url-link'); - projectorConfigUrlInput.onchange = () => { - let projectorDemoUrl = - location.protocol + - '//' + - location.host + - location.pathname + - '?config=' + - (projectorConfigUrlInput as HTMLInputElement).value; - - (projectorConfigDemoUrlInput as HTMLInputElement).value = projectorDemoUrl; - (projectorConfigDemoUrlLink as HTMLLinkElement).href = projectorDemoUrl; - }; - } - - private setProjectorConfigTemplateJson( - projectorConfigTemplate: HTMLTextAreaElement, - config: ProjectorConfig - ) { - projectorConfigTemplate.value = JSON.stringify( - config, - null, - /** replacer */ 2 /** white space */ + }; + const bookmarksFieldCheckbox = this.$$( + '#config-bookmarks-checkbox' + ) as HTMLInputElement; + bookmarksFieldCheckbox.onchange = () => { + if ((bookmarksFieldCheckbox as any).checked) { + projectorConfigTemplateJson.embeddings[0].bookmarksPath = + 'https://raw.githubusercontent.com/.../bookmarks.txt'; + } else { + delete projectorConfigTemplateJson.embeddings[0].bookmarksPath; + } + this.setProjectorConfigTemplateJson( + projectorConfigTemplate, + projectorConfigTemplateJson ); - } - - _getNumTensorsLabel(): string { - return this.tensorNames.length === 1 - ? '1 tensor' - : this.tensorNames.length + ' tensors'; - } - - _getNumRunsLabel(): string { - return this.runNames.length === 1 - ? '1 run' - : this.runNames.length + ' runs'; - } - - _hasChoice(choices: any[]): boolean { - return choices.length > 0; - } - - _hasChoices(choices: any[]): boolean { - return choices.length > 1; - } - - _openDataDialog(): void { - this.$.dataDialog.open(); - } - - _openConfigDialog(): void { - this.$.projectorConfigDialog.open(); - } + }; + const metadataFieldCheckbox = this.$$( + '#config-metadata-checkbox' + ) as HTMLInputElement; + metadataFieldCheckbox.onchange = () => { + if ((metadataFieldCheckbox as HTMLInputElement).checked) { + projectorConfigTemplateJson.embeddings[0].metadataPath = + 'https://raw.githubusercontent.com/.../optional.metadata.tsv'; + } else { + delete projectorConfigTemplateJson.embeddings[0].metadataPath; + } + this.setProjectorConfigTemplateJson( + projectorConfigTemplate, + projectorConfigTemplateJson + ); + }; + // Update the link and the readonly shareable URL. + const projectorConfigUrlInput = this.$$( + '#projector-config-url' + ) as HTMLInputElement; + const projectorConfigDemoUrlInput = this.$$('#projector-share-url'); + const projectorConfigDemoUrlLink = this.$$('#projector-share-url-link'); + projectorConfigUrlInput.onchange = () => { + let projectorDemoUrl = + location.protocol + + '//' + + location.host + + location.pathname + + '?config=' + + (projectorConfigUrlInput as HTMLInputElement).value; + (projectorConfigDemoUrlInput as HTMLInputElement).value = projectorDemoUrl; + (projectorConfigDemoUrlLink as HTMLLinkElement).href = projectorDemoUrl; + }; } - - customElements.define(DataPanel.prototype.is, DataPanel); -} // namespace vz_projector + private setProjectorConfigTemplateJson( + projectorConfigTemplate: HTMLTextAreaElement, + config: ProjectorConfig + ) { + projectorConfigTemplate.value = JSON.stringify( + config, + null, + /** replacer */ 2 /** white space */ + ); + } + _getNumTensorsLabel(): string { + return this.tensorNames.length === 1 + ? '1 tensor' + : this.tensorNames.length + ' tensors'; + } + _getNumRunsLabel(): string { + return this.runNames.length === 1 + ? '1 run' + : this.runNames.length + ' runs'; + } + _hasChoice(choices: any[]): boolean { + return choices.length > 0; + } + _hasChoices(choices: any[]): boolean { + return choices.length > 1; + } + _openDataDialog(): void { + (this.$.dataDialog as any).open(); + } + _openConfigDialog(): void { + (this.$.projectorConfigDialog as any).open(); + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-input.html b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-input.html deleted file mode 100644 index e5a4a0f71e..0000000000 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-input.html +++ /dev/null @@ -1,71 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-input.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-input.ts index 540dd593bb..bbb343f844 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-input.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-input.ts @@ -12,103 +12,137 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - // tslint:disable-next-line - export let ProjectorInputPolymer = PolymerElement({ - is: 'vz-projector-input', - properties: {label: String, message: String}, - }); - - export interface InputChangedListener { - (value: string, inRegexMode: boolean): void; - } - - /** Input control with custom capabilities (e.g. regex). */ - export class ProjectorInput extends ProjectorInputPolymer { - private textChangedListeners: InputChangedListener[]; - private paperInput: HTMLInputElement; - private inRegexModeButton: HTMLButtonElement; - private inRegexMode: boolean; - - /** Message that will be displayed at the bottom of the input control. */ - message: string; - - /** Subscribe to be called everytime the input changes. */ - registerInputChangedListener(listener: InputChangedListener) { - this.textChangedListeners.push(listener); - } - - ready() { - super.ready(); - this.inRegexMode = false; - this.textChangedListeners = []; - this.paperInput = this.$$('paper-input') as HTMLInputElement; - this.inRegexModeButton = this.$$('paper-button') as HTMLButtonElement; - this.paperInput.setAttribute('error-message', 'Invalid regex'); - - this.paperInput.addEventListener('input', () => { - this.onTextChanged(); - }); - - this.paperInput.addEventListener('keydown', (event) => { - event.stopPropagation(); - }); - - this.inRegexModeButton.addEventListener('click', () => - this.onClickRegexModeButton() - ); - this.updateRegexModeDisplaySlashes(); - this.onTextChanged(); - } - - private onClickRegexModeButton() { - this.inRegexMode = (this.inRegexModeButton as any).active; - this.updateRegexModeDisplaySlashes(); - this.onTextChanged(); - } - - private notifyInputChanged(value: string, inRegexMode: boolean) { - this.textChangedListeners.forEach((l) => l(value, inRegexMode)); - } - - private onTextChanged() { - try { - if (this.inRegexMode) { - new RegExp(this.paperInput.value); - } - } catch (invalidRegexException) { - this.paperInput.setAttribute('invalid', 'true'); - this.message = ''; - this.notifyInputChanged(null, true); - return; +import {PolymerElement, html} from '@polymer/polymer'; +import {LegacyElementMixin} from '@polymer/polymer/lib/legacy/legacy-element-mixin'; +import {customElement, property} from '@polymer/decorators'; + +import '@polymer/paper-button'; +import '@polymer/paper-input/paper-input'; +import '@polymer/paper-tooltip'; + +import './styles'; + +export interface InputChangedListener { + (value: string, inRegexMode: boolean): void; +} + +@customElement('vz-projector-input') +export class ProjectorInput extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + + + +
/
+
/
+
+ .* +
+
+ + Enable/disable regex mode. + + [[message]] + `; + @property({type: String}) + label: string; + + /** Message that will be displayed at the bottom of the input control. */ + @property({type: String}) + message: string; + + private textChangedListeners: InputChangedListener[]; + private paperInput: HTMLInputElement; + private inRegexModeButton: HTMLButtonElement; + private inRegexMode: boolean; + + /** Subscribe to be called everytime the input changes. */ + registerInputChangedListener(listener: InputChangedListener) { + this.textChangedListeners.push(listener); + } + ready() { + super.ready(); + this.inRegexMode = false; + this.textChangedListeners = []; + this.paperInput = this.$$('paper-input') as HTMLInputElement; + this.inRegexModeButton = this.$$('paper-button') as HTMLButtonElement; + this.paperInput.setAttribute('error-message', 'Invalid regex'); + this.paperInput.addEventListener('input', () => { + this.onTextChanged(); + }); + this.paperInput.addEventListener('keydown', (event) => { + event.stopPropagation(); + }); + this.inRegexModeButton.addEventListener('click', () => + this.onClickRegexModeButton() + ); + this.updateRegexModeDisplaySlashes(); + this.onTextChanged(); + } + private onClickRegexModeButton() { + this.inRegexMode = (this.inRegexModeButton as any).active; + this.updateRegexModeDisplaySlashes(); + this.onTextChanged(); + } + private notifyInputChanged(value: string, inRegexMode: boolean) { + this.textChangedListeners.forEach((l) => l(value, inRegexMode)); + } + private onTextChanged() { + try { + if (this.inRegexMode) { + new RegExp(this.paperInput.value); + } + } catch (invalidRegexException) { + this.paperInput.setAttribute('invalid', 'true'); + this.message = ''; + this.notifyInputChanged(null, true); + return; } - - setValue(value: string, inRegexMode: boolean) { - (this.inRegexModeButton as any).active = inRegexMode; - this.paperInput.value = value; - this.onClickRegexModeButton(); + this.paperInput.removeAttribute('invalid'); + this.notifyInputChanged(this.paperInput.value, this.inRegexMode); + } + private updateRegexModeDisplaySlashes() { + const slashes = this.paperInput.querySelectorAll('.slash'); + const display = this.inRegexMode ? '' : 'none'; + for (let i = 0; i < slashes.length; i++) { + (slashes[i] as HTMLDivElement).style.display = display; } } - - customElements.define(ProjectorInput.prototype.is, ProjectorInput); -} // namespace vz_projector + getValue(): string { + return this.paperInput.value; + } + getInRegexMode(): boolean { + return this.inRegexMode; + } + setValue(value: string, inRegexMode: boolean) { + (this.inRegexModeButton as any).active = inRegexMode; + this.paperInput.value = value; + this.onClickRegexModeButton(); + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.html b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.html deleted file mode 100644 index a7ab5f66bc..0000000000 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.html +++ /dev/null @@ -1,351 +0,0 @@ - - - - - - - - - - - - - - - - - diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.html.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.html.ts new file mode 100644 index 0000000000..0fec5995ac --- /dev/null +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.html.ts @@ -0,0 +1,333 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {html} from '@polymer/polymer'; + +export const template = html` + + +
+
+ + + +
+
+ + + + + + +
+
+
+ + + +
+`; diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.ts index 1b2b13c120..cab902fa2d 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-inspector-panel.ts @@ -12,536 +12,473 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - /** Limit the number of search results we show to the user. */ - const LIMIT_RESULTS = 100; - const DEFAULT_NEIGHBORS = 100; - - // tslint:disable-next-line - export let InspectorPanelPolymer = PolymerElement({ - is: 'vz-projector-inspector-panel', - properties: { - selectedMetadataField: String, - metadataFields: Array, - metadataColumn: String, - numNN: {type: Number, value: DEFAULT_NEIGHBORS}, - updateNumNN: Object, - spriteMeta: Object, // type: `SpriteMetadata` - showNeighborImages: { - type: Boolean, - value: true, - observer: '_refreshNeighborsList', - }, - spriteImagesAvailable: { - type: Boolean, - value: true, - observer: '_refreshNeighborsList', - }, - }, - }); - - type SpriteMetadata = { - imagePath?: string; - singleImageDim?: number[]; - aspectRatio?: number; - nCols?: number; - }; - - export class InspectorPanel extends InspectorPanelPolymer { - distFunc: DistanceFunction; - numNN: number; - - private projectorEventContext: ProjectorEventContext; - - private selectedMetadataField: string; - private metadataFields: string[]; - private metadataColumn: string; - private spriteMeta: SpriteMetadata; - private displayContexts: string[]; - private projector: Projector; - private selectedPointIndices: number[]; - private neighborsOfFirstPoint: knn.NearestEntry[]; - private showNeighborImages: boolean; - private spriteImagesAvailable: boolean; - private searchBox: ProjectorInput; - - private resetFilterButton: HTMLButtonElement; - private setFilterButton: HTMLButtonElement; - private clearSelectionButton: HTMLButtonElement; - private limitMessage: HTMLDivElement; - - ready() { - super.ready(); - this.resetFilterButton = this.$$('.reset-filter') as HTMLButtonElement; - this.setFilterButton = this.$$('.set-filter') as HTMLButtonElement; - this.clearSelectionButton = this.$$( - '.clear-selection' - ) as HTMLButtonElement; - this.limitMessage = this.$$('.limit-msg') as HTMLDivElement; - this.searchBox = this.$$('#search-box') as ProjectorInput; - this.displayContexts = []; +import {PolymerElement} from '@polymer/polymer'; +import {LegacyElementMixin} from '@polymer/polymer/lib/legacy/legacy-element-mixin'; +import {customElement, observe, property} from '@polymer/decorators'; + +import '@polymer/paper-dropdown-menu/paper-dropdown-menu'; +import '@polymer/paper-icon-button'; +import '@polymer/paper-item'; +import '@polymer/paper-listbox'; +import '@polymer/paper-slider'; +import '@polymer/paper-tooltip'; + +import {DistanceFunction, SpriteAndMetadataInfo, State} from './data'; +import {template} from './vz-projector-inspector-panel.html'; +import {ProjectorInput} from './vz-projector-input'; +import './vz-projector-input'; +import {Projector} from './vz-projector'; +import {dist2color, normalizeDist} from './projectorScatterPlotAdapter'; +import {ProjectorEventContext} from './projectorEventContext'; +import * as knn from './knn'; +import * as vector from './vector'; +import * as util from './util'; + +const LIMIT_RESULTS = 100; +const DEFAULT_NEIGHBORS = 100; + +type SpriteMetadata = { + imagePath?: string; + singleImageDim?: number[]; + aspectRatio?: number; + nCols?: number; +}; + +@customElement('vz-projector-inspector-panel') +export class InspectorPanel extends LegacyElementMixin(PolymerElement) { + static readonly template = template; + + @property({type: String}) + selectedMetadataField: string; + + @property({type: Array}) + metadataFields: Array; + + @property({type: String}) + metadataColumn: string; + + @property({type: Number}) + numNN: number = DEFAULT_NEIGHBORS; + + @property({type: Object}) + spriteMeta: SpriteMetadata; + + @property({type: Boolean}) + showNeighborImages: boolean = true; + + @property({type: Boolean}) + spriteImagesAvailable: Boolean = true; + + distFunc: DistanceFunction; + private projectorEventContext: ProjectorEventContext; + private displayContexts: string[]; + private projector: Projector; + private selectedPointIndices: number[]; + private neighborsOfFirstPoint: knn.NearestEntry[]; + private searchBox: ProjectorInput; + private resetFilterButton: HTMLButtonElement; + private setFilterButton: HTMLButtonElement; + private clearSelectionButton: HTMLButtonElement; + private limitMessage: HTMLDivElement; + private _currentNeighbors: any; + + ready() { + super.ready(); + this.resetFilterButton = this.$$('.reset-filter') as HTMLButtonElement; + this.setFilterButton = this.$$('.set-filter') as HTMLButtonElement; + this.clearSelectionButton = this.$$( + '.clear-selection' + ) as HTMLButtonElement; + this.limitMessage = this.$$('.limit-msg') as HTMLDivElement; + this.searchBox = this.$$('#search-box') as ProjectorInput; + this.displayContexts = []; + } + initialize( + projector: Projector, + projectorEventContext: ProjectorEventContext + ) { + this.projector = projector; + this.projectorEventContext = projectorEventContext; + this.setupUI(projector); + projectorEventContext.registerSelectionChangedListener( + (selection, neighbors) => this.updateInspectorPane(selection, neighbors) + ); + } + /** Updates the nearest neighbors list in the inspector. */ + private updateInspectorPane( + indices: number[], + neighbors: knn.NearestEntry[] + ) { + this.neighborsOfFirstPoint = neighbors; + this.selectedPointIndices = indices; + this.updateFilterButtons(indices.length + neighbors.length); + this.updateNeighborsList(neighbors); + if (neighbors.length === 0) { + this.updateSearchResults(indices); + } else { + this.updateSearchResults([]); } - - initialize( - projector: Projector, - projectorEventContext: ProjectorEventContext + } + private enableResetFilterButton(enabled: boolean) { + this.resetFilterButton.disabled = !enabled; + } + restoreUIFromBookmark(bookmark: State) { + this.enableResetFilterButton(bookmark.filteredPoints != null); + } + metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) { + let labelIndex = -1; + this.metadataFields = spriteAndMetadata.stats.map((stats, i) => { + if (!stats.isNumeric && labelIndex === -1) { + labelIndex = i; + } + return stats.name; + }); + if ( + spriteAndMetadata.spriteMetadata && + spriteAndMetadata.spriteMetadata.imagePath ) { - this.projector = projector; - this.projectorEventContext = projectorEventContext; - this.setupUI(projector); - projectorEventContext.registerSelectionChangedListener( - (selection, neighbors) => this.updateInspectorPane(selection, neighbors) - ); + const [ + spriteWidth, + spriteHeight, + ] = spriteAndMetadata.spriteMetadata.singleImageDim; + this.spriteMeta = { + imagePath: spriteAndMetadata.spriteImage.src, + aspectRatio: spriteWidth / spriteHeight, + nCols: Math.floor(spriteAndMetadata.spriteImage.width / spriteWidth), + singleImageDim: [spriteWidth, spriteHeight], + }; + } else { + this.spriteMeta = {}; } - - /** Updates the nearest neighbors list in the inspector. */ - private updateInspectorPane( - indices: number[], - neighbors: knn.NearestEntry[] + this.spriteImagesAvailable = !!this.spriteMeta.imagePath; + if ( + this.selectedMetadataField == null || + this.metadataFields.filter((name) => name === this.selectedMetadataField) + .length === 0 ) { - this.neighborsOfFirstPoint = neighbors; - this.selectedPointIndices = indices; - - this.updateFilterButtons(indices.length + neighbors.length); - this.updateNeighborsList(neighbors); - if (neighbors.length === 0) { - this.updateSearchResults(indices); - } else { - this.updateSearchResults([]); - } - } - - private enableResetFilterButton(enabled: boolean) { - this.resetFilterButton.disabled = !enabled; - } - - restoreUIFromBookmark(bookmark: State) { - this.enableResetFilterButton(bookmark.filteredPoints != null); + // Make the default label the first non-numeric column. + this.selectedMetadataField = this.metadataFields[Math.max(0, labelIndex)]; } - - metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) { - let labelIndex = -1; - this.metadataFields = spriteAndMetadata.stats.map((stats, i) => { - if (!stats.isNumeric && labelIndex === -1) { - labelIndex = i; - } - return stats.name; - }); - - if ( - spriteAndMetadata.spriteMetadata && - spriteAndMetadata.spriteMetadata.imagePath - ) { - const [ - spriteWidth, - spriteHeight, - ] = spriteAndMetadata.spriteMetadata.singleImageDim; - - this.spriteMeta = { - imagePath: spriteAndMetadata.spriteImage.src, - aspectRatio: spriteWidth / spriteHeight, - nCols: Math.floor(spriteAndMetadata.spriteImage.width / spriteWidth), - singleImageDim: [spriteWidth, spriteHeight], - }; - } else { - this.spriteMeta = {}; - } - this.spriteImagesAvailable = !!this.spriteMeta.imagePath; - - if ( - this.selectedMetadataField == null || - this.metadataFields.filter( - (name) => name === this.selectedMetadataField - ).length === 0 - ) { - // Make the default label the first non-numeric column. - this.selectedMetadataField = this.metadataFields[ - Math.max(0, labelIndex) - ]; - } - this.updateInspectorPane( - this.selectedPointIndices, - this.neighborsOfFirstPoint - ); - } - - datasetChanged() { - this.enableResetFilterButton(false); + this.updateInspectorPane( + this.selectedPointIndices, + this.neighborsOfFirstPoint + ); + } + datasetChanged() { + this.enableResetFilterButton(false); + } + @observe('showNeighborImages', 'spriteImagesAvailable') + _refreshNeighborsList() { + this.updateNeighborsList(); + } + metadataEditorContext(enabled: boolean, metadataColumn: string) { + if (!this.projector || !this.projector.dataSet) { + return; } - - _refreshNeighborsList() { - this.updateNeighborsList(); + let stat = this.projector.dataSet.spriteAndMetadataInfo.stats.filter( + (s) => s.name === metadataColumn + ); + if (!enabled || stat.length === 0 || stat[0].tooManyUniqueValues) { + this.removeContext('.metadata-info'); + return; } - - metadataEditorContext(enabled: boolean, metadataColumn: string) { - if (!this.projector || !this.projector.dataSet) { - return; - } - - let stat = this.projector.dataSet.spriteAndMetadataInfo.stats.filter( - (s) => s.name === metadataColumn + this.metadataColumn = metadataColumn; + this.addContext('.metadata-info'); + let list = this.$$('.metadata-list') as HTMLDivElement; + list.innerHTML = ''; + let entries = stat[0].uniqueEntries.sort((a, b) => a.count - b.count); + let maxCount = entries[entries.length - 1].count; + entries.forEach((e) => { + const metadataElement = document.createElement('div'); + metadataElement.className = 'metadata'; + const metadataElementLink = document.createElement('a'); + metadataElementLink.className = 'metadata-link'; + metadataElementLink.title = e.label; + const labelValueElement = document.createElement('div'); + labelValueElement.className = 'label-and-value'; + const labelElement = document.createElement('div'); + labelElement.className = 'label'; + labelElement.style.color = dist2color(this.distFunc, maxCount, e.count); + labelElement.innerText = e.label; + const valueElement = document.createElement('div'); + valueElement.className = 'value'; + valueElement.innerText = e.count.toString(); + labelValueElement.appendChild(labelElement); + labelValueElement.appendChild(valueElement); + const barElement = document.createElement('div'); + barElement.className = 'bar'; + const barFillElement = document.createElement('div'); + barFillElement.className = 'fill'; + barFillElement.style.borderTopColor = dist2color( + this.distFunc, + maxCount, + e.count ); - if (!enabled || stat.length === 0 || stat[0].tooManyUniqueValues) { - this.removeContext('.metadata-info'); - return; + barFillElement.style.width = + normalizeDist(this.distFunc, maxCount, e.count) * 100 + '%'; + barElement.appendChild(barFillElement); + for (let j = 1; j < 4; j++) { + const tickElement = document.createElement('div'); + tickElement.className = 'tick'; + tickElement.style.left = (j * 100) / 4 + '%'; + barElement.appendChild(tickElement); } - - this.metadataColumn = metadataColumn; - this.addContext('.metadata-info'); - let list = this.$$('.metadata-list') as HTMLDivElement; - list.innerHTML = ''; - - let entries = stat[0].uniqueEntries.sort((a, b) => a.count - b.count); - let maxCount = entries[entries.length - 1].count; - - entries.forEach((e) => { - const metadataElement = document.createElement('div'); - metadataElement.className = 'metadata'; - - const metadataElementLink = document.createElement('a'); - metadataElementLink.className = 'metadata-link'; - metadataElementLink.title = e.label; - - const labelValueElement = document.createElement('div'); - labelValueElement.className = 'label-and-value'; - - const labelElement = document.createElement('div'); - labelElement.className = 'label'; - labelElement.style.color = dist2color(this.distFunc, maxCount, e.count); - labelElement.innerText = e.label; - - const valueElement = document.createElement('div'); - valueElement.className = 'value'; - valueElement.innerText = e.count.toString(); - - labelValueElement.appendChild(labelElement); - labelValueElement.appendChild(valueElement); - - const barElement = document.createElement('div'); - barElement.className = 'bar'; - - const barFillElement = document.createElement('div'); - barFillElement.className = 'fill'; - barFillElement.style.borderTopColor = dist2color( - this.distFunc, - maxCount, - e.count - ); - barFillElement.style.width = - normalizeDist(this.distFunc, maxCount, e.count) * 100 + '%'; - barElement.appendChild(barFillElement); - - for (let j = 1; j < 4; j++) { - const tickElement = document.createElement('div'); - tickElement.className = 'tick'; - tickElement.style.left = (j * 100) / 4 + '%'; - barElement.appendChild(tickElement); - } - - metadataElementLink.appendChild(labelValueElement); - metadataElementLink.appendChild(barElement); - metadataElement.appendChild(metadataElementLink); - list.appendChild(metadataElement); - - metadataElementLink.onclick = () => { - this.projector.metadataEdit(metadataColumn, e.label); - }; - }); + metadataElementLink.appendChild(labelValueElement); + metadataElementLink.appendChild(barElement); + metadataElement.appendChild(metadataElementLink); + list.appendChild(metadataElement); + metadataElementLink.onclick = () => { + this.projector.metadataEdit(metadataColumn, e.label); + }; + }); + } + private addContext(context: string) { + if (this.displayContexts.indexOf(context) === -1) { + this.displayContexts.push(context); } - - private addContext(context: string) { - if (this.displayContexts.indexOf(context) === -1) { - this.displayContexts.push(context); - } - this.displayContexts.forEach((c) => { - (this.$$(c) as HTMLDivElement).style.display = 'none'; - }); - (this.$$(context) as HTMLDivElement).style.display = null; + this.displayContexts.forEach((c) => { + (this.$$(c) as HTMLDivElement).style.display = 'none'; + }); + (this.$$(context) as HTMLDivElement).style.display = null; + } + private removeContext(context: string) { + this.displayContexts = this.displayContexts.filter((c) => c !== context); + (this.$$(context) as HTMLDivElement).style.display = 'none'; + if (this.displayContexts.length > 0) { + let lastContext = this.displayContexts[this.displayContexts.length - 1]; + (this.$$(lastContext) as HTMLDivElement).style.display = null; } - - private removeContext(context: string) { - this.displayContexts = this.displayContexts.filter((c) => c !== context); - (this.$$(context) as HTMLDivElement).style.display = 'none'; - - if (this.displayContexts.length > 0) { - let lastContext = this.displayContexts[this.displayContexts.length - 1]; - (this.$$(lastContext) as HTMLDivElement).style.display = null; - } + } + private updateSearchResults(indices: number[]) { + const container = this.$$('.matches-list') as HTMLDivElement; + const list = container.querySelector('.list') as HTMLDivElement; + list.innerHTML = ''; + if (indices.length === 0) { + this.removeContext('.matches-list'); + return; } - - private updateSearchResults(indices: number[]) { - const container = this.$$('.matches-list') as HTMLDivElement; - const list = container.querySelector('.list') as HTMLDivElement; - list.innerHTML = ''; - if (indices.length === 0) { - this.removeContext('.matches-list'); - return; - } - this.addContext('.matches-list'); - - this.limitMessage.style.display = - indices.length <= LIMIT_RESULTS ? 'none' : null; - indices = indices.slice(0, LIMIT_RESULTS); - - for (let i = 0; i < indices.length; i++) { - const index = indices[i]; - - const row = document.createElement('div'); - row.className = 'row'; - - const label = this.getLabelFromIndex(index); - const rowLink = document.createElement('a'); - rowLink.className = 'label'; - rowLink.title = label; - rowLink.innerText = label; - - rowLink.onmouseenter = () => { - this.projectorEventContext.notifyHoverOverPoint(index); - }; - rowLink.onmouseleave = () => { - this.projectorEventContext.notifyHoverOverPoint(null); - }; - rowLink.onclick = () => { - this.projectorEventContext.notifySelectionChanged([index]); - }; - - row.appendChild(rowLink); - list.appendChild(row); - } + this.addContext('.matches-list'); + this.limitMessage.style.display = + indices.length <= LIMIT_RESULTS ? 'none' : null; + indices = indices.slice(0, LIMIT_RESULTS); + for (let i = 0; i < indices.length; i++) { + const index = indices[i]; + const row = document.createElement('div'); + row.className = 'row'; + const label = this.getLabelFromIndex(index); + const rowLink = document.createElement('a'); + rowLink.className = 'label'; + rowLink.title = label; + rowLink.innerText = label; + rowLink.onmouseenter = () => { + this.projectorEventContext.notifyHoverOverPoint(index); + }; + rowLink.onmouseleave = () => { + this.projectorEventContext.notifyHoverOverPoint(null); + }; + rowLink.onclick = () => { + this.projectorEventContext.notifySelectionChanged([index]); + }; + row.appendChild(rowLink); + list.appendChild(row); } - - private getLabelFromIndex(pointIndex: number): string { - const metadata = this.projector.dataSet.points[pointIndex].metadata[ - this.selectedMetadataField + } + private getLabelFromIndex(pointIndex: number): string { + const metadata = this.projector.dataSet.points[pointIndex].metadata[ + this.selectedMetadataField + ]; + return metadata !== undefined ? String(metadata) : `Unknown #${pointIndex}`; + } + private spriteImageRenderer() { + const spriteImagePath = this.spriteMeta.imagePath; + const {aspectRatio, nCols} = this.spriteMeta as any; + const paddingBottom = 100 / aspectRatio + '%'; + const backgroundSize = `${nCols * 100}% ${nCols * 100}%`; + const backgroundImage = `url(${CSS.escape(spriteImagePath)})`; + return (neighbor: knn.NearestEntry): HTMLElement => { + const spriteElementImage = document.createElement('div'); + spriteElementImage.className = 'sprite-image'; + spriteElementImage.style.backgroundImage = backgroundImage; + spriteElementImage.style.paddingBottom = paddingBottom; + spriteElementImage.style.backgroundSize = backgroundSize; + const [row, col] = [ + Math.floor(neighbor.index / nCols), + neighbor.index % nCols, + ]; + const [top, left] = [ + (row / (nCols - 1)) * 100, + (col / (nCols - 1)) * 100, ]; - return metadata !== undefined - ? String(metadata) - : `Unknown #${pointIndex}`; + spriteElementImage.style.backgroundPosition = `${left}% ${top}%`; + return spriteElementImage; + }; + } + private updateNeighborsList(neighbors?: knn.NearestEntry[]) { + neighbors = neighbors || this._currentNeighbors; + this._currentNeighbors = neighbors; + if (neighbors == null) { + return; } - - private spriteImageRenderer() { - const spriteImagePath = this.spriteMeta.imagePath; - const {aspectRatio, nCols} = this.spriteMeta; - const paddingBottom = 100 / aspectRatio + '%'; - const backgroundSize = `${nCols * 100}% ${nCols * 100}%`; - const backgroundImage = `url(${CSS.escape(spriteImagePath)})`; - - return (neighbor: knn.NearestEntry): HTMLElement => { - const spriteElementImage = document.createElement('div'); - spriteElementImage.className = 'sprite-image'; - spriteElementImage.style.backgroundImage = backgroundImage; - spriteElementImage.style.paddingBottom = paddingBottom; - spriteElementImage.style.backgroundSize = backgroundSize; - const [row, col] = [ - Math.floor(neighbor.index / nCols), - neighbor.index % nCols, - ]; - const [top, left] = [ - (row / (nCols - 1)) * 100, - (col / (nCols - 1)) * 100, - ]; - spriteElementImage.style.backgroundPosition = `${left}% ${top}%`; - - return spriteElementImage; - }; + const nnlist = this.$$('.nn-list') as HTMLDivElement; + nnlist.innerHTML = ''; + if (neighbors.length === 0) { + this.removeContext('.nn'); + return; } - - private updateNeighborsList(neighbors?: knn.NearestEntry[]) { - neighbors = neighbors || this._currentNeighbors; - this._currentNeighbors = neighbors; - if (neighbors == null) { - return; - } - - const nnlist = this.$$('.nn-list') as HTMLDivElement; - nnlist.innerHTML = ''; - - if (neighbors.length === 0) { - this.removeContext('.nn'); - return; + this.addContext('.nn'); + this.searchBox.message = ''; + const minDist = neighbors.length > 0 ? neighbors[0].dist : 0; + if (this.spriteImagesAvailable && this.showNeighborImages) { + var imageRenderer = this.spriteImageRenderer(); + } + for (let i = 0; i < neighbors.length; i++) { + const neighbor = neighbors[i]; + const neighborElement = document.createElement('div'); + neighborElement.className = 'neighbor'; + const neighborElementLink = document.createElement('a'); + neighborElementLink.className = 'neighbor-link'; + neighborElementLink.title = this.getLabelFromIndex(neighbor.index); + const labelValueElement = document.createElement('div'); + labelValueElement.className = 'label-and-value'; + const labelElement = document.createElement('div'); + labelElement.className = 'label'; + labelElement.style.color = dist2color( + this.distFunc, + neighbor.dist, + minDist + ); + labelElement.innerText = this.getLabelFromIndex(neighbor.index); + const valueElement = document.createElement('div'); + valueElement.className = 'value'; + valueElement.innerText = neighbor.dist.toFixed(3); + labelValueElement.appendChild(labelElement); + labelValueElement.appendChild(valueElement); + const barElement = document.createElement('div'); + barElement.className = 'bar'; + const barFillElement = document.createElement('div'); + barFillElement.className = 'fill'; + barFillElement.style.borderTopColor = dist2color( + this.distFunc, + neighbor.dist, + minDist + ); + barFillElement.style.width = + normalizeDist(this.distFunc, neighbor.dist, minDist) * 100 + '%'; + barElement.appendChild(barFillElement); + for (let j = 1; j < 4; j++) { + const tickElement = document.createElement('div'); + tickElement.className = 'tick'; + tickElement.style.left = (j * 100) / 4 + '%'; + barElement.appendChild(tickElement); } - this.addContext('.nn'); - - this.searchBox.message = ''; - const minDist = neighbors.length > 0 ? neighbors[0].dist : 0; - if (this.spriteImagesAvailable && this.showNeighborImages) { - var imageRenderer = this.spriteImageRenderer(); + const neighborElementImage = imageRenderer(neighbor); + neighborElement.appendChild(neighborElementImage); } - - for (let i = 0; i < neighbors.length; i++) { - const neighbor = neighbors[i]; - - const neighborElement = document.createElement('div'); - neighborElement.className = 'neighbor'; - - const neighborElementLink = document.createElement('a'); - neighborElementLink.className = 'neighbor-link'; - neighborElementLink.title = this.getLabelFromIndex(neighbor.index); - - const labelValueElement = document.createElement('div'); - labelValueElement.className = 'label-and-value'; - - const labelElement = document.createElement('div'); - labelElement.className = 'label'; - labelElement.style.color = dist2color( - this.distFunc, - neighbor.dist, - minDist - ); - labelElement.innerText = this.getLabelFromIndex(neighbor.index); - - const valueElement = document.createElement('div'); - valueElement.className = 'value'; - valueElement.innerText = neighbor.dist.toFixed(3); - - labelValueElement.appendChild(labelElement); - labelValueElement.appendChild(valueElement); - - const barElement = document.createElement('div'); - barElement.className = 'bar'; - - const barFillElement = document.createElement('div'); - barFillElement.className = 'fill'; - barFillElement.style.borderTopColor = dist2color( - this.distFunc, - neighbor.dist, - minDist - ); - barFillElement.style.width = - normalizeDist(this.distFunc, neighbor.dist, minDist) * 100 + '%'; - barElement.appendChild(barFillElement); - - for (let j = 1; j < 4; j++) { - const tickElement = document.createElement('div'); - tickElement.className = 'tick'; - tickElement.style.left = (j * 100) / 4 + '%'; - barElement.appendChild(tickElement); - } - - if (this.spriteImagesAvailable && this.showNeighborImages) { - const neighborElementImage = imageRenderer(neighbor); - neighborElement.appendChild(neighborElementImage); - } - - neighborElementLink.appendChild(labelValueElement); - neighborElementLink.appendChild(barElement); - neighborElement.appendChild(neighborElementLink); - nnlist.appendChild(neighborElement); - - neighborElementLink.onmouseenter = () => { - this.projectorEventContext.notifyHoverOverPoint(neighbor.index); - }; - neighborElementLink.onmouseleave = () => { - this.projectorEventContext.notifyHoverOverPoint(null); - }; - neighborElementLink.onclick = () => { - this.projectorEventContext.notifySelectionChanged([neighbor.index]); - }; - } - } - - private updateFilterButtons(numPoints: number) { - if (numPoints > 1) { - this.setFilterButton.innerText = `Isolate ${numPoints} points`; - this.setFilterButton.disabled = null; - this.clearSelectionButton.disabled = null; - } else { - this.setFilterButton.disabled = true; - this.clearSelectionButton.disabled = true; - } - } - - private setupUI(projector: Projector) { - this.distFunc = vector.cosDist; - const eucDist = this.$$('.distance a.euclidean') as HTMLLinkElement; - eucDist.onclick = () => { - const links = this.root.querySelectorAll('.distance a'); - for (let i = 0; i < links.length; i++) { - util.classed(links[i] as HTMLElement, 'selected', false); - } - util.classed(eucDist as HTMLElement, 'selected', true); - - this.distFunc = vector.dist; - this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); - const neighbors = projector.dataSet.findNeighbors( - this.selectedPointIndices[0], - this.distFunc, - this.numNN - ); - this.updateNeighborsList(neighbors); + neighborElementLink.appendChild(labelValueElement); + neighborElementLink.appendChild(barElement); + neighborElement.appendChild(neighborElementLink); + nnlist.appendChild(neighborElement); + neighborElementLink.onmouseenter = () => { + this.projectorEventContext.notifyHoverOverPoint(neighbor.index); }; - - const cosDist = this.$$('.distance a.cosine') as HTMLLinkElement; - cosDist.onclick = () => { - const links = this.root.querySelectorAll('.distance a'); - for (let i = 0; i < links.length; i++) { - util.classed(links[i] as HTMLElement, 'selected', false); - } - util.classed(cosDist, 'selected', true); - - this.distFunc = vector.cosDist; - this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); - const neighbors = projector.dataSet.findNeighbors( - this.selectedPointIndices[0], - this.distFunc, - this.numNN - ); - this.updateNeighborsList(neighbors); + neighborElementLink.onmouseleave = () => { + this.projectorEventContext.notifyHoverOverPoint(null); }; - - // Called whenever the search text input changes. - const updateInput = (value: string, inRegexMode: boolean) => { - if (value == null || value.trim() === '') { - this.searchBox.message = ''; - this.projectorEventContext.notifySelectionChanged([]); - return; - } - const indices = projector.dataSet.query( - value, - inRegexMode, - this.selectedMetadataField - ); - if (indices.length === 0) { - this.searchBox.message = '0 matches.'; - } else { - this.searchBox.message = `${indices.length} matches.`; - } - this.projectorEventContext.notifySelectionChanged(indices); + neighborElementLink.onclick = () => { + this.projectorEventContext.notifySelectionChanged([neighbor.index]); }; - this.searchBox.registerInputChangedListener((value, inRegexMode) => { - updateInput(value, inRegexMode); - }); - - // Filtering dataset. - this.setFilterButton.onclick = () => { - const indices = this.selectedPointIndices.concat( - this.neighborsOfFirstPoint.map((n) => n.index) - ); - projector.filterDataset(indices); - this.enableResetFilterButton(true); - this.updateFilterButtons(0); - }; - - this.resetFilterButton.onclick = () => { - projector.resetFilterDataset(); - this.enableResetFilterButton(false); - }; - - this.clearSelectionButton.onclick = () => { - projector.adjustSelectionAndHover([]); - }; - this.enableResetFilterButton(false); } - - private updateNumNN() { - if (this.selectedPointIndices != null) { - this.projectorEventContext.notifySelectionChanged([ - this.selectedPointIndices[0], - ]); + } + private updateFilterButtons(numPoints: number) { + if (numPoints > 1) { + this.setFilterButton.innerText = `Isolate ${numPoints} points`; + this.setFilterButton.disabled = null; + this.clearSelectionButton.disabled = null; + } else { + this.setFilterButton.disabled = true; + this.clearSelectionButton.disabled = true; + } + } + private setupUI(projector: Projector) { + this.distFunc = vector.cosDist; + const eucDist = this.$$('.distance a.euclidean') as HTMLLinkElement; + eucDist.onclick = () => { + const links = this.root.querySelectorAll('.distance a'); + for (let i = 0; i < links.length; i++) { + util.classed(links[i] as HTMLElement, 'selected', false); + } + util.classed(eucDist as HTMLElement, 'selected', true); + this.distFunc = vector.dist; + this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); + const neighbors = projector.dataSet.findNeighbors( + this.selectedPointIndices[0], + this.distFunc, + this.numNN + ); + this.updateNeighborsList(neighbors); + }; + const cosDist = this.$$('.distance a.cosine') as HTMLLinkElement; + cosDist.onclick = () => { + const links = this.root.querySelectorAll('.distance a'); + for (let i = 0; i < links.length; i++) { + util.classed(links[i] as HTMLElement, 'selected', false); + } + util.classed(cosDist, 'selected', true); + this.distFunc = vector.cosDist; + this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); + const neighbors = projector.dataSet.findNeighbors( + this.selectedPointIndices[0], + this.distFunc, + this.numNN + ); + this.updateNeighborsList(neighbors); + }; + // Called whenever the search text input changes. + const updateInput = (value: string, inRegexMode: boolean) => { + if (value == null || value.trim() === '') { + this.searchBox.message = ''; + this.projectorEventContext.notifySelectionChanged([]); + return; } + const indices = projector.dataSet.query( + value, + inRegexMode, + this.selectedMetadataField + ); + if (indices.length === 0) { + this.searchBox.message = '0 matches.'; + } else { + this.searchBox.message = `${indices.length} matches.`; + } + this.projectorEventContext.notifySelectionChanged(indices); + }; + this.searchBox.registerInputChangedListener((value, inRegexMode) => { + updateInput(value, inRegexMode); + }); + // Filtering dataset. + this.setFilterButton.onclick = () => { + const indices = this.selectedPointIndices.concat( + this.neighborsOfFirstPoint.map((n) => n.index) + ); + projector.filterDataset(indices); + this.enableResetFilterButton(true); + this.updateFilterButtons(0); + }; + this.resetFilterButton.onclick = () => { + projector.resetFilterDataset(); + this.enableResetFilterButton(false); + }; + this.clearSelectionButton.onclick = () => { + projector.adjustSelectionAndHover([]); + }; + this.enableResetFilterButton(false); + } + private updateNumNN() { + if (this.selectedPointIndices != null) { + this.projectorEventContext.notifySelectionChanged([ + this.selectedPointIndices[0], + ]); } } - - customElements.define(InspectorPanel.prototype.is, InspectorPanel); -} // namespace vz_projector +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-legend.html b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-legend.html deleted file mode 100644 index b738e5132f..0000000000 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-legend.html +++ /dev/null @@ -1,84 +0,0 @@ - - - - - - - - - - diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-legend.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-legend.ts index 0d217ca773..62b68f06e2 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-legend.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-legend.ts @@ -12,86 +12,135 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - // tslint:disable-next-line - export let LegendPolymer = PolymerElement({ - is: 'vz-projector-legend', - properties: {renderInfo: {type: Object, observer: '_renderInfoChanged'}}, - }); - export interface ColorLegendRenderInfo { - // To be used for categorical map. - items: ColorLegendItem[]; - // To be used for gradient map. - thresholds: ColorLegendThreshold[]; - } - - /** An item in the categorical color legend. */ - export interface ColorLegendItem { - color: string; - label: string; - count: number; - } - - /** An item in the gradient color legend. */ - export interface ColorLegendThreshold { - color: string; - value: number; - } +import {PolymerElement, html} from '@polymer/polymer'; +import {LegacyElementMixin} from '@polymer/polymer/lib/legacy/legacy-element-mixin'; +import {customElement, observe, property} from '@polymer/decorators'; + +import './styles'; + +export interface ColorLegendRenderInfo { + // To be used for categorical map. + items: ColorLegendItem[]; + // To be used for gradient map. + thresholds: ColorLegendThreshold[]; +} +/** An item in the categorical color legend. */ +export interface ColorLegendItem { + color: string; + label: string; + count: number; +} +/** An item in the gradient color legend. */ +export interface ColorLegendThreshold { + color: string; + value: number; +} + +@customElement('vz-projector-legend') +export class Legend extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + + + + + + `; + @property({type: Object}) + renderInfo: ColorLegendRenderInfo; + + @observe('renderInfo') + _renderInfoChanged() { + if (this.renderInfo == null) { + return; } - - private getOffset(value: number): string { - const min = this.renderInfo.thresholds[0].value; - const max = this.renderInfo.thresholds[ - this.renderInfo.thresholds.length - 1 - ].value; - return ((100 * (value - min)) / (max - min)).toFixed(2) + '%'; + if (this.renderInfo.thresholds) { + // is under dom-if so we should wait for it to be + // inserted in the dom tree using async(). + this.async(() => this.setupLinearGradient()); } - - private setupLinearGradient() { - const linearGradient = this.$$('#gradient') as SVGLinearGradientElement; - - const width = (this.$$('svg.gradient') as SVGElement).clientWidth; - - // Set the svg to be the width of its parent. - (this.$$('svg.gradient rect') as SVGRectElement).style.width = - width + 'px'; - - // Remove all children from before. - linearGradient.innerHTML = ''; - - // Add a child in for each gradient threshold. - this.renderInfo.thresholds.forEach((t) => { - const stopElement = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'stop' - ); - stopElement.setAttribute('offset', this.getOffset(t.value)); - stopElement.setAttribute('stop-color', t.color); - }); + } + _getLastThreshold(): number { + if (this.renderInfo == null || this.renderInfo.thresholds == null) { + return; } + return this.renderInfo.thresholds[this.renderInfo.thresholds.length - 1] + .value; } - - customElements.define(Legend.prototype.is, Legend); -} // namespace vz_projector + private getOffset(value: number): string { + const min = this.renderInfo.thresholds[0].value; + const max = this.renderInfo.thresholds[ + this.renderInfo.thresholds.length - 1 + ].value; + return ((100 * (value - min)) / (max - min)).toFixed(2) + '%'; + } + private setupLinearGradient() { + const linearGradient = this.$$('#gradient') as SVGLinearGradientElement; + const width = (this.$$('svg.gradient') as SVGElement).clientWidth; + // Set the svg to be the width of its parent. + (this.$$('svg.gradient rect') as SVGRectElement).style.width = width + 'px'; + // Remove all children from before. + linearGradient.innerHTML = ''; + // Add a child in for each gradient threshold. + this.renderInfo.thresholds.forEach((t) => { + const stopElement = document.createElementNS( + 'http://www.w3.org/2000/svg', + 'stop' + ); + stopElement.setAttribute('offset', this.getOffset(t.value)); + stopElement.setAttribute('stop-color', t.color); + }); + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-metadata-card.html b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-metadata-card.html deleted file mode 100644 index b6545a1d9c..0000000000 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-metadata-card.html +++ /dev/null @@ -1,104 +0,0 @@ - - - - - - - - -
-
- - -
- - -
- -
-
-
- - - - diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-metadata-card.ts b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-metadata-card.ts index 4f96c71baf..38e1c038e9 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-metadata-card.ts +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-metadata-card.ts @@ -12,64 +12,144 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -namespace vz_projector { - // tslint:disable-next-line - export let MetadataCardPolymer = PolymerElement({ - is: 'vz-projector-metadata-card', - properties: { - hasMetadata: {type: Boolean, value: false}, - isCollapsed: {type: Boolean, value: false}, - collapseIcon: {type: String, value: 'expand-less'}, - metadata: {type: Array}, - label: String, - }, - }); - - export class MetadataCard extends MetadataCardPolymer { - hasMetadata: boolean; - isCollapsed: boolean; - collapseIcon: string; - metadata: Array<{key: string; value: string}>; - label: string; - - private labelOption: string; - private pointMetadata: PointMetadata; - - /** Handles toggle of metadata-container. */ - _toggleMetadataContainer() { - (this.$$('#metadata-container') as any).toggle(); - this.isCollapsed = !this.isCollapsed; - this.set( - 'collapseIcon', - this.isCollapsed ? 'expand-more' : 'expand-less' - ); - } +import {PolymerElement, html} from '@polymer/polymer'; +import {LegacyElementMixin} from '@polymer/polymer/lib/legacy/legacy-element-mixin'; +import {customElement, property} from '@polymer/decorators'; - updateMetadata(pointMetadata?: PointMetadata) { - this.pointMetadata = pointMetadata; - this.hasMetadata = pointMetadata != null; - - if (pointMetadata) { - let metadata = []; - for (let metadataKey in pointMetadata) { - if (!pointMetadata.hasOwnProperty(metadataKey)) { - continue; - } - metadata.push({key: metadataKey, value: pointMetadata[metadataKey]}); - } +import '@polymer/iron-collapse'; +import '@polymer/paper-icon-button'; +import '@polymer/iron-collapse'; +import '@polymer/paper-icon-button'; - this.metadata = metadata; - this.label = '' + this.pointMetadata[this.labelOption]; +import {PointMetadata} from './data'; + +@customElement('vz-projector-metadata-card') +export class MetadataCard extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + + + `; + + @property({type: Boolean}) + hasMetadata: boolean = false; + + @property({type: Boolean}) + isCollapsed: boolean = false; + + @property({type: String}) + collapseIcon: string = 'expand-less'; + + @property({type: Array}) + metadata: Array<{ + key: string; + value: string; + }>; + + @property({type: String}) + label: string; + + private labelOption: string; + private pointMetadata: PointMetadata; + /** Handles toggle of metadata-container. */ + _toggleMetadataContainer() { + (this.$$('#metadata-container') as any).toggle(); + this.isCollapsed = !this.isCollapsed; + this.set('collapseIcon', this.isCollapsed ? 'expand-more' : 'expand-less'); + } + updateMetadata(pointMetadata?: PointMetadata) { + this.pointMetadata = pointMetadata; + this.hasMetadata = pointMetadata != null; + if (pointMetadata) { + let metadata = []; + for (let metadataKey in pointMetadata) { + if (!pointMetadata.hasOwnProperty(metadataKey)) { + continue; + } + metadata.push({key: metadataKey, value: pointMetadata[metadataKey]}); + } + this.metadata = metadata; + this.label = '' + this.pointMetadata[this.labelOption]; } } - - customElements.define(MetadataCard.prototype.is, MetadataCard); -} // namespace vz_projector + setLabelOption(labelOption: string) { + this.labelOption = labelOption; + if (this.pointMetadata) { + this.label = '' + this.pointMetadata[this.labelOption]; + } + } +} diff --git a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-projections-panel.html b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-projections-panel.html.ts similarity index 93% rename from tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-projections-panel.html rename to tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-projections-panel.html.ts index 97f14eed1b..189960e546 100644 --- a/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-projections-panel.html +++ b/tensorboard/plugins/projector/polymer3/vz_projector/vz-projector-projections-panel.html.ts @@ -1,6 +1,4 @@ - - - - - - - - - - - - - - - - - - - - - -