Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(scatterplot) support multiple legends #2016

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions packages/scatterplot/src/ScatterPlot.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { createElement, Fragment, ReactNode, useMemo } from 'react'
import { SvgWrapper, Container, useDimensions, CartesianMarkers } from '@nivo/core'
import { Axes, Grid } from '@nivo/axes'
import { BoxLegendSvg } from '@nivo/legends'
import { useScatterPlot } from './hooks'
import { svgDefaultProps } from './props'
import { ScatterPlotAnnotations } from './ScatterPlotAnnotations'
import { Nodes } from './Nodes'
import { Mesh } from './Mesh'
import { ScatterPlotDatum, ScatterPlotLayerId, ScatterPlotSvgProps } from './types'
import { ScatterPlotLegends } from './ScatterPlotLegends'

type InnerScatterPlotProps<RawDatum extends ScatterPlotDatum> = Omit<
ScatterPlotSvgProps<RawDatum>,
Expand Down Expand Up @@ -39,6 +39,7 @@ const InnerScatterPlot = <RawDatum extends ScatterPlotDatum>({
axisLeft = svgDefaultProps.axisLeft,
annotations = svgDefaultProps.annotations,
isInteractive = svgDefaultProps.isInteractive,
initialHiddenIds = [],
useMesh = svgDefaultProps.useMesh,
debugMesh = svgDefaultProps.debugMesh,
onMouseEnter,
Expand All @@ -48,6 +49,7 @@ const InnerScatterPlot = <RawDatum extends ScatterPlotDatum>({
tooltip = svgDefaultProps.tooltip,
markers = svgDefaultProps.markers,
legends = svgDefaultProps.legends,
legendLabel,
role = svgDefaultProps.role,
ariaLabel,
ariaLabelledBy,
Expand All @@ -59,7 +61,7 @@ const InnerScatterPlot = <RawDatum extends ScatterPlotDatum>({
partialMargin
)

const { xScale, yScale, nodes, legendData } = useScatterPlot<RawDatum>({
const { xScale, yScale, nodes, legendsData, toggleSerie } = useScatterPlot<RawDatum>({
data,
xScaleSpec,
xFormat,
Expand All @@ -69,7 +71,10 @@ const InnerScatterPlot = <RawDatum extends ScatterPlotDatum>({
height: innerHeight,
nodeId,
nodeSize,
initialHiddenIds,
colors,
legends,
legendLabel,
})

const customLayerProps = useMemo(
Expand Down Expand Up @@ -184,15 +189,15 @@ const InnerScatterPlot = <RawDatum extends ScatterPlotDatum>({
}

if (layers.includes('legends')) {
layerById.legends = legends.map((legend, i) => (
<BoxLegendSvg
key={i}
{...legend}
containerWidth={innerWidth}
containerHeight={innerHeight}
data={legendData}
layerById.legends = (
<ScatterPlotLegends
key="legends"
width={innerWidth}
height={innerHeight}
legends={legendsData}
toggleSerie={toggleSerie}
/>
))
)
}

return (
Expand Down
30 changes: 19 additions & 11 deletions packages/scatterplot/src/ScatterPlotCanvas.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const InnerScatterPlotCanvas = <RawDatum extends ScatterPlotDatum>({
onClick,
tooltip = canvasDefaultProps.tooltip,
legends = canvasDefaultProps.legends,
legendLabel,
canvasRef,
}: InnerScatterPlotCanvasProps<RawDatum>) => {
const canvasEl = useRef<HTMLCanvasElement | null>(null)
Expand All @@ -69,7 +70,7 @@ const InnerScatterPlotCanvas = <RawDatum extends ScatterPlotDatum>({
partialMargin
)

const { xScale, yScale, nodes, legendData } = useScatterPlot<RawDatum>({
const { xScale, yScale, nodes, legendsData } = useScatterPlot<RawDatum>({
data,
xScaleSpec,
xFormat,
Expand All @@ -79,7 +80,10 @@ const InnerScatterPlotCanvas = <RawDatum extends ScatterPlotDatum>({
height: innerHeight,
nodeId,
nodeSize,
initialHiddenIds: Array<string>(),
colors,
legends,
legendLabel,
})

const boundAnnotations = useScatterPlotAnnotations<RawDatum>(nodes, annotations)
Expand Down Expand Up @@ -111,7 +115,8 @@ const InnerScatterPlotCanvas = <RawDatum extends ScatterPlotDatum>({
canvasEl.current.width = outerWidth * pixelRatio
canvasEl.current.height = outerHeight * pixelRatio

const ctx = canvasEl.current.getContext('2d')!
const ctx = canvasEl.current.getContext('2d')
if (!ctx) return

ctx.scale(pixelRatio, pixelRatio)

Expand Down Expand Up @@ -163,22 +168,22 @@ const InnerScatterPlotCanvas = <RawDatum extends ScatterPlotDatum>({
renderNode(ctx, node)
})
} else if (layer === 'mesh') {
if (debugMesh) {
renderVoronoiToCanvas(ctx, voronoi!)
if (debugMesh && voronoi) {
renderVoronoiToCanvas(ctx, voronoi)
if (currentNode) {
renderVoronoiCellToCanvas(ctx, voronoi!, currentNode.index)
renderVoronoiCellToCanvas(ctx, voronoi, currentNode.index)
}
}
} else if (layer === 'legends') {
legends.forEach(legend => {
legendsData.forEach(([legend, data]) =>
renderLegendToCanvas(ctx, {
...legend,
data: legendData,
data: legend.data ?? data,
containerWidth: innerWidth,
containerHeight: innerHeight,
theme,
})
})
)
} else if (typeof layer === 'function') {
layer(ctx, customLayerProps)
} else {
Expand All @@ -203,12 +208,14 @@ const InnerScatterPlotCanvas = <RawDatum extends ScatterPlotDatum>({
nodes,
enableGridX,
enableGridY,
gridXValues,
gridYValues,
axisTop,
axisRight,
axisBottom,
axisLeft,
legends,
legendData,
legendsData,
debugMesh,
voronoi,
currentNode,
Expand All @@ -219,13 +226,14 @@ const InnerScatterPlotCanvas = <RawDatum extends ScatterPlotDatum>({

const getNodeFromMouseEvent = useCallback(
event => {
const [x, y] = getRelativeCursor(canvasEl.current!, event)
if (!canvasEl.current) return null
const [x, y] = getRelativeCursor(canvasEl.current, event)
if (!isCursorInRect(margin.left, margin.top, innerWidth, innerHeight, x, y)) return null

const nodeIndex = delaunay.find(x - margin.left, y - margin.top)
return nodes[nodeIndex]
},
[canvasEl, margin, innerWidth, innerHeight, delaunay]
[nodes, canvasEl, margin, innerWidth, innerHeight, delaunay]
)

const handleMouseHover = useCallback(
Expand Down
22 changes: 22 additions & 0 deletions packages/scatterplot/src/ScatterPlotLegends.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { BoxLegendSvg } from '@nivo/legends'
import { ScatterPlotLegendsProps } from './types'

export const ScatterPlotLegends = ({
width,
height,
legends,
toggleSerie,
}: ScatterPlotLegendsProps) => (
<>
{legends.map(([legend, data], i) => (
<BoxLegendSvg
key={i}
{...legend}
containerWidth={width}
containerHeight={height}
data={data}
toggleSerie={legend.toggleSerie && toggleSerie}
/>
))}
</>
)
99 changes: 71 additions & 28 deletions packages/scatterplot/src/compute.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ import isString from 'lodash/isString'
import isNumber from 'lodash/isNumber'
import isPlainObject from 'lodash/isPlainObject'
import { scaleLinear } from 'd3-scale'
import { ComputedSerie } from '@nivo/scales'
import { computeXYScalesForSeries, ScaleSpec } from '@nivo/scales'
import {
ScatterPlotCommonProps,
ScatterPlotDataProps,
ScatterPlotDatum,
ScatterPlotNodeData,
ScatterPlotNodeDynamicSizeSpec,
ScatterPlotRawNodeData,
} from './types'
import { OrdinalColorScale } from '@nivo/colors'

const isDynamicSizeSpec = <RawDatum extends ScatterPlotDatum>(
size: ScatterPlotCommonProps<RawDatum>['nodeSize']
Expand Down Expand Up @@ -48,40 +51,80 @@ export const getNodeSizeGenerator = <RawDatum extends ScatterPlotDatum>(
throw new Error('nodeSize is invalid, it should be either a function, a number or an object')
}

export const computePoints = <RawDatum extends ScatterPlotDatum>({
series,
export const computeRawSeriesPoints = <RawDatum extends ScatterPlotDatum>({
data,
xScaleSpec,
yScaleSpec,
width,
height,
formatX,
formatY,
getNodeId,
}: {
series: ComputedSerie<{ id: string | number }, RawDatum>[]
data: ScatterPlotDataProps<RawDatum>['data']
xScaleSpec: ScaleSpec
yScaleSpec: ScaleSpec
width: number
height: number
formatX: (value: RawDatum['x']) => string | number
formatY: (value: RawDatum['x']) => string | number
getNodeId: (d: Omit<ScatterPlotNodeData<RawDatum>, 'id' | 'size' | 'color'>) => string
}): Omit<ScatterPlotNodeData<RawDatum>, 'size' | 'color'>[] => {
const points: Omit<ScatterPlotNodeData<RawDatum>, 'size' | 'color'>[] = []

series.forEach(serie => {
serie.data.forEach((d, serieIndex) => {
const point: Omit<ScatterPlotNodeData<RawDatum>, 'id' | 'size' | 'color'> = {
index: points.length,
serieIndex,
serieId: serie.id,
x: d.position.x as number,
xValue: d.data.x,
formattedX: formatX(d.data.x),
y: d.position.y as number,
yValue: d.data.y,
formattedY: formatY(d.data.y),
data: d.data,
}

points.push({
...point,
id: getNodeId(point),
getNodeId: (d: Omit<ScatterPlotRawNodeData<RawDatum>, 'id'>) => string
}) => {
const { series, xScale, yScale } = computeXYScalesForSeries<{ id: string | number }, RawDatum>(
data,
xScaleSpec,
yScaleSpec,
width,
height
)
let offset = 0 // allows giving each data point a unique index
const rawSeriesNodes: ScatterPlotRawNodeData<RawDatum>[][] = series
.filter(serie => serie.data.length > 0)
.map((serie, serieIndex) => {
const points = serie.data.map((d, i) => {
const point: Omit<ScatterPlotRawNodeData<RawDatum>, 'id'> = {
index: offset + i,
serieIndex,
serieId: serie.id,
x: d.position.x as number,
xValue: d.data.x,
formattedX: formatX(d.data.x),
y: d.position.y as number,
yValue: d.data.y,
formattedY: formatY(d.data.y),
data: d.data,
}
return {
...point,
id: getNodeId(point),
}
})
offset = offset + points.length
return points
})
})

return points
return { rawSeriesNodes, xScale, yScale }
}

export const computeStyledPoints = <RawDatum extends ScatterPlotDatum>({
rawSeriesNodes,
hiddenIds,
getColor,
getNodeSize,
}: {
rawSeriesNodes: ScatterPlotRawNodeData<RawDatum>[][]
hiddenIds: string[]
getColor: OrdinalColorScale<{ serieId: string | number }>
getNodeSize: (datum: ScatterPlotRawNodeData<RawDatum>) => number
}): ScatterPlotNodeData<RawDatum>[][] => {
return rawSeriesNodes
.filter(rawNodes => !hiddenIds.includes(String(rawNodes[0].serieId)))
.map(rawNodes => {
const color = getColor({ serieId: rawNodes[0].serieId })
return rawNodes.map(rawNode => ({
...rawNode,
color,
size: getNodeSize(rawNode),
}))
})
}
Loading