Skip to content

Commit

Permalink
feat(shader-ast): major updates
Browse files Browse the repository at this point in the history
- add initial collection of re-usable shader functions
  - SDF primitives & combinators
  - raymarch helpers
  - fog/falloff functions
  - clamp / fit
  - lambert / diffuse lighting
- add constantFolding() tree optimizer
- add userland function dependencies (mandatory, but still unused)
- optimize single component swizzles in JS target
- add more node type checkers, update walk()
- update types
  • Loading branch information
postspectacular committed Jun 17, 2019
1 parent b313a56 commit 51d42b4
Show file tree
Hide file tree
Showing 13 changed files with 770 additions and 61 deletions.
4 changes: 2 additions & 2 deletions packages/shader-ast/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"build:es6": "tsc --declaration",
"build:bundle": "../../scripts/bundle-module",
"test": "rimraf build && tsc -p test/tsconfig.json && nyc mocha build/test/*.js",
"clean": "rimraf *.js *.d.ts .nyc_output build coverage doc lib",
"clean": "rimraf *.js *.d.ts .nyc_output build coverage doc lib codegen std",
"cover": "yarn test && nyc report --reporter=lcov",
"doc": "node_modules/.bin/typedoc --mode modules --out doc src",
"pub": "yarn build && yarn publish --access public"
Expand Down Expand Up @@ -50,4 +50,4 @@
"access": "public"
},
"sideEffects": false
}
}
18 changes: 12 additions & 6 deletions packages/shader-ast/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ export type Vec = "vec2" | "vec3" | "vec4";
export type IVec = "ivec2" | "ivec3" | "ivec4";
export type BVec = "bvec2" | "bvec3" | "bvec4";
export type Mat = "mat2" | "mat3" | "mat4";
export type Prim = "f32" | "i32" | "u32" | Vec;
export type Comparable = "f32" | "i32";
export type Prim = "f32" | Vec;
export type Int = "i32" | "u32";
export type Comparable = "f32" | Int;
export type Numeric = number | Term<"f32"> | Term<"i32"> | Term<"u32">;

export type Assignable<T extends Type> = Sym<T> | Swizzle<T> | Index<T>;
Expand Down Expand Up @@ -337,6 +338,10 @@ export interface Term<T extends Type> {
type: T;
}

export interface Scoped {
scope: Scope;
}

export interface Lit<T extends Type> extends Term<T> {
val: any;
info?: string;
Expand Down Expand Up @@ -412,10 +417,10 @@ export interface FuncArg<T extends Type> extends Term<T> {
opts: SymOpts;
}

export interface Func<T extends Type> extends Term<T> {
export interface Func<T extends Type> extends Term<T>, Scoped {
id: string;
args: Sym<any>[];
scope: Scope;
deps: Func<any>[];
}

export interface TaggedFn0<T extends Type> extends Func0<T>, Func<T> {
Expand Down Expand Up @@ -509,11 +514,10 @@ export interface FnCall<T extends Type> extends Term<T> {
info?: string;
}

export interface ForLoop extends Term<"void"> {
export interface ForLoop extends Term<"void">, Scoped {
init?: Decl<any>;
test: Term<"bool">;
iter?: Term<any>;
body: Scope;
}

export interface TargetImpl extends Record<Tag, Fn<any, string>> {
Expand All @@ -523,6 +527,7 @@ export interface TargetImpl extends Record<Tag, Fn<any, string>> {
call_i: Fn<FnCall<any>, string>;
decl: Fn<Decl<any>, string>;
fn: Fn<Func<any>, string>;
for: Fn<ForLoop, string>;
idx: Fn<Index<any>, string>;
if: Fn<Branch, string>;
lit: Fn<Lit<any>, string>;
Expand All @@ -532,4 +537,5 @@ export interface TargetImpl extends Record<Tag, Fn<any, string>> {
scope: Fn<Scope, string>;
swizzle: Fn<Swizzle<any>, string>;
sym: Fn<Sym<any>, string>;
ternary: Fn<Ternary<any>, string>;
}
183 changes: 140 additions & 43 deletions packages/shader-ast/src/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import {
Index,
Indexable,
IndexTypeMap,
Int,
IVec,
Lit,
Mat,
Expand Down Expand Up @@ -84,6 +85,21 @@ let symID = 0;

export const gensym = () => `_sym${symID++}`;

export const isF32 = (t: Term<any>) => t.type === "f32";

export const isI32 = (t: Term<any>) => t.type === "i32";

export const isU32 = (t: Term<any>) => t.type === "u32";

export const isLit = (t: Term<any>) => t.tag === "lit";

export const isLitF32 = (t: Term<any>) => isLit(t) && isF32(t);

export const isLitI32 = (t: Term<any>) => isLit(t) && isI32(t);

export const isLitNumeric = (t: Term<any>) =>
isLit(t) && (isF32(t) || isI32(t) || isU32(t));

export const isVec = (t: Term<any>) => t.type.indexOf("vec") >= 0;

export const isMat = (t: Term<any>) => t.type.indexOf("mat") >= 0;
Expand All @@ -92,34 +108,78 @@ export const itemType = (type: Type) => <Type>type.replace("[]", "");

export const wrapF32 = (x?: Numeric) => (isNumber(x) ? float(x) : x);

export const numberWithMatchingType = (t: Term<Prim>, x: number) =>
export const numberWithMatchingType = (t: Term<Prim | Int>, x: number) =>
t.type[0] === "i" ? int(x) : float(x);

export const children = (t: Term<any>) =>
t.tag == "fn"
export const scopeChildren = (t: Term<any>) =>
t.tag === "fn" || t.tag === "for"
? (<Func<any>>t).scope.body
: t.tag === "if"
? (<Branch>t).f
? (<Branch>t).t.body.concat((<Branch>t).f!.body)
: (<Branch>t).t.body
: undefined;

export const allChildren = (t: Term<any>) =>
t.tag === "scope"
? (<Scope>t).body
: t.tag === "fn" || t.tag === "for"
? (<Func<any>>t).scope.body
: t.tag == "if"
: t.tag === "if"
? (<Branch>t).f
? (<Branch>t).t.body.concat((<Branch>t).f!.body)
: (<Branch>t).t.body
: t.tag === "ternary"
? [(<Ternary<any>>t).t, (<Ternary<any>>t).f]
: t.tag === "ret"
? [(<FuncReturn<any>>t).val]
: t.tag === "call" || t.tag === "call_i"
? (<FnCall<any>>t).args
: t.tag === "sym" && (<Sym<any>>t).init
? [(<Sym<any>>t).init]
: t.tag === "op2"
? [(<Op2<any>>t).l, (<Op2<any>>t).r]
: isVec(t) || isMat(t)
? (<Lit<any>>t).val
: undefined;

/**
* Traverses given AST in depth-first order and applies `visit` and
* `children` fns to each node. Descends only further if `children`
* returns an array of child nodes. The `visit` function must accept 2
* args: the accumulator (`acc`) given to `walk` and a tree node. The
* return value of `visit` is ignored. `walk` itself returns the
* possibly updated `acc`.
*
* If `pre` is true (default), the `visit` function will be called prior
* to visiting a node's children. If false, the visitor is called on the
* way back up.
*
* @param visit
* @param children
* @param acc
* @param tree
* @param pre
*/
export const walk = <T>(
visit: Fn2<T, Term<any>, void>,
visit: Fn2<T, Term<any>, T>,
children: Fn<Term<any>, Term<any>[] | undefined>,
acc: T,
t: Term<any> | Term<any>[]
tree: Term<any> | Term<any>[],
pre = true
) => {
if (isArray(t)) {
t.forEach((x) => walk(visit, children, acc, x));
if (isArray(tree)) {
tree.forEach((x) => (acc = walk(visit, children, acc, x, pre)));
} else {
visit(acc, t);
const c = children(t);
c && walk(visit, children, acc, c);
pre && (acc = visit(acc, tree));
const c = children(tree);
c && (acc = walk(visit, children, acc, c, pre));
!pre && (acc = visit(acc, tree));
}
return acc;
};

export function sym<T extends Type>(init: Term<T>): Sym<T>;
export function sym<T extends Type>(type: T): Sym<T>;
export function sym<T extends Type>(type: T, opts: SymOpts): Sym<T>;
export function sym<T extends Type>(type: T, init: Term<T>): Sym<T>;
Expand All @@ -130,12 +190,16 @@ export function sym<T extends Type>(type: T, id: string, opts: SymOpts): Sym<T>;
export function sym<T extends Type>(type: T, opts: SymOpts, init: Term<T>): Sym<T>;
// prettier-ignore
export function sym<T extends Type>(type: T, id: string, opts: SymOpts, init: Term<T>): Sym<T>;
export function sym<T extends Type>(type: T, ...xs: any[]): Sym<any> {
export function sym<T extends Type>(type: any, ...xs: any[]): Sym<any> {
let id: string;
let opts: SymOpts;
let init: Term<T>;
switch (xs.length) {
case 0:
if (!isString(type)) {
init = type;
type = init.type;
}
break;
case 1:
if (isString(xs[0])) {
Expand Down Expand Up @@ -396,14 +460,14 @@ export const op2 = (
};
};

export const inc = <T extends Prim>(t: Term<T>): Op2<T> =>
export const inc = <T extends Prim | Int>(t: Term<T>): Op2<T> =>
<Op2<any>>add(<Term<any>>t, <Term<any>>numberWithMatchingType(t, 1));

export const dec = <T extends Prim>(t: Term<T>): Op2<T> =>
export const dec = <T extends Prim | Int>(t: Term<T>): Op2<T> =>
<Op2<any>>sub(<Term<any>>t, <Term<any>>numberWithMatchingType(t, 1));

// prettier-ignore
export function add<A extends Prim | IVec | Mat, B extends A>(l: Term<A>, b: Term<B>): Op2<A>;
export function add<A extends Prim | Int | IVec | Mat, B extends A>(l: Term<A>, b: Term<B>): Op2<A>;
// prettier-ignore
export function add<T extends Vec | Mat>(l: Term<"f32">, b: Term<T>): Op2<T>;
// prettier-ignore
Expand All @@ -418,7 +482,7 @@ export function add(l: Term<any>, r: Term<any>): Op2<any> {
}

// prettier-ignore
export function sub<A extends Prim | IVec | Mat, B extends A>(l: Term<A>, b: Term<B>): Op2<A>;
export function sub<A extends Prim | Int | IVec | Mat, B extends A>(l: Term<A>, b: Term<B>): Op2<A>;
// prettier-ignore
export function sub<T extends Vec | Mat>(l: Term<"f32">, b: Term<T>): Op2<T>;
// prettier-ignore
Expand All @@ -432,7 +496,7 @@ export function sub(l: Term<any>, r: Term<any>): Op2<any> {
}

// prettier-ignore
export function mul<A extends Prim | IVec | Mat, B extends A>(l: Term<A>, b: Term<B>): Op2<A>;
export function mul<A extends Prim | Int | IVec | Mat, B extends A>(l: Term<A>, b: Term<B>): Op2<A>;
// prettier-ignore
export function mul<T extends Vec | Mat>(l: Term<"f32">, b: Term<T>): Op2<T>;
// prettier-ignore
Expand All @@ -452,7 +516,7 @@ export function mul(l: Term<any>, r: Term<any>): Op2<any> {
}

// prettier-ignore
export function div<A extends Prim | IVec | Mat, B extends A>(l: Term<A>, b: Term<B>): Op2<A>;
export function div<A extends Prim | Int | IVec | Mat, B extends A>(l: Term<A>, b: Term<B>): Op2<A>;
// prettier-ignore
export function div<T extends Vec | Mat>(l: Term<"f32">, b: Term<T>): Op2<T>;
// prettier-ignore
Expand All @@ -465,7 +529,8 @@ export function div(l: Term<any>, r: Term<any>): Op2<any> {
return op2("/", l, r);
}

export const neg = <T extends Prim | IVec | Mat>(val: Term<T>) => op1("-", val);
export const neg = <T extends Prim | Int | IVec | Mat>(val: Term<T>) =>
op1("-", val);

export const not = (val: Term<"bool">) => op1("!", val);
export const or = (a: Term<"bool">, b: Term<"bool">) => op2("||", a, b);
Expand All @@ -490,59 +555,91 @@ export const scope = (body: Term<any>[], global = false): Scope => ({
global
});

/**
* DO NOT USE YET!
*
* TODO add func dep ordering
*/
export const program = (entry: Func<any>) => scope([entry], true);

const defArg = <T extends Type>([type, id, opts]: Arg<T>): FuncArg<T> => ({
tag: "arg",
type,
id: id || gensym(),
opts: { q: "in", ...opts }
});

/**
* Defines a new function with up to 8 typed checked arguments.
*
* @param type return type
* @param name function name
* @param args arg types / names / opts
* @param body function body closure
* @param deps array of userland functions called from this function
*/
// prettier-ignore
export function defn<T extends Type>(type: T, name: string, args: [], body: FnBody0): TaggedFn0<T>;
export function defn<T extends Type>(type: T, name: string, args: [], body: FnBody0, deps?: Func<any>[]): TaggedFn0<T>;
// prettier-ignore
export function defn<T extends Type, A extends Type>(type: T, name: string, args: Arg1<A>, body: FnBody1<A>): TaggedFn1<A,T>;
export function defn<T extends Type, A extends Type>(type: T, name: string, args: Arg1<A>, body: FnBody1<A>, deps?: Func<any>[]): TaggedFn1<A,T>;
// prettier-ignore
export function defn<T extends Type, A extends Type, B extends Type>(type: T, name: string, args: Arg2<A,B>, body: FnBody2<A,B>): TaggedFn2<A,B,T>;
export function defn<T extends Type, A extends Type, B extends Type>(type: T, name: string, args: Arg2<A,B>, body: FnBody2<A,B>, deps?: Func<any>[]): TaggedFn2<A,B,T>;
// prettier-ignore
export function defn<T extends Type, A extends Type, B extends Type, C extends Type>(type: T, name: string, args: Arg3<A,B,C>, body: FnBody3<A,B,C>): TaggedFn3<A,B,C,T>;
export function defn<T extends Type, A extends Type, B extends Type, C extends Type>(type: T, name: string, args: Arg3<A,B,C>, body: FnBody3<A,B,C>, deps?: Func<any>[]): TaggedFn3<A,B,C,T>;
// prettier-ignore
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type>(type: T, name: string, args: Arg4<A,B,C,D>, body: FnBody4<A,B,C,D>): TaggedFn4<A,B,C,D,T>;
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type>(type: T, name: string, args: Arg4<A,B,C,D>, body: FnBody4<A,B,C,D>, deps?: Func<any>[]): TaggedFn4<A,B,C,D,T>;
// prettier-ignore
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type, E extends Type>(type: T, name: string, args: Arg5<A,B,C,D,E>, body: FnBody5<A,B,C,D,E>): TaggedFn5<A,B,C,D,E,T>;
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type, E extends Type>(type: T, name: string, args: Arg5<A,B,C,D,E>, body: FnBody5<A,B,C,D,E>, deps?: Func<any>[]): TaggedFn5<A,B,C,D,E,T>;
// prettier-ignore
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type, E extends Type, F extends Type>(type: T, name: string, args: Arg6<A,B,C,D,E,F>, body: FnBody6<A,B,C,D,E,F>): TaggedFn6<A,B,C,D,E,F,T>;
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type, E extends Type, F extends Type>(type: T, name: string, args: Arg6<A,B,C,D,E,F>, body: FnBody6<A,B,C,D,E,F>, deps?: Func<any>[]): TaggedFn6<A,B,C,D,E,F,T>;
// prettier-ignore
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type, E extends Type, F extends Type, G extends Type>(type: T, name: string, args: Arg7<A,B,C,D,E,F,G>, body: FnBody7<A,B,C,D,E,F,G>): TaggedFn7<A,B,C,D,E,F,G,T>;
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type, E extends Type, F extends Type, G extends Type>(type: T, name: string, args: Arg7<A,B,C,D,E,F,G>, body: FnBody7<A,B,C,D,E,F,G>, deps?: Func<any>[]): TaggedFn7<A,B,C,D,E,F,G,T>;
// prettier-ignore
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type, E extends Type, F extends Type, G extends Type, H extends Type>(type: T, name: string, args: Arg8<A,B,C,D,E,F,G,H>, body: FnBody8<A,B,C,D,E,F,G,H>): TaggedFn8<A,B,C,D,E,F,G,H,T>;
export function defn<T extends Type, A extends Type, B extends Type, C extends Type, D extends Type, E extends Type, F extends Type, G extends Type, H extends Type>(type: T, name: string, args: Arg8<A,B,C,D,E,F,G,H>, body: FnBody8<A,B,C,D,E,F,G,H>, deps?: Func<any>[]): TaggedFn8<A,B,C,D,E,F,G,H,T>;
// prettier-ignore
export function defn(type: Type, id: string, _args: Arg<any>[], _body: (...xs: Sym<any>[]) => Term<any>[]): Func<any> {
export function defn(type: Type, id: string, _args: Arg<any>[], _body: (...xs: Sym<any>[]) => Term<any>[], deps: Func<any>[]=[]): Func<any> {
const args = _args.map(defArg);
const body = _body(...args.map((x) => sym(x.type, x.id, x.opts)));
// count & check returns
const returns = walk(
(acc: FuncReturn<any>[], t) =>
t.tag === "ret" && acc.push(t),
children,
[],
(n, t) => {
if(t.tag === "ret") {
assert(
t.type === type,
`wrong return type for function '${id}', expected ${type}, got ${
t.type
}`
);
n++;
}
return n;
},
scopeChildren,
0,
body
);
const mismatched = returns.find((t) => t.type !== type);
if (mismatched) {
throw new Error(
`wrong return type for function '${id}', expected ${type}, got ${
mismatched.type
}`
);
} else if (type !== "void" && !returns.length) {
if (type !== "void" && !returns) {
throw new Error(`function '${id}' must return a value of type ${type}`);
}
// verify all non-builtin functions called are also
// provided as deps to ensure complete call graph later
walk(
(_, t) => t.tag === "call" && assert(
!!deps.find((y) => y.id === (<FnCall<any>>t).id),
`function '${id}' calls function '${(<FnCall<any>>t).id}' not given in deps`
),
allChildren,
<any>null,
body
);
const $: any = (...xs: any[]) => funcall(id, type, ...xs);
return Object.assign($, <Func<any>>{
tag: "fn",
type,
id,
args,
scope: scope(body)
deps,
scope: scope(body),
});
}

Expand Down Expand Up @@ -645,7 +742,7 @@ export function forLoop(...xs: any[]): ForLoop {
init: init ? decl(init) : undefined,
test: test(init!),
iter: iter ? iter(init!) : undefined,
body: scope(body(init!))
scope: scope(body(init!))
};
}

Expand Down
Loading

0 comments on commit 51d42b4

Please sign in to comment.