Skip to content

Commit

Permalink
Fix forward of pow(-f, i)
Browse files Browse the repository at this point in the history
  • Loading branch information
praeclarum committed Jul 2, 2023
1 parent c9a908b commit 0f1c13c
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 18 deletions.
25 changes: 21 additions & 4 deletions src/expr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

export type ExprCode = number | string;

export type ExprNodeType = "apply" | "block" | "assign" | "if" | "negate" | "return" | "statements" | "var" | "+" | "-" | "*" | "/" | "==" | "!=" | "<" | "<=" | ">" | ">=" | "&&" | "||" | "!" | "~" | "^" | "%" | "?";
export type ExprNodeType = "apply" | "block" | "assign" | "if" | "negate" | "return" | "statements" | "var" | "+" | "-" | "*" | "/" | "==" | "!=" | "<" | "<=" | ">" | ">=" | "&&" | "||" | "&" | "|" | "^" | "!" | "~" | "^" | "%" | "?";

export type ExprAtom = string | ManifestNumber;
export type ExprCell = [ExprNodeType, ExprNode[]];
Expand Down Expand Up @@ -82,7 +82,11 @@ function lexn(code: string): (string | ManifestNumber)[] {
s += code[i];
i++;
}
const numType = hasDecimal ? "floatAbstract" : "intAbstract";
let numType: ManifestNumberType = hasDecimal ? "floatAbstract" : "intAbstract";
if (i < n && code[i] === "f") {
i++;
numType = "floatAbstract";
}
tokens.push(new ManifestNumber(numType, parseFloat(s)));
continue;
}
Expand Down Expand Up @@ -229,7 +233,7 @@ export function parseCode(code: ExprCode): ExprNode {
}
const t2 = tokens[i];
if (typeof t2 !== "string" || t2 !== ":") {
throw new Error("Expected :");
throw new Error(`Expected ':', got ${t2}`);
}
const expr3 = parseConditional(i + 1);
if (expr3 === null) {
Expand Down Expand Up @@ -271,7 +275,10 @@ export function parseCode(code: ExprCode): ExprNode {
const parseAddOrSubtract =genericParseSeparatedList(parseMultiplyOrDivide, ["+", "-"]);
const parseRelational = genericParseSeparatedList(parseAddOrSubtract, ["<", ">", "<=", ">="]);
const parseEquality = genericParseSeparatedList(parseRelational, ["==", "!="]);
const parseLogicalAnd = genericParseSeparatedList(parseEquality, ["&&"]);
const parseAnd = genericParseSeparatedList(parseEquality, ["&"]);
const parseExclusiveOr = genericParseSeparatedList(parseAnd, ["|"]);
const parseInclusiveOr = genericParseSeparatedList(parseExclusiveOr, ["|"]);
const parseLogicalAnd = genericParseSeparatedList(parseInclusiveOr, ["&&"]);
const parseLogicalOr = genericParseSeparatedList(parseLogicalAnd, ["||"]);
function parseUnary(i: number): ParseState|null {
if (i >= tokens.length) {
Expand Down Expand Up @@ -539,6 +546,16 @@ export function exprNodeToString(ast: ExprNode): string {
return `(${exprNodeToString(ast[1][0])} && ${exprNodeToString(ast[1][1])})`;
case "||":
return `(${exprNodeToString(ast[1][0])} || ${exprNodeToString(ast[1][1])})`;
case "&":
return `(${exprNodeToString(ast[1][0])} & ${exprNodeToString(ast[1][1])})`;
case "|":
return `(${exprNodeToString(ast[1][0])} | ${exprNodeToString(ast[1][1])})`;
case "^":
return `(${exprNodeToString(ast[1][0])} ^ ${exprNodeToString(ast[1][1])})`;
case "&&":
return `(${exprNodeToString(ast[1][0])} && ${exprNodeToString(ast[1][1])})`;
case "||":
return `(${exprNodeToString(ast[1][0])} || ${exprNodeToString(ast[1][1])})`;
case "?":
return `(${exprNodeToString(ast[1][0])} ? ${exprNodeToString(ast[1][1])} : ${exprNodeToString(ast[1][2])})`;
case "apply":
Expand Down
4 changes: 3 additions & 1 deletion src/op_table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ export const registry: AnOpSpec[] = [
{
name: "pow",
type: "binary",
forward: "output = pow(input, other)",
// forward: "output = pow(input, other)",
forward: `output = input >= 0 ? pow(input, other) :
(fract(other) == 0 ? (pow(-input, other) * ((i32(other) & 1) != 0 ? -1f : 1f)) : pow(input, other))`,
backward: "inputGrad = outputGrad * other * pow(input, other - 1.0); otherGrad = outputGrad * pow(input, other) * log(input)",
},
// quantized_batch_norm: quantization
Expand Down
52 changes: 52 additions & 0 deletions src/ops.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,58 @@ test("numel", async () => {
expect(x.numel()).toEqual(6);
});

test("pow(3) of 31x3 random numbers", async () => {
const batchSize = 31;
const pointsArray: number[][] = [];
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
pointsArray.push([(Math.random()-0.5)*2.1, (Math.random()-0.5)*2.1, (Math.random()-0.5)*2.1]);
}
const points = tensor({data:pointsArray, requiresGrad: true});
const pointsAr = await points.toArrayAsync() as number[][];
console.log(pointsAr);
const y = points.pow(3);
expect(y.shape).toEqual([batchSize, 3]);
const yArray = await y.toArrayAsync() as number[][];
for (let b = 0; b < batchSize; b++) {
for (let i = 0; i < 3; i++) {
expect(yArray[b][i]).not.toBeNaN();
expect(yArray[b][i]).toBeCloseTo(pointsArray[b][i] * pointsArray[b][i] * pointsArray[b][i]);
}
}
const loss = y.sum();
const lossValue = await loss.toArrayAsync() as number;
expect(lossValue).not.toBeNaN();
loss.backward();
expect(points.grad).not.toBeNull();
const gradArray = await points.grad!.toArrayAsync() as number[][];
console.log(gradArray);
for (let b = 0; b < batchSize; b++) {
for (let i = 0; i < 3; i++) {
expect(gradArray[b][i]).not.toBeNaN();
expect(gradArray[b][i]).toBeCloseTo(3 * pointsArray[b][i] * pointsArray[b][i]);
}
}
});

test("pow(3) grad of 1 negative number", async () => {
const pointsArray: number[] = [-3.14];
const points = tensor({data:pointsArray, requiresGrad: true});
const y = points.pow(3);
const yArray = await y.toArrayAsync() as number[][];
expect(yArray[0]).not.toBeNaN();
expect(yArray[0]).toBeCloseTo(pointsArray[0] * pointsArray[0] * pointsArray[0]);
const loss = y.abs();
const lossValue = await loss.toArrayAsync() as number;
expect(lossValue).not.toBeNaN();
loss.backward();
expect(points.grad).not.toBeNull();
expect(y.grad).not.toBeNull();
const gradArray = await points.grad!.toArrayAsync() as number[][];
console.log(gradArray);
expect(gradArray[0]).not.toBeNaN();
expect(gradArray[0]).toBeCloseTo(-3 * pointsArray[0] * pointsArray[0]);
});

test("reshape [2, 3] to [3, 2]", async () => {
const a = tensor([[25.0, -102.0, -1.0], [7.0, -95.0, -38.0]]);
expect(a.shape).toEqual([2, 3]);
Expand Down
2 changes: 1 addition & 1 deletion src/ops_opgen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ export function positive(input: Tensor): Tensor {
/**
* Calculates:
* ```js
* output = pow(input, other)
* output = input
* ```
*
* Gradient:
Expand Down
2 changes: 1 addition & 1 deletion src/optim.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ test("sgd mlp train loop", async () => {
const sphereSDF = (batchedPoints: Tensor) => {
// sqrt(x^2 + y^2 + z^2) - radius
// const distanceToCenterSq = batchedPoints.pow(2).sum(1, true);
const squaredPoint = batchedPoints.mul(batchedPoints);
const squaredPoint = batchedPoints.pow(2);
const distanceToCenterSq = squaredPoint.sum(1, true);
const distanceToCenter = distanceToCenterSq.sqrt();
const distanceToSurface = distanceToCenter.sub(radius);
Expand Down
4 changes: 2 additions & 2 deletions src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,7 @@ export class Tensor extends TensorBase {
/**
* Calculates:
* ```js
* output = pow(input, other)
* output = input
* ```
*
* Gradient:
Expand All @@ -2150,7 +2150,7 @@ export class Tensor extends TensorBase {
/**
* Calculates:
* ```js
* output = pow(input, other)
* output = input
* ```
*
* Gradient:
Expand Down
19 changes: 10 additions & 9 deletions web/tests/testfw.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ function test(description, callback) { testreg.push({ description, callback });

class Expect {
constructor(value, truth) { this.value = value; this.truth = truth; }
toBe(expected) { if (this.truth(!Object.is(this.value, expected))) { throw new Error(`Expected «${this.value}» to be «${expected}»`); } }
toBeCloseTo(expected, precision) { const expDiff = Math.pow(10, -precision)/2; if (this.truth(Math.abs(this.value - expected) >= expDiff)) { throw new Error(`Expected «${this.value}» to be close to «${expected}» (diff: < ${expDiff.toFixed(precision+1)})`); } }
toBeGreaterThan(expected) { if (this.truth(!(this.value > expected))) { throw new Error(`Expected «${this.value}» to be greater than «${expected}»`); } }
toBeGreaterThanOrEqual(expected) { if (this.truth(!(this.value >= expected))) { throw new Error(`Expected «${this.value}» to be greater than or equal to «${expected}»`); } }
toBeInstanceOf(expected) { if (this.truth(!(this.value instanceof expected))) { throw new Error(`Expected «${this.value}» to be instance of «${expected}»`); } }
toBeLessThan(expected) { if (this.truth(!(this.value < expected))) { throw new Error(`Expected «${this.value}» to be less than «${expected}»`); } }
toBeLessThanOrEqual(expected) { if (this.truth(!(this.value <= expected))) { throw new Error(`Expected «${this.value}» to be less than or equal to «${expected}»`); } }
toBeNaN() { if (this.truth(!Number.isNaN(this.value))) { throw new Error(`Expected «${this.value}» to be NaN`); } }
toBeNull() { if (this.truth(this.value !== null)) { throw new Error(`Expected «${this.value}» to be null`); } }
get toBeText() { return this.truth(true) ? "to be" : "to not be"; }
toBe(expected) { if (this.truth(!Object.is(this.value, expected))) { throw new Error(`Expected «${this.value}» ${this.toBeText} «${expected}»`); } }
toBeCloseTo(expected, precision) { const expDiff = Math.pow(10, -(precision||2))/2; if (this.truth(Math.abs(this.value - expected) >= expDiff)) { throw new Error(`Expected «${this.value}» ${this.toBeText} close to «${expected}» (diff: < ${expDiff.toFixed((precision||2)+1)})`); } }
toBeGreaterThan(expected) { if (this.truth(!(this.value > expected))) { throw new Error(`Expected «${this.value}» ${this.toBeText} greater than «${expected}»`); } }
toBeGreaterThanOrEqual(expected) { if (this.truth(!(this.value >= expected))) { throw new Error(`Expected «${this.value}» ${this.toBeText} greater than or equal to «${expected}»`); } }
toBeInstanceOf(expected) { if (this.truth(!(this.value instanceof expected))) { throw new Error(`Expected «${this.value}» ${this.toBeText} instance of «${expected}»`); } }
toBeLessThan(expected) { if (this.truth(!(this.value < expected))) { throw new Error(`Expected «${this.value}» ${this.toBeText} less than «${expected}»`); } }
toBeLessThanOrEqual(expected) { if (this.truth(!(this.value <= expected))) { throw new Error(`Expected «${this.value}» ${this.toBeText} less than or equal to «${expected}»`); } }
toBeNaN() { if (this.truth(!Number.isNaN(this.value))) { throw new Error(`Expected «${this.value}» ${this.toBeText} NaN`); } }
toBeNull() { if (this.truth(this.value !== null)) { throw new Error(`Expected «${this.value}» ${this.toBeText} null`); } }
toEqual(expected) { if (this.truth(!eq(this.value, expected))) { throw new Error(`Expected «${JSON.stringify(this.value)}» to equal «${JSON.stringify(expected)}»`); } }
toHaveLength(expected) { if (this.truth(this.value.length !== expected)) { throw new Error(`Expected «${this.value}» to have length «${expected}»`); } }
toThrow(expected) {
Expand Down

0 comments on commit 0f1c13c

Please sign in to comment.