Skip to content

Commit

Permalink
Merge pull request #226 from mingmingtasd/npu_fp16
Browse files Browse the repository at this point in the history
Add NPU device type and three fp16 models for image classification
  • Loading branch information
huningxin committed May 7, 2024
2 parents 3ce8c0d + c53f0d3 commit 480ab02
Show file tree
Hide file tree
Showing 9 changed files with 566 additions and 38 deletions.
12 changes: 12 additions & 0 deletions common/component/component.js
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,18 @@ $(document).ready(async () => {
"title",
"WebNN is supported, disable WebNN Polyfill."
);
// Disable WebNN NPU backend if failed to find a capable NPU adapter.
try {
await navigator.ml.createContext({deviceType: 'npu'});
} catch (error) {
$('#webnn_npu').parent().addClass('disabled');
$('#webnn_npu').parent().addClass('btn-outline-secondary');
$('#webnn_npu').parent().removeClass('btn-outline-info');
$('#webnn_npu').parent().attr(
"title",
"Unable to find a capable NPU adapter."
);
}
}
}
$("#webnnstatus").html("supported").addClass("webnn-status-true");
Expand Down
17 changes: 17 additions & 0 deletions common/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ export function handleClick(cssSelectors, disabled = true) {
}
}

/**
* Handle button UI, disable or enable the button.
* @param {String} selector, css selector.
* @param {Boolean} disabled, disable or enable the button.
*/
export function handleBtnUI(selector, disabled = true) {
if (disabled) {
$(selector).addClass('disabled');
$(selector).addClass('btn-outline-secondary');
$(selector).removeClass('btn-outline-info');
} else {
$(selector).removeClass('disabled');
$(selector).removeClass('btn-outline-secondary');
$(selector).addClass('btn-outline-info');
}
}

/**
* Show flexible alert messages
* @param {String} msg, alert message.
Expand Down
1 change: 1 addition & 0 deletions image_classification/.eslintrc.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module.exports = {
globals: {
'MLGraphBuilder': 'readonly',
'tf': 'readonly',
},
};
176 changes: 176 additions & 0 deletions image_classification/efficientnet_fp16_nchw.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
'use strict';

import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';

// EfficientNet fp16 model with 'nchw' input layout
export class EfficientNetFP16Nchw {
constructor() {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/efficientnet_fp16_nchw_optimized/weights/';
this.inputOptions = {
mean: [0.485, 0.456, 0.406],
std: [0.229, 0.224, 0.225],
norm: true,
inputLayout: 'nchw',
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
};
this.outputDimensions = [1, 1000];
}

async buildConv_(input, name, blockName, clip = false, options = {}) {
let prefix = '';
if (blockName !== '') {
prefix = this.weightsUrl_ + 'block' + blockName + '_conv' +
name;
} else {
prefix = this.weightsUrl_ + 'conv' + name;
}
const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy',
'float16');
options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy',
'float16');
if (clip) {
return this.builder_.clamp(
this.builder_.conv2d(await input, await weight, options),
{minValue: 0, maxValue: 6});
}
return this.builder_.conv2d(await input, await weight, options);
}

async buildGemm_(input, name) {
const prefix = this.weightsUrl_ + 'dense' + name;
const weightName = prefix + '_w.npy';
const weight = buildConstantByNpy(this.builder_, weightName,
'float16');
const biasName = prefix + '_b.npy';
const bias = buildConstantByNpy(this.builder_, biasName,
'float16');
const options =
{c: this.builder_.reshape(await bias, [1, 1000])};
return await this.builder_.gemm(await input, await weight, options);
}

async buildBottleneck_(input, blockName, group, pad = 1) {
const conv1 = this.buildConv_(input, '0', blockName, true);
const conv2 = this.buildConv_(conv1, '1', blockName, true,
{groups: group, padding: [pad, pad, pad, pad]});
const conv3 = this.buildConv_(conv2, '2', blockName);
return this.builder_.add(await conv3, await input);
}

async buildBottlenecks_(input, blockNames, group, pad = 1) {
let result = input;
for (let i = 0; i < blockNames.length; i++) {
const bottleneck = await this.buildBottleneck_(result, blockNames[i],
group, pad);
result = bottleneck;
}
return result;
}

async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
this.builder_ = new MLGraphBuilder(this.context_);
let data = this.builder_.input('input', {
dataType: 'float32',
dimensions: this.inputOptions.inputDimensions,
});
data = this.builder_.cast(data, 'float16');
// Block 0
const conv1 = this.buildConv_(
data, '0', '0', true, {padding: [0, 1, 0, 1], strides: [2, 2]});
const conv2 = this.buildConv_(conv1, '1', '0', true,
{groups: 32, padding: [1, 1, 1, 1]});
const conv3 = this.buildConv_(conv2, '2', '0');

// Block 1
const conv4 = this.buildConv_(conv3, '0', '1', true);
const conv5 = this.buildConv_(conv4, '1', '1', true,
{groups: 144, padding: [0, 1, 0, 1], strides: [2, 2]});
const conv6 = this.buildConv_(conv5, '2', '1');

// Block 2~4
const bottleneck4 = this.buildBottlenecks_(conv6,
['2', '3', '4'], 192);

// Block 5
const conv7 = this.buildConv_(bottleneck4, '0', '5', true);
const conv8 = this.buildConv_(conv7, '1', '5', true,
{groups: 192, padding: [1, 2, 1, 2], strides: [2, 2]});
const conv9 = this.buildConv_(conv8, '2', '5');

// Block 6~8
const bottleneck8 = this.buildBottlenecks_(conv9,
['6', '7', '8'], 336, 2);

// Block 9
const conv10 = this.buildConv_(bottleneck8, '0', '9', true);
const conv11 = this.buildConv_(conv10, '1', '9', true,
{groups: 336, padding: [0, 1, 0, 1], strides: [2, 2]});
const conv12 = this.buildConv_(conv11, '2', '9');

// Block 10~14
const bottleneck14 = this.buildBottlenecks_(conv12,
['10', '11', '12', '13', '14'], 672);

// Block 15
const conv13 = this.buildConv_(bottleneck14, '0', '15', true);
const conv14 = this.buildConv_(conv13, '1', '15', true,
{groups: 672, padding: [2, 2, 2, 2]});
const conv15 = this.buildConv_(conv14, '2', '15');

// Block 16~20
const bottleneck20 = await this.buildBottlenecks_(conv15,
['16', '17', '18', '19', '20'], 960, 2);

// Block 21
const conv16 = this.buildConv_(bottleneck20, '0', '21', true);
const conv17 = this.buildConv_(conv16, '1', '21', true,
{groups: 960, padding: [1, 2, 1, 2], strides: [2, 2]});
const conv18 = this.buildConv_(conv17, '2', '21');

// Block 22~28
const bottleneck28 = this.buildBottlenecks_(conv18,
['22', '23', '24', '25', '26', '27', '28'], 1632, 2);

// Block 29
const conv19 = this.buildConv_(bottleneck28, '0', '29', true);
const conv20 = this.buildConv_(conv19, '1', '29', true,
{groups: 1632, padding: [1, 1, 1, 1]});
const conv21 = this.buildConv_(conv20, '2', '29');

const conv22 = this.buildConv_(conv21, '0', '', true);
const pool1 = this.builder_.averagePool2d(await conv22);
const reshape = this.builder_.reshape(pool1, [1, 1280]);
const gemm = this.buildGemm_(reshape, '0');
if (contextOptions.deviceType === 'npu') {
return this.builder_.cast(await gemm, 'float32');
} else {
const softmax = this.builder_.softmax(await gemm);
return this.builder_.cast(softmax, 'float32');
}
}

async build(outputOperand) {
this.graph_ = await this.builder_.build({'output': outputOperand});
}

// Release the constant tensors of a model
dispose() {
// dispose() is only available in webnn-polyfill
if (this.graph_ !== null && 'dispose' in this.graph_) {
this.graph_.dispose();
}
}

async compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputs = {'output': outputBuffer};
const results = await this.context_.compute(this.graph_, inputs, outputs);
return results;
}
}
46 changes: 37 additions & 9 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off">WebNN (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_npu" autocomplete="off">WebNN (NPU)
</label>
</div>
</div>
</div>
Expand All @@ -61,21 +64,43 @@
</div>
</div>
</div> -->
<div class="row mb-2 align-items-center">
<div class="col-1 col-md-1">
<span>Data Type</span>
</div>
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="dataTypeBtns">
<label class="btn btn-outline-info" id="float32Label" active>
<input type="radio" name="layout" id="float32" autocomplete="off" checked>Float32
</label>
<label class="btn btn-outline-info" id="float16Label">
<input type="radio" name="layout" id="float16" autocomplete="off">Float16
</label>
</div>
</div>
</div>
<div class="row align-items-center">
<div class="col col-md-1">
<span>Model</span>
</div>
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="modelBtns">
<label class="btn btn-outline-info">
<input type="radio" name="model" id="mobilenet" autocomplete="off">MobileNet V2
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="squeezenet" autocomplete="off">SqueezeNet
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="resnet50" autocomplete="off">ResNet V2 50
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="mobilenet" autocomplete="off">MobileNet V2
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="squeezenet" autocomplete="off">SqueezeNet
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="resnet50v2" autocomplete="off">ResNet 50 V2
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="resnet50v1" autocomplete="off">ResNet 50 V1
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="efficientnet" autocomplete="off">EfficientNet
</label>

</div>
</div>
</div>
Expand Down Expand Up @@ -213,6 +238,9 @@ <h2 class="text-uppercase text-info">No model selected</h2>
<script src="https://cdn.jsdelivr.net/npm/popper.js@1.16.1/dist/umd/popper.min.js"
integrity="sha384-9/reFTGAW83EW2RDu2S0VKaIzap3H66lZH81PoYlFhbGU+6BZp6G7niu735Sk7lN"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.9.0/dist/tf.min.js"
integrity="sha256-28ZvjeNGrGNEIj9/2D8YAPE6Vm5JSvvDs+LI4ED31x8="
crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"
integrity="sha384-B4gt1jrGC7Jh4AgTPSdUtOBvfO8shuf57BaghqFfPlYxofvL8/KUEfYiJOMMV+rV"
crossorigin="anonymous"></script>
Expand Down

0 comments on commit 480ab02

Please sign in to comment.