Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion src/math/webgl/shader_compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ export function makeShader(inputsInfo: InputInfo[], outputShape: ShapeInfo,
getOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
const source = [
SHADER_PREFIX, inputPrefixSnippet, SAMPLE_1D_SNIPPET, SAMPLE_2D_SNIPPET,
SAMPLE_3D_SNIPPET, inputSamplingSnippet, outputSamplingSnippet, userCode
SAMPLE_3D_SNIPPET, SAMPLE_4D_SNIPPET, inputSamplingSnippet,
outputSamplingSnippet, userCode
].join('\n');
return source;
}
Expand All @@ -63,6 +64,10 @@ function getInputSamplingSnippet(
res += getSampler3D(
inInfo.name, shape as [number, number, number], texShape);
break;
case 4:
res += getSampler4D(
inInfo.name, shape as [number, number, number, number], texShape);
break;
default:
throw new Error(
`${shape.length}-D input sampling` +
Expand Down Expand Up @@ -93,6 +98,9 @@ function getOutputSamplingSnippet(
case 3:
return getOutput3DCoords(outShape as [number, number, number],
outTexShape);
case 4:
return getOutput4DCoords(outShape as [number, number, number, number],
outTexShape);
default:
throw new Error(
`${outShape.length}-D output sampling is not yet supported`);
Expand Down Expand Up @@ -144,6 +152,19 @@ const SAMPLE_3D_SNIPPET = `
}
`;

const SAMPLE_4D_SNIPPET = `
float sample4D(sampler2D texture, float texNumR, float texNumC, float stride0,
float stride1, float stride2, float row, float col, float depth,
float depth2) {
float index = dot(vec4(row, col, depth, depth2),
vec4(stride0, stride1, stride2, 1.0));
float texR = floor(index / texNumC);
float texC = mod(index, texNumC);
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
return texture2D(texture, uv).r;
}
`;

function getOutput1DCoords(
shape: [number], texShape: [number, number]): string {
if (texShape[0] === 1) {
Expand Down Expand Up @@ -185,6 +206,30 @@ function getOutput3DCoords(shape: [number, number, number],
`;
}

function getOutput4DCoords(shape: [number, number, number, number],
texShape: [number, number]): string {
const stride2 = shape[3];
const stride1 = shape[2] * stride2;
const stride0 = shape[1] * stride1;
return `
vec4 getOutputCoords() {
vec2 resTexRC = floor(gl_FragCoord.yx);
float index = dot(resTexRC, vec2(${texShape[1]}.0, 1.0));

float r = floor(index / ${stride0}.0);
index -= r * ${stride0}.0;

float c = floor(index / ${stride1}.0);
index -= c * ${stride1}.0;

float d = floor(index / ${stride2}.0);
float d2 = mod(index, ${stride2}.0);

return vec4(r, c, d, d2);
}
`;
}

function getOutput2DCoords(
shape: [number, number], texShape: [number, number]): string {
if (util.arraysEqual(shape, texShape)) {
Expand Down Expand Up @@ -265,6 +310,24 @@ function getSampler3D(
`;
}

function getSampler4D(
texName: string, shape: [number, number, number, number],
texShape: [number, number]): string {
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const tR = texShape[0];
const tC = texShape[1];
const stride2 = shape[3];
const stride1 = shape[2] * stride2;
const stride0 = shape[1] * stride1;

return `
float ${funcName}(float row, float col, float depth, float depth2) {
return sample4D(${texName}, ${tR}.0, ${tC}.0, ${stride0}.0, ${stride1}.0,
${stride2}.0, row, col, depth, depth2);
}
`;
}

function getSampler2D(
texName: string, shape: [number, number],
texShape: [number, number]): string {
Expand Down