Skip to content

Commit

Permalink
Merge pull request #1984 from Sec-ant/refactor-bezier-function
Browse files Browse the repository at this point in the history
refactor Bezier function
  • Loading branch information
moklick committed Mar 23, 2022
2 parents 903b795 + 3708402 commit b0bec16
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 122 deletions.
107 changes: 70 additions & 37 deletions src/components/ConnectionLine/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@ import shallow from 'zustand/shallow';
import { useStore } from '../../store';
import { getBezierPath } from '../Edges/BezierEdge';
import { getSmoothStepPath } from '../Edges/SmoothStepEdge';
import {
HandleElement,
ConnectionLineType,
ConnectionLineComponent,
HandleType,
Node,
ReactFlowState,
Position,
} from '../../types';
import { ConnectionLineType, ConnectionLineComponent, HandleType, Node, ReactFlowState, Position } from '../../types';
import { getSimpleBezierPath } from '../Edges/SimpleBezierEdge';

interface ConnectionLineProps {
Expand All @@ -29,13 +21,6 @@ interface ConnectionLineProps {

const selector = (s: ReactFlowState) => ({ nodeInternals: s.nodeInternals, transform: s.transform });

const getSourceHandle = (handleId: string | null, sourceNode: Node, connectionHandleType: HandleType) => {
const handleTypeInverted = connectionHandleType === 'source' ? 'target' : 'source';
const handleBound = sourceNode.handleBounds?.[connectionHandleType] || sourceNode.handleBounds?.[handleTypeInverted];

return handleId ? handleBound?.find((d: HandleElement) => d.id === handleId) : handleBound?.[0];
};

export default ({
connectionNodeId,
connectionHandleId,
Expand All @@ -51,43 +36,92 @@ export default ({
const handleId = connectionHandleId;

const { nodeInternals, transform } = useStore(selector, shallow);
const sourceNode = useRef<Node | undefined>(nodeInternals.get(nodeId));
const fromNode = useRef<Node | undefined>(nodeInternals.get(nodeId));

if (
!sourceNode.current ||
!sourceNode.current ||
!fromNode.current ||
!fromNode.current ||
!isConnectable ||
!sourceNode.current.handleBounds?.[connectionHandleType]
!fromNode.current.handleBounds?.[connectionHandleType]
) {
return null;
}

const sourceHandle = getSourceHandle(handleId, sourceNode.current, connectionHandleType);
const sourceHandleX = sourceHandle ? sourceHandle.x + sourceHandle.width / 2 : (sourceNode.current?.width ?? 0) / 2;
const sourceHandleY = sourceHandle ? sourceHandle.y + sourceHandle.height / 2 : sourceNode.current?.height ?? 0;
const sourceX = (sourceNode.current.positionAbsolute?.x || 0) + sourceHandleX;
const sourceY = (sourceNode.current.positionAbsolute?.y || 0) + sourceHandleY;

const targetX = (connectionPositionX - transform[0]) / transform[2];
const targetY = (connectionPositionY - transform[1]) / transform[2];
const handleBound = fromNode.current.handleBounds?.[connectionHandleType];
const fromHandle = handleId ? handleBound?.find((d) => d.id === handleId) : handleBound?.[0];
const fromHandleX = fromHandle ? fromHandle.x + fromHandle.width / 2 : (fromNode.current?.width ?? 0) / 2;
const fromHandleY = fromHandle ? fromHandle.y + fromHandle.height / 2 : fromNode.current?.height ?? 0;
const fromX = (fromNode.current.positionAbsolute?.x || 0) + fromHandleX;
const fromY = (fromNode.current.positionAbsolute?.y || 0) + fromHandleY;

const toX = (connectionPositionX - transform[0]) / transform[2];
const toY = (connectionPositionY - transform[1]) / transform[2];

const fromPosition = fromHandle?.position;

let toPosition: Position | undefined;
switch (fromPosition) {
case Position.Left:
toPosition = Position.Right;
break;
case Position.Right:
toPosition = Position.Left;
break;
case Position.Top:
toPosition = Position.Bottom;
break;
case Position.Bottom:
toPosition = Position.Top;
break;
}

const isRightOrLeft = sourceHandle?.position === Position.Left || sourceHandle?.position === Position.Right;
const targetPosition = isRightOrLeft ? Position.Left : Position.Top;
let sourceX: number,
sourceY: number,
sourcePosition: Position | undefined,
targetX: number,
targetY: number,
targetPosition: Position | undefined;

switch (connectionHandleType) {
case 'source':
{
sourceX = fromX;
sourceY = fromY;
sourcePosition = fromPosition;
targetX = toX;
targetY = toY;
targetPosition = toPosition;
}
break;
case 'target':
{
sourceX = toX;
sourceY = toY;
sourcePosition = toPosition;
targetX = fromX;
targetY = fromY;
targetPosition = fromPosition;
}
break;
}

if (CustomConnectionLineComponent) {
return (
<g className="react-flow__connection">
<CustomConnectionLineComponent
sourceX={sourceX}
sourceY={sourceY}
sourcePosition={sourceHandle?.position}
sourcePosition={sourcePosition}
targetX={targetX}
targetY={targetY}
targetPosition={targetPosition}
connectionLineType={connectionLineType}
connectionLineStyle={connectionLineStyle}
sourceNode={sourceNode.current as Node}
sourceHandle={sourceHandle}
fromNode={fromNode.current}
fromHandle={fromHandle}
// backward compatibility, mark as deprecated?
sourceNode={fromNode.current}
sourceHandle={fromHandle}
/>
</g>
);
Expand All @@ -98,16 +132,15 @@ export default ({
const pathParams = {
sourceX,
sourceY,
sourcePosition: sourceHandle?.position,
sourcePosition,
targetX,
targetY,
targetPosition,
};

if (connectionLineType === ConnectionLineType.Bezier) {
// @TODO: we need another getBezier function, that handles a connection line.
// Since we don't know the target position, we can't use the default bezier function here.
dAttr = getBezierPath({ ...pathParams, curvature: 0 });
// we assume the destination position is opposite to the source position
dAttr = getBezierPath(pathParams);
} else if (connectionLineType === ConnectionLineType.Step) {
dAttr = getSmoothStepPath({
...pathParams,
Expand Down
184 changes: 106 additions & 78 deletions src/components/Edges/BezierEdge.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import React, { memo } from 'react';
import { EdgeProps, Position } from '../../types';
import BaseEdge from './BaseEdge';
import { getCenter } from './utils';

export interface GetBezierPathParams {
sourceX: number;
Expand All @@ -11,12 +10,56 @@ export interface GetBezierPathParams {
targetY: number;
targetPosition?: Position;
curvature?: number;
centerX?: number;
centerY?: number;
}

// @TODO: refactor getBezierPath function. It's too long and hard to understand.
// We should reuse the curvature handling for top/bottom and left/right.
interface GetControlWithCurvatureParams {
pos: Position;
x1: number;
y1: number;
x2: number;
y2: number;
c: number;
}

function calculateControlOffset(distance: number, curvature: number): number {
if (distance >= 0) {
return 0.5 * distance;
} else {
return curvature * 25 * Math.sqrt(-distance);
}
}

function getControlWithCurvature({ pos, x1, y1, x2, y2, c }: GetControlWithCurvatureParams): [number, number] {
let ctX: number, ctY: number;
switch (pos) {
case Position.Left:
{
ctX = x1 - calculateControlOffset(x1 - x2, c);
ctY = y1;
}
break;
case Position.Right:
{
ctX = x1 + calculateControlOffset(x2 - x1, c);
ctY = y1;
}
break;
case Position.Top:
{
ctX = x1;
ctY = y1 - calculateControlOffset(y1 - y2, c);
}
break;
case Position.Bottom:
{
ctX = x1;
ctY = y1 + calculateControlOffset(y2 - y1, c);
}
break;
}
return [ctX, ctY];
}

export function getBezierPath({
sourceX,
sourceY,
Expand All @@ -25,78 +68,62 @@ export function getBezierPath({
targetY,
targetPosition = Position.Top,
curvature = 0.25,
centerX,
centerY,
}: GetBezierPathParams): string {
const leftAndRight = [Position.Left, Position.Right];
const hasCurvature = curvature > 0;
const [_centerX, _centerY] = getCenter({ sourceX, sourceY, targetX, targetY });

if (leftAndRight.includes(sourcePosition) && leftAndRight.includes(targetPosition)) {
const cX = typeof centerX !== 'undefined' ? centerX : _centerX;
const distanceX = targetX - sourceX;
const absDistanceX = Math.abs(distanceX);
const amtX = (Math.sqrt(absDistanceX) / 2) * (50 * curvature);

let hx1 = cX;
let hx2 = cX;

if (hasCurvature) {
const sourceAndTargetRight = sourcePosition === Position.Right && targetPosition === Position.Right;
const sourceAndTargetLeft = sourcePosition === Position.Left && targetPosition === Position.Left;

hx1 = sourceX + amtX;
hx2 = targetX - amtX;

if (sourceAndTargetLeft) {
hx1 = sourceX - amtX;
} else if (sourceAndTargetRight) {
hx2 = targetX + amtX;
} else if (sourcePosition === Position.Left && targetX <= sourceX) {
hx1 = cX;
hx2 = cX;
} else if (sourcePosition === Position.Left && targetX > sourceX) {
hx1 = sourceX - amtX;
hx2 = targetX + amtX;
}
}

return `M${sourceX},${sourceY} C${hx1},${sourceY} ${hx2},${targetY}, ${targetX},${targetY}`;
} else if (leftAndRight.includes(targetPosition)) {
return `M${sourceX},${sourceY} Q${sourceX},${targetY} ${targetX},${targetY}`;
} else if (leftAndRight.includes(sourcePosition)) {
return `M${sourceX},${sourceY} Q${targetX},${sourceY} ${targetX},${targetY}`;
}

const cY = typeof centerY !== 'undefined' ? centerY : _centerY;
const distanceY = targetY - sourceY;
const absDistanceY = Math.abs(distanceY);
const amtY = (Math.sqrt(absDistanceY) / 2) * (50 * curvature);

let hy1 = cY;
let hy2 = cY;

if (hasCurvature) {
hy1 = sourceY + amtY;
hy2 = targetY - amtY;

const sourceAndTargetTop = sourcePosition === Position.Top && targetPosition === Position.Top;
const sourceAndTargetBottom = sourcePosition === Position.Bottom && targetPosition === Position.Bottom;

if (sourceAndTargetTop) {
hy1 = targetY - amtY;
} else if (sourceAndTargetBottom) {
hy2 = targetY + amtY;
} else if (sourcePosition === Position.Top && targetY <= sourceY) {
hy1 = cY;
hy2 = cY;
} else if (sourcePosition === Position.Top && targetY > sourceY) {
hy1 = sourceY - amtY;
hy2 = targetY + amtY;
}
}
const [sourceControlX, sourceControlY] = getControlWithCurvature({
pos: sourcePosition,
x1: sourceX,
y1: sourceY,
x2: targetX,
y2: targetY,
c: curvature,
});
const [targetControlX, targetControlY] = getControlWithCurvature({
pos: targetPosition,
x1: targetX,
y1: targetY,
x2: sourceX,
y2: sourceY,
c: curvature,
});
return `M${sourceX},${sourceY} C${sourceControlX},${sourceControlY} ${targetControlX},${targetControlY} ${targetX},${targetY}`;
}

return `M${sourceX},${sourceY} C${sourceX},${hy1} ${targetX},${hy2} ${targetX},${targetY}`;
// @TODO: this function will recalculate the control points
// one option is to let getXXXPath() return center points
// but will introduce breaking changes
// the getCenter() of other types of edges might need to change, too
export function getBezierCenter({
sourceX,
sourceY,
sourcePosition = Position.Bottom,
targetX,
targetY,
targetPosition = Position.Top,
curvature = 0.25,
}: GetBezierPathParams): [number, number, number, number] {
const [sourceControlX, sourceControlY] = getControlWithCurvature({
pos: sourcePosition,
x1: sourceX,
y1: sourceY,
x2: targetX,
y2: targetY,
c: curvature,
});
const [targetControlX, targetControlY] = getControlWithCurvature({
pos: targetPosition,
x1: targetX,
y1: targetY,
x2: sourceX,
y2: sourceY,
c: curvature,
});
// cubic bezier t=0.5 mid point, not the actual mid point, but easy to calculate
// https://stackoverflow.com/questions/67516101/how-to-find-distance-mid-point-of-bezier-curve
const centerX = sourceX * 0.125 + sourceControlX * 0.375 + targetControlX * 0.375 + targetX * 0.125;
const centerY = sourceY * 0.125 + sourceControlY * 0.375 + targetControlY * 0.375 + targetY * 0.125;
const xOffset = Math.abs(centerX - sourceX);
const yOffset = Math.abs(centerY - sourceY);
return [centerX, centerY, xOffset, yOffset];
}

export default memo(
Expand All @@ -118,16 +145,17 @@ export default memo(
markerStart,
curvature,
}: EdgeProps) => {
const [centerX, centerY] = getCenter({ sourceX, sourceY, targetX, targetY, sourcePosition, targetPosition });
const path = getBezierPath({
const params = {
sourceX,
sourceY,
sourcePosition,
targetX,
targetY,
targetPosition,
curvature,
});
};
const path = getBezierPath(params);
const [centerX, centerY] = getBezierCenter(params);

return (
<BaseEdge
Expand Down

0 comments on commit b0bec16

Please sign in to comment.