Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tfjs-core/src/backends/backend_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export function castTensor<T extends Tensor>(
const real = backend.real(x);
const result = real.cast(dtype);
real.dispose();
return result;
return result as T;
}
if (dtype === 'int32') {
return backend.int(x);
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ export class Engine implements TensorTracker, DataMover {
dtype?: DataType): Variable {
name = name || this.nextVariableId().toString();
if (dtype != null && dtype !== initialValue.dtype) {
initialValue = initialValue.asType(dtype);
initialValue = initialValue.cast(dtype);
}
const v = new Variable(initialValue, trainable, name, this.nextTensorId());
if (this.state.registeredVariables[v.name] != null) {
Expand Down
32 changes: 32 additions & 0 deletions tfjs-core/src/public/chained_ops/abs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

// TODO update import path once op is modularized.
import {abs} from '../../ops/ops';
import {Tensor} from '../../tensor';
import {Rank} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
abs<T extends Tensor>(this: T): T;
}
}

Tensor.prototype.abs = function<T extends Tensor>(this: T) {
this.throwIfDisposed();
return abs(this);
};
32 changes: 32 additions & 0 deletions tfjs-core/src/public/chained_ops/acos.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

// TODO update import path once op is modularized.
import {acos} from '../../ops/ops';
import {Tensor} from '../../tensor';
import {Rank} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
acos<T extends Tensor>(this: T): T;
}
}

Tensor.prototype.acos = function<T extends Tensor>(this: T) {
this.throwIfDisposed();
return acos(this);
};
32 changes: 32 additions & 0 deletions tfjs-core/src/public/chained_ops/acosh.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

// TODO update import path once op is modularized.
import {acosh} from '../../ops/ops';
import {Tensor} from '../../tensor';
import {Rank} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
acosh<T extends Tensor>(this: T): T;
}
}

Tensor.prototype.acosh = function<T extends Tensor>(this: T) {
this.throwIfDisposed();
return acosh(this);
};
36 changes: 36 additions & 0 deletions tfjs-core/src/public/chained_ops/add_strict.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

// TODO update import path once op is modularized.
import {addStrict} from '../../ops/ops';
import {Tensor} from '../../tensor';
import {Rank, TensorLike} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
addStrict<T extends Tensor>(this: T, x: T|TensorLike): T;
}
}

/**
* @deprecated strict variants of ops have been deprecated
*/
Tensor.prototype.addStrict = function<T extends Tensor>(
this: T, x: T|TensorLike) {
this.throwIfDisposed();
return addStrict(this, x);
};
5 changes: 2 additions & 3 deletions tfjs-core/src/public/chained_ops/arg_min.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ declare module '../../tensor' {
}
}

Tensor.prototype.argMin = function<T extends Tensor>(axis?: number): T {
Tensor.prototype.argMin = function<T extends Tensor>(axis: number): T {
this.throwIfDisposed();
// tslint:disable-next-line: no-unnecessary-type-assertion
return argMin(this, axis) as T;
return argMin(this, axis);
};
33 changes: 33 additions & 0 deletions tfjs-core/src/public/chained_ops/as1d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {reshape} from '../../ops/reshape';
import {Tensor} from '../../tensor';
import {Rank} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
as1D<T extends Tensor>(): Tensor1D;
}
}

/** Converts a `tf.Tensor` to a `tf.Tensor1D`. */
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
Tensor.prototype.as1D = function<T extends Tensor>(): T {
this.throwIfDisposed();
return reshape(this, [this.size]) as T;
};
39 changes: 39 additions & 0 deletions tfjs-core/src/public/chained_ops/as2d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {reshape} from '../../ops/reshape';
import {Tensor} from '../../tensor';
import {Rank} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
as2D<T extends Tensor>(rows: number, columns: number): Tensor2D;
}
}

/**
* Converts a `tf.Tensor` to a `tf.Tensor2D`.
*
* @param rows Number of rows in `tf.Tensor2D`.
* @param columns Number of columns in `tf.Tensor2D`.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
Tensor.prototype.as2D = function<T extends Tensor>(
rows: number, columns: number): T {
this.throwIfDisposed();
return reshape(this, [rows, columns]) as T;
};
41 changes: 41 additions & 0 deletions tfjs-core/src/public/chained_ops/as3d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {reshape} from '../../ops/reshape';
import {Tensor} from '../../tensor';
import {Rank} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
as3D<T extends Tensor>(rows: number, columns: number, depth: number):
Tensor3D;
}
}

/**
* Converts a `tf.Tensor` to a `tf.Tensor3D`.
*
* @param rows Number of rows in `tf.Tensor3D`.
* @param columns Number of columns in `tf.Tensor3D`.
* @param depth Depth of `tf.Tensor3D`.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
Tensor.prototype.as3D = function<T extends Tensor>(
rows: number, columns: number, depth: number): T {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth]) as T;
};
42 changes: 42 additions & 0 deletions tfjs-core/src/public/chained_ops/as4d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {reshape} from '../../ops/reshape';
import {Tensor} from '../../tensor';
import {Rank} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
as4D<T extends Tensor>(
rows: number, columns: number, depth: number, depth2: number): Tensor4D;
}
}

/**
* Converts a `tf.Tensor` to a `tf.Tensor4D`.
*
* @param rows Number of rows in `tf.Tensor4D`.
* @param columns Number of columns in `tf.Tensor4D`.
* @param depth Depth of `tf.Tensor4D`.
* @param depth2 4th dimension of `tf.Tensor4D`.
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
Tensor.prototype.as4D = function<T extends Tensor>(
rows: number, columns: number, depth: number, depth2: number): T {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth, depth2]) as T;
};
45 changes: 45 additions & 0 deletions tfjs-core/src/public/chained_ops/as5d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {reshape} from '../../ops/reshape';
import {Tensor} from '../../tensor';
import {Rank} from '../../types';

declare module '../../tensor' {
interface Tensor<R extends Rank = Rank> {
as5D<T extends Tensor>(
rows: number, columns: number, depth: number, depth2: number,
depth3: number): Tensor5D;
}
}

/**
* Converts a `tf.Tensor` to a `tf.Tensor5D`.
*
* @param rows Number of rows in `tf.Tensor5D`.
* @param columns Number of columns in `tf.Tensor5D`.
* @param depth Depth of `tf.Tensor5D`.
* @param depth2 4th dimension of `tf.Tensor5D`.
* @param depth3 5th dimension of 'tf.Tensor5D'
*/
/** @doc {heading: 'Tensors', subheading: 'Classes'} */
Tensor.prototype.as5D = function<T extends Tensor>(
rows: number, columns: number, depth: number, depth2: number,
depth3: number): T {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth, depth2, depth3]) as T;
};
Loading