-
Notifications
You must be signed in to change notification settings - Fork 18
/
operand.ts
116 lines (100 loc) 路 2.96 KB
/
operand.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import {MLGraphBuilder} from './graph_builder';
import {Operation} from './operation';
import {ArrayBufferView} from './types';
import * as utils from './utils';
/**
* [spec](https://webmachinelearning.github.io/webnn/#enumdef-MLOperandDataType)
*/
export enum MLOperandDataType {
'float32' = 'float32',
'float16' = 'float16',
'int32' = 'int32',
'uint32' = 'uint32',
'int8' = 'int8',
'uint8' = 'uint8'
}
/**
* [spec](https://webmachinelearning.github.io/webnn/#dictdef-mloperanddescriptor)
*/
export interface MLOperandDescriptor {
dataType: MLOperandDataType;
dimensions: number[];
}
/**
* [spec](https://webmachinelearning.github.io/webnn/#api-mloperand)
*/
export class MLOperand {
private readonly builder_: MLGraphBuilder;
readonly desc: MLOperandDescriptor;
/** @internal */
get builder(): MLGraphBuilder {
return this.builder_;
}
/** @internal */
constructor(builder: MLGraphBuilder, desc: MLOperandDescriptor) {
this.builder_ = builder;
this.desc = desc;
}
dataType(): MLOperandDataType {
return this.desc.dataType;
}
shape(): number[] {
let resultShape: number[] = [];
if (this.desc.dimensions) {
resultShape = this.desc.dimensions.slice();
}
return resultShape;
}
rank(): number {
return this.desc.dimensions.length;
}
}
/** @internal */
export class InputOperand extends MLOperand {
readonly name: string;
constructor(
name: string, desc: MLOperandDescriptor, builder: MLGraphBuilder) {
super(builder, desc);
utils.assert(typeof name === 'string', 'The name parameter is invalid');
this.name = name;
utils.validateOperandDescriptor(desc);
}
}
/** @internal */
export class ConstantOperand extends MLOperand {
readonly value: number|ArrayBufferView;
static createScalar(
value: number, type: MLOperandDataType = MLOperandDataType.float32,
builder: MLGraphBuilder): ConstantOperand {
utils.assert(
type in MLOperandDataType, 'The operand data type is invalid.');
utils.validateValueType(value, type);
return new ConstantOperand(
{dataType: type} as MLOperandDescriptor, value, builder);
}
static createTensor(
desc: MLOperandDescriptor, value: ArrayBufferView,
builder: MLGraphBuilder): ConstantOperand {
utils.assert(
utils.isTypedArray(value),
'Only ArrayBufferView value type is supported.');
const array = value ;
utils.validateOperandDescriptor(desc);
utils.validateTypedArray(array, desc.dataType, desc.dimensions);
return new ConstantOperand(desc, array.slice(), builder);
}
private constructor(
desc: MLOperandDescriptor, value: number|ArrayBufferView,
builder: MLGraphBuilder) {
super(builder, desc);
this.value = value;
}
}
/** @ignore */
export class OutputOperand extends MLOperand {
readonly operation: Operation;
constructor(operation: Operation, desc: MLOperandDescriptor) {
super(operation.builder, desc);
this.operation = operation;
}
}