Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 2 additions & 211 deletions tfjs-core/src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import * as util from '../util';
import {getAxesPermutation, getInnerMostAxes} from './axis_util';
import {concat} from './concat_split';
import {op} from './operation';
import {MPRandGauss, RandGamma, UniformRandom} from './rand';
import {zeros, zerosLike} from './tensor_ops';

/**
Expand Down Expand Up @@ -101,207 +100,6 @@ function eye_(
}
}

/**
* Creates a `tf.Tensor` with values sampled from a normal distribution.
*
* ```js
* tf.randomNormal([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param mean The mean of the normal distribution.
* @param stdDev The standard deviation of the normal distribution.
* @param dtype The data type of the output.
* @param seed The seed for the random number generator.
*/
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomNormal_<R extends Rank>(
shape: ShapeMap[R], mean = 0, stdDev = 1, dtype?: 'float32'|'int32',
seed?: number): Tensor<R> {
if (dtype != null && (dtype as DataType) === 'bool') {
throw new Error(`Unsupported data type ${dtype}`);
}
const randGauss =
new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);
const res = buffer(shape, dtype);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}

/**
* Creates a `tf.Tensor` with values sampled from a truncated normal
* distribution.
*
* ```js
* tf.truncatedNormal([2, 2]).print();
* ```
*
* The generated values follow a normal distribution with specified mean and
* standard deviation, except that values whose magnitude is more than 2
* standard deviations from the mean are dropped and re-picked.
*
* @param shape An array of integers defining the output tensor shape.
* @param mean The mean of the normal distribution.
* @param stdDev The standard deviation of the normal distribution.
* @param dtype The data type of the output tensor.
* @param seed The seed for the random number generator.
*/
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function truncatedNormal_<R extends Rank>(
shape: ShapeMap[R], mean = 0, stdDev = 1, dtype?: 'float32'|'int32',
seed?: number): Tensor<R> {
if (dtype != null && (dtype as DataType) === 'bool') {
throw new Error(`Unsupported data type ${dtype}`);
}
const randGauss =
new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);
const res = buffer(shape, dtype);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}

/**
* Creates a `tf.Tensor` with values sampled from a gamma distribution.
*
* ```js
* tf.randomGamma([2, 2], 1).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param alpha The shape parameter of the gamma distribution.
* @param beta The inverse scale parameter of the gamma distribution. Defaults
* to 1.
* @param dtype The data type of the output. Defaults to float32.
* @param seed The seed for the random number generator.
*/
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomGamma_<R extends Rank>(
shape: ShapeMap[R], alpha: number, beta = 1,
dtype: 'float32'|'int32' = 'float32', seed?: number): Tensor<R> {
if (beta == null) {
beta = 1;
}
if (dtype == null) {
dtype = 'float32';
}
if (dtype !== 'float32' && dtype !== 'int32') {
throw new Error(`Unsupported data type ${dtype}`);
}
const rgamma = new RandGamma(alpha, beta, dtype, seed);
const res = buffer(shape, dtype);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = rgamma.nextValue();
}
return res.toTensor();
}

/**
* Creates a `tf.Tensor` with values sampled from a uniform distribution.
*
* The generated values follow a uniform distribution in the range [minval,
* maxval). The lower bound minval is included in the range, while the upper
* bound maxval is excluded.
*
* ```js
* tf.randomUniform([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param minval The lower bound on the range of random values to generate.
* Defaults to 0.
* @param maxval The upper bound on the range of random values to generate.
* Defaults to 1.
* @param dtype The data type of the output tensor. Defaults to 'float32'.
*/
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomUniform_<R extends Rank>(
shape: ShapeMap[R], minval = 0, maxval = 1, dtype: DataType = 'float32',
seed?: number|string): Tensor<R> {
const res = buffer(shape, dtype);
const random = new UniformRandom(minval, maxval, null, seed);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = random.nextValue();
}
return res.toTensor();
}

/**
* Creates a `tf.Tensor` with values sampled from a random number generator
* function defined by the user.
*
* @param shape An array of integers defining the output tensor shape.
* @param randFunction A random number generator function which is called
* for each element in the output tensor.
* @param dtype The data type of the output tensor. Defaults to 'float32'.
*/
function rand_<R extends Rank>(
shape: ShapeMap[R], randFunction: () => number,
dtype?: DataType): Tensor<R> {
const size = util.sizeFromShape(shape);

let values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
} else if (dtype === 'int32') {
values = new Int32Array(size);
} else if (dtype === 'bool') {
values = new Uint8Array(size);
} else {
throw new Error(`Unknown data type ${dtype}`);
}

for (let i = 0; i < size; i++) {
values[i] = randFunction();
}
return ENGINE.makeTensor(values, shape, dtype) as Tensor<R>;
}

/**
* Creates a `tf.Tensor` with values drawn from a multinomial distribution.
*
* ```js
* const probs = tf.tensor([.75, .25]);
* tf.multinomial(probs, 3).print();
* ```
*
* @param logits 1D array with unnormalized log-probabilities, or
* 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
* parameter.
* @param numSamples Number of samples to draw for each row slice.
* @param seed The seed number.
* @param normalized Whether the provided `logits` are normalized true
* probabilities (sum to 1). Defaults to false.
* @return 1D array of shape `[numSamples]`, or 2D array of shape
* `[batchSize, numSamples]`, depending on the rank of the input.
*/
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function multinomial_(
logits: Tensor1D|Tensor2D|TensorLike, numSamples: number, seed?: number,
normalized = false): Tensor1D|Tensor2D {
const $logits = convertToTensor(logits, 'logits', 'multinomial');
const numOutcomes = $logits.size;
const origRank = $logits.rank;
if (numOutcomes < 2) {
throw new Error(
`Error in multinomial: you need at least 2 outcomes, but got ` +
`${numOutcomes}.`);
}
if (origRank > 2) {
throw new Error(`Rank of probabilities must be 1 or 2, but is ${origRank}`);
}
seed = seed || Math.random();
const logits2D = origRank === 1 ? $logits.as2D(1, -1) : $logits as Tensor2D;
const res = ENGINE.runKernelFunc(
backend => backend.multinomial(logits2D, normalized, numSamples, seed),
{logits2D});

return origRank === 1 ? res.as1D() : res;
}

/**
* Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
* value `onValue` (defaults to 1), while all other locations take value
Expand Down Expand Up @@ -1100,7 +898,7 @@ async function setdiff1dAsync_(
* zeros.
*/
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function buffer<R extends Rank, D extends DataType = 'float32'>(
export function buffer<R extends Rank, D extends DataType = 'float32'>(
shape: ShapeMap[R], dtype: D = 'float32' as D,
values?: DataTypeMap[D]): TensorBuffer<R, D> {
dtype = dtype || 'float32' as D;
Expand All @@ -1125,8 +923,7 @@ function print<T extends Tensor>(x: T, verbose = false): void {
}

export {
buffer, // Not wrapped in op() since no tensors.
print // Not wrapped in op() since no need to increase stack trace.
print // Not wrapped in op() since no need to increase stack trace.
};

export const batchToSpaceND = op({batchToSpaceND_});
Expand All @@ -1136,22 +933,16 @@ export const cumsum = op({cumsum_});
export const depthToSpace = op({depthToSpace_});
export const expandDims = op({expandDims_});
export const eye = op({eye_});
export const multinomial = op({multinomial_});
export const oneHot = op({oneHot_});
export const pad = op({pad_});
export const pad1d = op({pad1d_});
export const pad2d = op({pad2d_});
export const pad3d = op({pad3d_});
export const pad4d = op({pad4d_});
export const rand = op({rand_});
export const randomNormal = op({randomNormal_});
export const randomGamma = op({randomGamma_});
export const randomUniform = op({randomUniform_});
export const reshape = op({reshape_});
export const spaceToBatchND = op({spaceToBatchND_});
export const squeeze = op({squeeze_});
export const stack = op({stack_});
export const tile = op({tile_});
export const truncatedNormal = op({truncatedNormal_});
export const unstack = op({unstack_});
export const setdiff1dAsync = setdiff1dAsync_;
Loading