Skip to content

Commit

Permalink
Merge pull request #2208 from quantified-uncertainty/pointset-improve…
Browse files Browse the repository at this point in the history
…ments

Pointset refactorings
  • Loading branch information
berekuk committed Aug 17, 2023
2 parents 56044dc + b7ea48c commit 803e27c
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 342 deletions.
2 changes: 1 addition & 1 deletion packages/squiggle-lang/__tests__/dist/Scale_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ describe("Scale logarithm", () => {
(-Math.log2(high - low) / 2) * (high ** 2 - low ** 2);
if (!meanResult.ok) {
expect(meanResult.value).toEqual(
operationDistError(NegativeInfinityError)
operationDistError(new NegativeInfinityError())
);
} else {
expect(meanResult.value).toBeCloseTo(meanAnalytical);
Expand Down
2 changes: 1 addition & 1 deletion packages/squiggle-lang/__tests__/library/pointset_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ describe("Mean of mixture is weighted average of means", () => {
fc.float({ min: toFloat32(1e-7), max: 100, noNaN: true }),
async (normalMean, normalStdev, betaA, betaB, x, y) => {
// normaalize is due to https://github.com/quantified-uncertainty/squiggle/issues/1400 bug
const squiggleString = `mean(mixture(normal(${normalMean},${normalStdev}), beta(${betaA},${betaB}), [${x}, ${y}])->normalize)`;
const squiggleString = `mean(mixture(Sym.normal(${normalMean},${normalStdev}), Sym.beta(${betaA},${betaB}), [${x}, ${y}])->normalize)`;
const res = await testRun(squiggleString);
const weightDenom = x + y;
const normalWeight = x / weightDenom;
Expand Down
27 changes: 4 additions & 23 deletions packages/squiggle-lang/src/PointSet/Continuous.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ export class ContinuousShape implements PointSet<ContinuousShape> {
);
}

isEmpty() {
return this.xyShape.xs.length === 0;
}
toContinuous() {
return this;
}
Expand Down Expand Up @@ -380,31 +383,9 @@ export const sum = (continuousShapes: ContinuousShape[]): ContinuousShape => {
}, empty());
};

export const reduce = <E>(
continuousShapes: ContinuousShape[],
fn: (v1: number, v2: number) => Result.result<number, E>,
integralSumCachesFn: (v1: number, v2: number) => number | undefined = () =>
undefined
): Result.result<ContinuousShape, E> => {
let acc = empty();
for (const shape of continuousShapes) {
const result = combinePointwise(
acc,
shape,
fn,
undefined,
integralSumCachesFn
);
if (!result.ok) {
return result;
}
acc = result.value;
}
return Result.Ok(acc);
};

/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
each discrete data point, and then adds them all together. */
//TODO WARNING: The combineAlgebraicallyWithDiscrete will break for subtraction and division, like, discrete - continous
export const combineAlgebraicallyWithDiscrete = (
op: ConvolutionOperation,
t1: ContinuousShape,
Expand Down
20 changes: 3 additions & 17 deletions packages/squiggle-lang/src/PointSet/Discrete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ export class DiscreteShape implements PointSet<DiscreteShape> {
);
}

isEmpty() {
return this.xyShape.xs.length === 0;
}
toContinuous() {
return undefined;
}
Expand Down Expand Up @@ -276,23 +279,6 @@ export const combinePointwise = <E>(
);
};

export const reduce = <E>(
shapes: DiscreteShape[],
fn: (v1: number, v2: number) => Result.result<number, E>,
integralSumCachesFn: (v1: number, v2: number) => number | undefined = () =>
undefined
): Result.result<DiscreteShape, E> => {
let acc = empty();
for (const shape of shapes) {
const result = combinePointwise(acc, shape, fn, integralSumCachesFn);
if (!result.ok) {
return result;
}
acc = result.value;
}
return Result.Ok(acc);
};

/* This multiples all of the data points together and creates a new discrete distribution from the results.
Data points at the same xs get added together. It may be a good idea to downsample t1 and t2 before and/or the result after. */
export const combineAlgebraically = (
Expand Down
68 changes: 29 additions & 39 deletions packages/squiggle-lang/src/PointSet/Mixed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import * as Discrete from "./Discrete.js";
import * as MixedPoint from "./MixedPoint.js";
import * as Result from "../utility/result.js";
import * as Common from "./Common.js";
import { AnyPointSet } from "./PointSet.js";
import { ContinuousShape } from "./Continuous.js";
import { DiscreteShape } from "./Discrete.js";
import { ConvolutionOperation, PointSet } from "./PointSet.js";
Expand Down Expand Up @@ -52,6 +51,9 @@ export class MixedShape implements PointSet<MixedShape> {
return Math.max(this.continuous.maxX(), this.discrete.maxX());
}

isEmpty() {
return this.continuous.isEmpty() && this.discrete.isEmpty();
}
toContinuous() {
return this.continuous;
}
Expand All @@ -70,6 +72,15 @@ export class MixedShape implements PointSet<MixedShape> {
}

normalize() {
if (this.isEmpty()) {
return this; // still not normalized, throw an error?
}
if (this.continuous.isEmpty()) {
return this.discrete.normalize().toMixed();
}
if (this.discrete.isEmpty()) {
return this.continuous.normalize().toMixed();
}
const continuousIntegralSum = this.continuous.integralSum();
const discreteIntegralSum = this.discrete.integralSum();

Expand Down Expand Up @@ -232,7 +243,11 @@ export class MixedShape implements PointSet<MixedShape> {
const discreteMean = this.discrete.mean();
const continuousMean = this.continuous.mean();
// means are already weighted by subshape probabilities
return (discreteMean + continuousMean) / this.integralSum();
return (
(discreteMean * this.discrete.integralSum() +
continuousMean * this.continuous.integralSum()) /
this.integralSum()
);
}
variance(): number {
// the combined mean is the weighted sum of the two:
Expand Down Expand Up @@ -264,26 +279,6 @@ export class MixedShape implements PointSet<MixedShape> {
}
}

// let totalLength = (t: t): int => {
// let continuousLength = t.continuous.xyShape->XYShape.T.length
// let discreteLength = t.discrete.xyShape->XYShape.T.length

// continuousLength + discreteLength
// }

// let scaleBy = (t: t, scale): t => {
// let scaledDiscrete = Discrete.scaleBy(t.discrete, scale)
// let scaledContinuous = Continuous.scaleBy(t.continuous, scale)
// let scaledIntegralCache = E.O.bind(t.integralCache, v => Some(Continuous.scaleBy(v, scale)))
// let scaledIntegralSumCache = E.O.bind(t.integralSumCache, s => Some(s *. scale))
// make(
// ~discrete=scaledDiscrete,
// ~continuous=scaledContinuous,
// ~integralSumCache=scaledIntegralSumCache,
// ~integralCache=scaledIntegralCache,
// )
// }

export const combineAlgebraically = (
op: ConvolutionOperation,
t1: MixedShape,
Expand Down Expand Up @@ -320,6 +315,7 @@ export const combineAlgebraically = (
t2.discrete,
"Second"
);

const continuousConvResult = Continuous.sum([
ccConvResult,
dcConvResult,
Expand Down Expand Up @@ -358,19 +354,18 @@ export const combinePointwise = <E>(
v2: ContinuousShape
) => ContinuousShape | undefined = () => undefined
): Result.result<MixedShape, E> => {
const isDefined = <T>(argument: T | undefined): argument is T => {
return argument !== undefined;
};

const reducedDiscrete = Discrete.reduce(
[t1, t2].map((t) => t.toDiscrete()).filter(isDefined),
const reducedDiscrete = Discrete.combinePointwise(
t1.toDiscrete(),
t2.toDiscrete(),
fn,
integralSumCachesFn
);

const reducedContinuous = Continuous.reduce(
[t1, t2].map((t) => t.toContinuous()).filter(isDefined),
const reducedContinuous = Continuous.combinePointwise(
t1.toContinuous(),
t2.toContinuous(),
fn,
undefined,
integralSumCachesFn
);

Expand Down Expand Up @@ -398,13 +393,13 @@ export const combinePointwise = <E>(
);
};

export const buildMixedShape = ({
export function buildMixedShape({
continuous,
discrete,
}: {
continuous?: ContinuousShape;
discrete?: DiscreteShape;
}): AnyPointSet | undefined => {
}): MixedShape | undefined {
continuous ??= new ContinuousShape({
integralSumCache: 0,
xyShape: { xs: [], ys: [] },
Expand All @@ -417,12 +412,7 @@ export const buildMixedShape = ({
const dLength = discrete.xyShape.xs.length;
if (cLength < 2 && dLength == 0) {
return undefined;
} else if (cLength < 2) {
return discrete;
} else if (dLength == 0) {
return continuous;
} else {
const mixedDist = new MixedShape({ continuous, discrete });
return mixedDist;
return new MixedShape({ continuous, discrete });
}
};
}
20 changes: 1 addition & 19 deletions packages/squiggle-lang/src/PointSet/PointSet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export interface PointSet<T> {
| ((cache: ContinuousShape) => ContinuousShape | undefined)
): result<T, E>;
xToY(x: number): MixedPoint;
isEmpty(): boolean;
toContinuous(): ContinuousShape | undefined;
toDiscrete(): DiscreteShape | undefined;
toMixed(): MixedShape;
Expand All @@ -59,25 +60,6 @@ export interface PointSet<T> {
variance(): number;
}

//TODO WARNING: The combineAlgebraicallyWithDiscrete will break for subtraction and division, like, discrete - continous
export const combineAlgebraically = (
op: ConvolutionOperation,
t1: AnyPointSet,
t2: AnyPointSet
): AnyPointSet => {
if (t1 instanceof ContinuousShape && t2 instanceof ContinuousShape) {
return Continuous.combineAlgebraically(op, t1, t2);
} else if (t1 instanceof DiscreteShape && t2 instanceof ContinuousShape) {
return Continuous.combineAlgebraicallyWithDiscrete(op, t2, t1, "First");
} else if (t1 instanceof ContinuousShape && t2 instanceof DiscreteShape) {
return Continuous.combineAlgebraicallyWithDiscrete(op, t1, t2, "Second");
} else if (t1 instanceof DiscreteShape && t2 instanceof DiscreteShape) {
return Discrete.combineAlgebraically(op, t1, t2);
} else {
return Mixed.combineAlgebraically(op, t1.toMixed(), t2.toMixed());
}
};

export const combinePointwise = <E>(
t1: AnyPointSet,
t2: AnyPointSet,
Expand Down

3 comments on commit 803e27c

@vercel
Copy link

@vercel vercel bot commented on 803e27c Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on 803e27c Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on 803e27c Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.