diff --git a/index.html b/index.html index f76ebb9..5d14a16 100644 --- a/index.html +++ b/index.html @@ -10,7 +10,6 @@ body, html { margin: 0; - color: white; } diff --git a/src/clone-node.ts b/src/clone-node.ts index 0a28c0f..f8ae30c 100644 --- a/src/clone-node.ts +++ b/src/clone-node.ts @@ -84,10 +84,6 @@ export async function cloneNode( && (isHTMLElementNode(node) || isSVGElementNode(node))) { const computedStyle = ownerWindow.getComputedStyle(node) - if (computedStyle.display === 'none') { - return ownerDocument.createComment(node.tagName.toLowerCase()) - } - const cloned = await cloneElement(node, context) const clonedStyle = cloned.style diff --git a/src/copy-css-styles.ts b/src/copy-css-styles.ts index cdfd1d0..553949e 100644 --- a/src/copy-css-styles.ts +++ b/src/copy-css-styles.ts @@ -17,8 +17,8 @@ export function copyCssStyles( context: Context, ) { const clonedStyle = cloned.style - const defaultStyle = getDefaultStyle(node.nodeName, null, context) - const diffStyle = getDiffStyle(computedStyle, defaultStyle, node) + const defaultStyle = getDefaultStyle(node, null, context) + const diffStyle = getDiffStyle(computedStyle, defaultStyle) for (const [name, [value, priority]] of Object.entries(diffStyle)) { if (ignoredStyle.includes(name)) continue diff --git a/src/copy-pseudo-class.ts b/src/copy-pseudo-class.ts index 7dff682..e34d8d2 100644 --- a/src/copy-pseudo-class.ts +++ b/src/copy-pseudo-class.ts @@ -43,7 +43,7 @@ export function copyPseudoClass( if (!content || content === 'none') return const klasses = [uuid()] - const defaultStyle = getDefaultStyle(node.nodeName, pseudoClass, context) + const defaultStyle = getDefaultStyle(node, pseudoClass, context) const cloneStyle = [ `content: '${ content.replace(/'|"/g, '') }';`, ] diff --git a/src/get-default-style.ts b/src/get-default-style.ts index 4790abe..76be746 100644 --- a/src/get-default-style.ts +++ b/src/get-default-style.ts @@ -1,11 +1,42 @@ -import { uuid } from './utils' +import { isSVGElementNode, uuid } from './utils' import type { Context } from './context' -export function getDefaultStyle(nodeName: string, pseudoElement: string | null, context: Context) { - nodeName = nodeName.toLowerCase() +const ignoredStyles = [ + 'width', + 'height', +] + +const includedAttributes = [ + 'stroke', + 'fill', +] + +export function getDefaultStyle( + node: HTMLElement | SVGElement, + pseudoElement: string | null, + context: Context, +) { const { defaultComputedStyles, ownerDocument } = context - const key = `${ nodeName }${ pseudoElement ?? '' }` + + const nodeName = node.nodeName.toLowerCase() + const isSvgNode = isSVGElementNode(node) && nodeName !== 'svg' + const attributes = isSvgNode + ? includedAttributes + .map(name => [name, node.getAttribute(name)]) + .filter(([, value]) => value !== null) + : [] + + const key = [ + isSvgNode && 'svg', + nodeName, + attributes.map((name, value) => `${ name }=${ value }`).join(','), + pseudoElement, + ] + .filter(Boolean) + .join(':') + if (defaultComputedStyles.has(key)) return defaultComputedStyles.get(key)! + let sandbox = context.sandbox if (!sandbox) { if (ownerDocument) { @@ -21,24 +52,35 @@ export function getDefaultStyle(nodeName: string, pseudoElement: string | null, } } if (!sandbox) return {} + const sandboxWindow = sandbox.contentWindow if (!sandboxWindow) return {} const sandboxDocument = sandboxWindow.document - const el = sandboxDocument.createElement(nodeName) - sandboxDocument.body.appendChild(el) - // Ensure that there is some content, so properties like margin are applied + + let root: HTMLElement | SVGSVGElement + let el: Element + if (isSvgNode) { + root = sandboxDocument.createElementNS('http://www.w3.org/2000/svg', 'svg') + el = root.ownerDocument.createElementNS(root.namespaceURI, nodeName) + attributes.forEach(([name, value]) => { + el.setAttributeNS(null, name!, value!) + }) + root.appendChild(el) + } else { + root = el = sandboxDocument.createElement(nodeName) + } el.textContent = ' ' - const style = sandboxWindow.getComputedStyle(el, pseudoElement) + sandboxDocument.body.appendChild(root) + const computedStyle = sandboxWindow.getComputedStyle(el, pseudoElement) const styles: Record = {} - for (let i = style.length - 1; i >= 0; i--) { - const name = style.item(i) - if (name === 'width' || name === 'height') { - styles[name] = 'auto' - } else { - styles[name] = style.getPropertyValue(name) - } + for (let len = computedStyle.length, i = 0; i < len; i++) { + const name = computedStyle.item(i) + if (ignoredStyles.includes(name)) continue + styles[name] = computedStyle.getPropertyValue(name) } - sandboxDocument.body.removeChild(el) + sandboxDocument.body.removeChild(root) + defaultComputedStyles.set(key, styles) + return styles } diff --git a/src/get-diff-style.ts b/src/get-diff-style.ts index bacdc70..0e0e6b3 100644 --- a/src/get-diff-style.ts +++ b/src/get-diff-style.ts @@ -6,7 +6,6 @@ const getPrefix = (name: string) => name export function getDiffStyle( style: CSSStyleDeclaration, defaultStyle: Record, - node?: HTMLElement | SVGElement, ) { const diffStyle: Record = {} const diffStylePrefixs: string[] = [] @@ -23,11 +22,7 @@ export function getDiffStyle( prefixTree[prefix][name] = [value, priority] } - if ( - defaultStyle[name] === value - && !priority - && (node && !node.getAttribute(name)) - ) continue + if (defaultStyle[name] === value && !priority) continue if (prefix) { diffStylePrefixs.push(prefix) diff --git a/test/fixtures/svg.color.html b/test/fixtures/svg.color.html index 1c5ea4d..6ae7994 100644 --- a/test/fixtures/svg.color.html +++ b/test/fixtures/svg.color.html @@ -25,9 +25,9 @@ height="120" requiredExtensions="http://www.w3.org/1999/xhtml" > - + diff --git a/test/fixtures/svg.symbol.html b/test/fixtures/svg.symbol.html new file mode 100644 index 0000000..de27480 --- /dev/null +++ b/test/fixtures/svg.symbol.html @@ -0,0 +1,20 @@ + diff --git a/test/fixtures/svg.symbol.png b/test/fixtures/svg.symbol.png new file mode 100644 index 0000000..ed4a8f0 Binary files /dev/null and b/test/fixtures/svg.symbol.png differ