diff --git a/src/strands/strands_transpiler.js b/src/strands/strands_transpiler.js index f6632d7429..b30c3380e2 100644 --- a/src/strands/strands_transpiler.js +++ b/src/strands/strands_transpiler.js @@ -40,6 +40,50 @@ function nodeIsUniform(ancestor) { ); } +function nodeIsUniformCallbackFn(node, names) { + if (!names?.size) return false; + if (node.type === 'FunctionDeclaration' && names.has(node.id?.name)) return true; + if ( + node.type === 'VariableDeclarator' && names.has(node.id?.name) && + (node.init?.type === 'FunctionExpression' || node.init?.type === 'ArrowFunctionExpression') + ) { + return true; + } + return false; +} + +function collectUniformCallbackNames(ast) { + // Sub-pass 1: collect all named function definitions + const namedFunctions = new Set(); + ancestor(ast, { + FunctionDeclaration(node) { + if (node.id) namedFunctions.add(node.id.name); + }, + VariableDeclarator(node) { + if ( + node.id?.type === 'Identifier' && + (node.init?.type === 'FunctionExpression' || node.init?.type === 'ArrowFunctionExpression') + ) { + namedFunctions.add(node.id.name); + } + } + }); + // Sub-pass 2: find which of those names are passed as uniform call arguments + const names = new Set(); + ancestor(ast, { + CallExpression(node) { + if (nodeIsUniform(node)) { + for (const arg of node.arguments) { + if (arg.type === 'Identifier' && namedFunctions.has(arg.name)) { + names.add(arg.name); + } + } + } + } + }); + return names; +} + function nodeIsVarying(node) { return node && node.type === 'CallExpression' && ( @@ -192,8 +236,10 @@ function replaceReferences(node, tempVarMap) { } const ASTCallbacks = { - UnaryExpression(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + UnaryExpression(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } const unaryFnName = UnarySymbolToName[node.operator]; const standardReplacement = (node) => { node.type = 'CallExpression' @@ -236,8 +282,10 @@ const ASTCallbacks = { delete node.argument; delete node.operator; }, - BreakStatement(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + BreakStatement(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } node.callee = { type: 'Identifier', name: '__p5.break' @@ -245,8 +293,10 @@ const ASTCallbacks = { node.arguments = []; node.type = 'CallExpression'; }, - MemberExpression(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + MemberExpression(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } // Skip sets -- these will be converted to .set() method // calls at the AssignmentExpression level if ( @@ -272,8 +322,10 @@ const ASTCallbacks = { node.type = 'CallExpression'; } }, - VariableDeclarator(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + VariableDeclarator(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } if (nodeIsUniform(node.init)) { // Only inject the variable name if the first argument isn't already a string if (node.init.arguments.length === 0 || @@ -298,16 +350,18 @@ const ASTCallbacks = { value: node.id.name } node.init.arguments.unshift(varyingNameLiteral); - _state.varyings[node.id.name] = varyingNameLiteral; + state.varyings[node.id.name] = varyingNameLiteral; } else { // Still track it as a varying even if name wasn't injected - _state.varyings[node.id.name] = node.init.arguments[0]; + state.varyings[node.id.name] = node.init.arguments[0]; } } }, - Identifier(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } - if (_state.varyings[node.name] + Identifier(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } + if (state.varyings[node.name] && !ancestors.some(a => a.type === 'AssignmentExpression' && a.left === node) ) { node.type = 'CallExpression'; @@ -327,8 +381,10 @@ const ASTCallbacks = { }, // The callbacks for AssignmentExpression and BinaryExpression handle // operator overloading including +=, *= assignment expressions - ArrayExpression(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + ArrayExpression(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } const original = JSON.parse(JSON.stringify(node)); node.type = 'CallExpression'; node.callee = { @@ -337,8 +393,10 @@ const ASTCallbacks = { }; node.arguments = [original]; }, - AssignmentExpression(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + AssignmentExpression(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } const unsafeTypes = ['Literal', 'ArrayExpression', 'Identifier']; if (node.operator !== '=') { const methodName = replaceBinaryOperator(node.operator.replace('=','')); @@ -367,7 +425,7 @@ const ASTCallbacks = { node.right = rightReplacementNode; } // Handle direct varying variable assignment: myVarying = value - if (_state.varyings[node.left.name]) { + if (state.varyings[node.left.name]) { node.type = 'ExpressionStatement'; node.expression = { type: 'CallExpression', @@ -412,7 +470,7 @@ const ASTCallbacks = { let varyingName = null; // Check if it's a direct identifier: myVarying.xyz - if (node.left.object.type === 'Identifier' && _state.varyings[node.left.object.name]) { + if (node.left.object.type === 'Identifier' && state.varyings[node.left.object.name]) { varyingName = node.left.object.name; } // Check if it's a getValue() call: myVarying.getValue().xyz @@ -420,7 +478,7 @@ const ASTCallbacks = { node.left.object.callee?.type === 'MemberExpression' && node.left.object.callee.property?.name === 'getValue' && node.left.object.callee.object?.type === 'Identifier' && - _state.varyings[node.left.object.callee.object.name]) { + state.varyings[node.left.object.callee.object.name]) { varyingName = node.left.object.callee.object.name; } @@ -451,10 +509,12 @@ const ASTCallbacks = { } } }, - BinaryExpression(node, _state, ancestors) { + BinaryExpression(node, state, ancestors) { // Don't convert uniform default values to node methods, as // they should be evaluated at runtime, not compiled. - if (ancestors.some(nodeIsUniform)) { return; } + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } // If the left hand side of an expression is one of these types, // we should construct a node from it. const unsafeTypes = ['Literal', 'ArrayExpression', 'Identifier']; @@ -482,10 +542,12 @@ const ASTCallbacks = { }; node.arguments = [node.right]; }, - LogicalExpression(node, _state, ancestors) { + LogicalExpression(node, state, ancestors) { // Don't convert uniform default values to node methods, as // they should be evaluated at runtime, not compiled. - if (ancestors.some(nodeIsUniform)) { return; } + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } // If the left hand side of an expression is one of these types, // we should construct a node from it. const unsafeTypes = ['Literal', 'ArrayExpression', 'Identifier']; @@ -513,8 +575,10 @@ const ASTCallbacks = { }; node.arguments = [node.right]; }, - ConditionalExpression(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + ConditionalExpression(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } // Transform condition ? consequent : alternate // into __p5.strandsTernary(condition, consequent, alternate) const test = node.test; @@ -527,8 +591,10 @@ const ASTCallbacks = { delete node.consequent; delete node.alternate; }, - IfStatement(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + IfStatement(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } // Transform if statement into strandsIf() call // The condition is evaluated directly, not wrapped in a function const condition = node.test; @@ -796,8 +862,10 @@ const ASTCallbacks = { delete node.consequent; delete node.alternate; }, - UpdateExpression(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + UpdateExpression(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } // Transform ++var, var++, --var, var-- into assignment expressions let operator; @@ -828,11 +896,13 @@ const ASTCallbacks = { // Replace the update expression with the assignment expression Object.assign(node, assignmentExpr); delete node.prefix; - this.BinaryExpression(node.right, _state, [...ancestors, node]); - this.AssignmentExpression(node, _state, ancestors); + this.BinaryExpression(node.right, state, [...ancestors, node]); + this.AssignmentExpression(node, state, ancestors); }, - ForStatement(node, _state, ancestors) { - if (ancestors.some(nodeIsUniform)) { return; } + ForStatement(node, state, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) { + return; + } // Transform for statement into strandsFor() call // for (init; test; update) body -> strandsFor(initCb, conditionCb, updateCb, bodyCb, initialVars) @@ -1538,22 +1608,31 @@ function transformFunctionSetCalls(functionNode) { } // Main transformation pass: find and transform functions with .set() calls in control flow -function transformSetCallsInControlFlow(ast) { +function transformSetCallsInControlFlow(ast, names) { const functionsToTransform = []; // Collect functions that have .set() calls in control flow const collectFunctions = { ArrowFunctionExpression(node, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) { + return; + } if (functionHasSetInControlFlow(node)) { functionsToTransform.push(node); } }, FunctionExpression(node, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) { + return; + } if (functionHasSetInControlFlow(node)) { functionsToTransform.push(node); } }, FunctionDeclaration(node, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) { + return; + } if (functionHasSetInControlFlow(node)) { functionsToTransform.push(node); } @@ -1569,12 +1648,15 @@ function transformSetCallsInControlFlow(ast) { } // Main transformation pass: find and transform helper functions with early returns -function transformHelperFunctionEarlyReturns(ast) { +function transformHelperFunctionEarlyReturns(ast, names) { const helperFunctionsToTransform = []; // Collect helper functions that need transformation const collectHelperFunctions = { VariableDeclarator(node, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) { + return; + } const init = node.init; if (init && (init.type === 'ArrowFunctionExpression' || init.type === 'FunctionExpression')) { if (functionHasEarlyReturns(init)) { @@ -1583,6 +1665,9 @@ function transformHelperFunctionEarlyReturns(ast) { } }, FunctionDeclaration(node, ancestors) { + if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) { + return; + } if (functionHasEarlyReturns(node)) { helperFunctionsToTransform.push(node); } @@ -1612,20 +1697,41 @@ export function transpileStrandsToJS(p5, sourceString, srcLocations, scope) { locations: srcLocations }); + // Pre-pass: collect names of functions passed by reference as uniform callbacks + const uniformCallbackNames = collectUniformCallbackNames(ast); + // First pass: transform .set() calls in control flow to use intermediate variables - transformSetCallsInControlFlow(ast); + transformSetCallsInControlFlow(ast, uniformCallbackNames); // Second pass: transform everything except if/for statements using normal ancestor traversal const nonControlFlowCallbacks = { ...ASTCallbacks }; delete nonControlFlowCallbacks.IfStatement; delete nonControlFlowCallbacks.ForStatement; - ancestor(ast, nonControlFlowCallbacks, undefined, { varyings: {} }); + ancestor(ast, nonControlFlowCallbacks, undefined, { varyings: {}, uniformCallbackNames }); // Third pass: transform helper functions with early returns to use __returnValue pattern - transformHelperFunctionEarlyReturns(ast); + transformHelperFunctionEarlyReturns(ast, uniformCallbackNames); // Fourth pass: transform if/for statements in post-order using recursive traversal const postOrderControlFlowTransform = { + CallExpression(node, state, c) { + if (nodeIsUniform(node)) { return; } + if (node.callee) c(node.callee, state); + for (const arg of node.arguments) c(arg, state); + }, + FunctionDeclaration(node, state, c) { + if (state.uniformCallbackNames?.has(node.id?.name)) return; + if (node.body) c(node.body, state); + }, + VariableDeclarator(node, state, c) { + if ( + state.uniformCallbackNames?.has(node.id?.name) && + (node.init?.type === 'FunctionExpression' || node.init?.type === 'ArrowFunctionExpression') + ) { + return; + } + if (node.init) c(node.init, state); + }, IfStatement(node, state, c) { state.inControlFlow++; // First recursively process children @@ -1662,7 +1768,7 @@ export function transpileStrandsToJS(p5, sourceString, srcLocations, scope) { delete node.argument; } }; - recursive(ast, { varyings: {}, inControlFlow: 0 }, postOrderControlFlowTransform); + recursive(ast, { varyings: {}, inControlFlow: 0, uniformCallbackNames }, postOrderControlFlowTransform); const transpiledSource = escodegen.generate(ast); const scopeKeys = Object.keys(scope); const match = /\(?\s*(?:function)?\s*\w*\s*\(([^)]*)\)\s*(?:=>)?\s*{((?:.|\n)*)}\s*;?\s*\)?/ diff --git a/test/unit/webgl/p5.Shader.js b/test/unit/webgl/p5.Shader.js index bdfb0e5476..24c23fe548 100644 --- a/test/unit/webgl/p5.Shader.js +++ b/test/unit/webgl/p5.Shader.js @@ -2184,6 +2184,168 @@ test('returns numbers for builtin globals outside hooks and a strandNode when ca assert.approximately(pixelColor[2], 0, 5); }); + test('handle uniformFloat with control flow in callback', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + // Uniform callback with an if-statement and multiple return paths + const pastOneSecond = myp5.uniformFloat(() => { + if (myp5.frameCount > 1000) { + return 1; + } + return 0; + }); + + myp5.filterColor.begin(); + myp5.filterColor.set(myp5.mix([1, 0, 0, 1], [0, 1, 0, 1], pastOneSecond)); + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + // frameCount <= 1000 so pastOneSecond = 0, mix returns [1, 0, 0, 1] = red + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); + + test('handle uniformFloat with for loop in callback', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + // Uniform callback with a for loop accumulating a value + const brightness = myp5.uniformFloat(() => { + let sum = 0; + for (let i = 0; i < 3; i++) { + sum += i; + } + return sum / 10; // 0+1+2=3, 3/10=0.3 + }); + + myp5.filterColor.begin(); + myp5.filterColor.set([brightness, 0, 0, 1]); + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + // brightness = 0.3, so red channel = 0.3 * 255 ≈ 76 + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 0.3 * 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); + + test('handle uniformFloat with sub-function call in callback', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + // Uniform callback that calls a sub-function + const brightness = myp5.uniformFloat(() => { + const getValue = () => 0.6; + return getValue(); + }); + + myp5.filterColor.begin(); + myp5.filterColor.set([brightness, 0, 0, 1]); + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + // brightness = 0.6, so red channel = 0.6 * 255 ≈ 153 + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 0.6 * 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); + + test('handle uniformFloat with control flow in non-inline callback', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + function pastOneSecondValue() { + if (myp5.frameCount > 1000) { + return 1; + } + return 0; + } + const pastOneSecond = myp5.uniformFloat(pastOneSecondValue); + + myp5.filterColor.begin(); + myp5.filterColor.set(myp5.mix([1, 0, 0, 1], [0, 1, 0, 1], pastOneSecond)); + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + // frameCount <= 1000 so pastOneSecond = 0, mix returns [1, 0, 0, 1] = red + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); + + test('handle uniformFloat with for loop in non-inline callback', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + function brightnessValue() { + let sum = 0; + for (let i = 0; i < 3; i++) { + sum += i; + } + return sum / 10; // 0+1+2=3, 3/10=0.3 + } + const brightness = myp5.uniformFloat(brightnessValue); + + myp5.filterColor.begin(); + myp5.filterColor.set([brightness, 0, 0, 1]); + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + // brightness = 0.3, so red channel = 0.3 * 255 ≈ 76 + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 0.3 * 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); + + test('handle uniformFloat with sub-function call in non-inline callback', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + function getValue() { + return 0.6; + } + function brightnessValue() { + return getValue(); + } + const brightness = myp5.uniformFloat(brightnessValue); + + myp5.filterColor.begin(); + myp5.filterColor.set([brightness, 0, 0, 1]); + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + // brightness = 0.6, so red channel = 0.6 * 255 ≈ 153 + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 0.6 * 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); + test('handle false .set() in if with content afterwards with flat API', () => { myp5.createCanvas(50, 50, myp5.WEBGL);