diff --git a/src/math/webgl/shader_compiler.ts b/src/math/webgl/shader_compiler.ts index ef4f772bd6..89e2eae914 100644 --- a/src/math/webgl/shader_compiler.ts +++ b/src/math/webgl/shader_compiler.ts @@ -37,7 +37,8 @@ export function makeShader(inputsInfo: InputInfo[], outputShape: ShapeInfo, getOutputSamplingSnippet(outputShape.logicalShape, outTexShape); const source = [ SHADER_PREFIX, inputPrefixSnippet, SAMPLE_1D_SNIPPET, SAMPLE_2D_SNIPPET, - SAMPLE_3D_SNIPPET, inputSamplingSnippet, outputSamplingSnippet, userCode + SAMPLE_3D_SNIPPET, SAMPLE_4D_SNIPPET, inputSamplingSnippet, + outputSamplingSnippet, userCode ].join('\n'); return source; } @@ -63,6 +64,10 @@ function getInputSamplingSnippet( res += getSampler3D( inInfo.name, shape as [number, number, number], texShape); break; + case 4: + res += getSampler4D( + inInfo.name, shape as [number, number, number, number], texShape); + break; default: throw new Error( `${shape.length}-D input sampling` + @@ -93,6 +98,9 @@ function getOutputSamplingSnippet( case 3: return getOutput3DCoords(outShape as [number, number, number], outTexShape); + case 4: + return getOutput4DCoords(outShape as [number, number, number, number], + outTexShape); default: throw new Error( `${outShape.length}-D output sampling is not yet supported`); @@ -144,6 +152,19 @@ const SAMPLE_3D_SNIPPET = ` } `; +const SAMPLE_4D_SNIPPET = ` + float sample4D(sampler2D texture, float texNumR, float texNumC, float stride0, + float stride1, float stride2, float row, float col, float depth, + float depth2) { + float index = dot(vec4(row, col, depth, depth2), + vec4(stride0, stride1, stride2, 1.0)); + float texR = floor(index / texNumC); + float texC = mod(index, texNumC); + vec2 uv = (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); + return texture2D(texture, uv).r; + } +`; + function getOutput1DCoords( shape: [number], texShape: [number, number]): string { if (texShape[0] === 1) { @@ -185,6 +206,30 @@ function getOutput3DCoords(shape: [number, number, number], `; } +function getOutput4DCoords(shape: [number, number, number, number], + texShape: [number, number]): string { + const stride2 = shape[3]; + const stride1 = shape[2] * stride2; + const stride0 = shape[1] * stride1; + return ` + vec4 getOutputCoords() { + vec2 resTexRC = floor(gl_FragCoord.yx); + float index = dot(resTexRC, vec2(${texShape[1]}.0, 1.0)); + + float r = floor(index / ${stride0}.0); + index -= r * ${stride0}.0; + + float c = floor(index / ${stride1}.0); + index -= c * ${stride1}.0; + + float d = floor(index / ${stride2}.0); + float d2 = mod(index, ${stride2}.0); + + return vec4(r, c, d, d2); + } + `; +} + function getOutput2DCoords( shape: [number, number], texShape: [number, number]): string { if (util.arraysEqual(shape, texShape)) { @@ -265,6 +310,24 @@ function getSampler3D( `; } +function getSampler4D( + texName: string, shape: [number, number, number, number], + texShape: [number, number]): string { + const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + const tR = texShape[0]; + const tC = texShape[1]; + const stride2 = shape[3]; + const stride1 = shape[2] * stride2; + const stride0 = shape[1] * stride1; + + return ` + float ${funcName}(float row, float col, float depth, float depth2) { + return sample4D(${texName}, ${tR}.0, ${tC}.0, ${stride0}.0, ${stride1}.0, + ${stride2}.0, row, col, depth, depth2); + } +`; +} + function getSampler2D( texName: string, shape: [number, number], texShape: [number, number]): string {