Large diffs are not rendered by default.

@@ -0,0 +1,18 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licnses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

export {MobileNet} from './mobilenet';
@@ -0,0 +1,179 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line:max-line-length
import {Array1D, Array3D, Array4D, CheckpointLoader, initializeGPU, Model, NDArray, NDArrayMathCPU, NDArrayMath, NDArrayMathGPU, Scalar} from 'deeplearn';
import {IMAGENET_CLASSES} from './imagenet_classes';

const GOOGLE_CLOUD_STORAGE_DIR =
'https://storage.googleapis.com/learnjs-data/checkpoint_zoo/';

export class MobileNet implements Model {
private variables: {[varName: string]: NDArray};

// yolo variables
private PREPROCESS_DIVISOR = Scalar.new(255.0/2);
private ONE = Scalar.ONE;

constructor(private math: NDArrayMath) {
// TODO(nsthorat): This awful hack is because we need to share the global
// GPGPU between deeplearn loaded from standalone as well as the internal
// deeplearn that gets compiled as part of this model. Remove this once we
// decouple NDArray from storage mechanism.
initializeGPU(
(this.math as NDArrayMathGPU).getGPGPUContext(),
(this.math as NDArrayMathGPU).getTextureManager());
}

/**
* Loads necessary variables for MobileNet.
*/
async load(): Promise<void> {
const checkpointLoader = new CheckpointLoader(GOOGLE_CLOUD_STORAGE_DIR +
'mobilenet_v1_1.0_224/');
this.variables = await checkpointLoader.getAllVariables();
}

/**
* Infer through MobileNet, assumes variables have been loaded. This does
* standard ImageNet pre-processing before inferring through the model. This
* method returns named activations as well as pre-softmax logits.
*
* @param input un-preprocessed input Array.
* @return Named activations and the pre-softmax logits.
*/
predict(input: Array3D): Array1D {
// Keep a map of named activations for rendering purposes.
const netout = this.math.scope((keep) => {
// Preprocess the input.
const preprocessedInput = this.math.subtract(
this.math.arrayDividedByScalar(input,
this.PREPROCESS_DIVISOR), this.ONE) as Array3D;

const x1 = this.convBlock(preprocessedInput, [2, 2]);
const x2 = this.depthwiseConvBlock(x1, [1, 1], 1);

const x3 = this.depthwiseConvBlock(x2, [2, 2], 2);
const x4 = this.depthwiseConvBlock(x3, [1, 1], 3);

const x5 = this.depthwiseConvBlock(x4, [2, 2], 4);
const x6 = this.depthwiseConvBlock(x5, [1, 1], 5);

const x7 = this.depthwiseConvBlock(x6, [2, 2], 6);
const x8 = this.depthwiseConvBlock(x7, [1, 1], 7);
const x9 = this.depthwiseConvBlock(x8, [1, 1], 8);
const x10 = this.depthwiseConvBlock(x9, [1, 1], 9);
const x11 = this.depthwiseConvBlock(x10, [1, 1], 10);
const x12 = this.depthwiseConvBlock(x11, [1, 1], 11);

const x13 = this.depthwiseConvBlock(x12, [2, 2], 12);
const x14 = this.depthwiseConvBlock(x13, [1, 1], 13);

const x15 = this.math.avgPool(x14, x14.shape[0], 1, 0);
const x16 = this.math.conv2d(x15,
this.variables['MobilenetV1/Logits/Conv2d_1c_1x1/weights'] as Array4D,
this.variables['MobilenetV1/Logits/Conv2d_1c_1x1/biases'] as Array1D,
1,
'same');

return x16.as1D();
});

return netout;
}

private convBlock(inputs: Array3D, strides: [number, number]) {
const convPadding = 'MobilenetV1/Conv2d_0';

const x1 = this.math.conv2d(inputs,
this.variables[convPadding + '/weights'] as Array4D,
null, // this convolutional layer does not use bias
strides,
'same');

const x2 = this.math.batchNormalization3D(x1,
this.variables[convPadding + '/BatchNorm/moving_mean'] as Array1D,
this.variables[convPadding + '/BatchNorm/moving_variance'] as Array1D,
.001,
this.variables[convPadding + '/BatchNorm/gamma'] as Array1D,
this.variables[convPadding + '/BatchNorm/beta'] as Array1D);

return this.math.clip(x2, 0, 6); // simple implementation of Relu6
}

private depthwiseConvBlock(inputs: Array3D,
strides: [number, number],
blockID: number) {
const dwPadding = 'MobilenetV1/Conv2d_' + String(blockID) + '_depthwise';
const pwPadding = 'MobilenetV1/Conv2d_' + String(blockID) + '_pointwise';

const x1 = this.math.depthwiseConv2D(inputs,
this.variables[dwPadding + '/depthwise_weights'] as Array4D,
strides,
'same') as Array3D;

const x2 = this.math.batchNormalization3D(x1,
this.variables[dwPadding + '/BatchNorm/moving_mean'] as Array1D,
this.variables[dwPadding + '/BatchNorm/moving_variance'] as Array1D,
.001,
this.variables[dwPadding + '/BatchNorm/gamma'] as Array1D,
this.variables[dwPadding + '/BatchNorm/beta'] as Array1D);

const x3 = this.math.clip(x2, 0, 6);

const x4 = this.math.conv2d(x3,
this.variables[pwPadding + '/weights'] as Array4D,
null, // this convolutional layer does not use bias
[1, 1],
'same');

const x5 = this.math.batchNormalization3D(x4,
this.variables[pwPadding + '/BatchNorm/moving_mean'] as Array1D,
this.variables[pwPadding + '/BatchNorm/moving_variance'] as Array1D,
.001,
this.variables[pwPadding + '/BatchNorm/gamma'] as Array1D,
this.variables[pwPadding + '/BatchNorm/beta'] as Array1D);

return this.math.clip(x5, 0, 6);
}

/**
* Get the topK classes for pre-softmax logits. Returns a map of className
* to softmax normalized probability.
*
* @param logits Pre-softmax logits array.
* @param topK How many top classes to return.
*/
async getTopKClasses(logits: Array1D, topK: number):
Promise<{[className: string]: number}> {
const predictions = this.math.softmax(logits);
const topk = new NDArrayMathCPU().topK(predictions, topK);
const topkIndices = await topk.indices.data();
const topkValues = await topk.values.data();

const topClassesToProbability: {[className: string]: number} = {};
for (let i = 0; i < topkIndices.length; i++) {
topClassesToProbability[IMAGENET_CLASSES[topkIndices[i]]] = topkValues[i];
}
return topClassesToProbability;
}

dispose() {
for (const varName in this.variables) {
this.variables[varName].dispose();
}
}
}
@@ -0,0 +1,31 @@
{
"name": "deeplearn-mobilenet",
"version": "0.1.2",
"description": "Pretrained MobileNet model in deeplearn.js",
"main": "dist/index.js",
"unpkg": "dist/bundle.js",
"types": "dist/index.d.ts",
"peerDependencies": {
"deeplearn": "~0.3.11"
},
"repository": {
"type": "git",
"url": "https://github.com/PAIR-code/deeplearnjs.git"
},
"devDependencies": {
"deeplearn": "~0.3.11",
"mkdirp": "~0.5.1",
"tsify": "~3.0.3",
"tslint": "~5.8.0",
"typescript": "~2.6.1",
"uglifyjs": "~2.4.11",
"watchify": "~3.9.0"
},
"scripts": {
"prep": "yarn && mkdirp dist",
"build": "browserify --standalone mobilenet mobilenet.ts -p [tsify] -o dist/bundle.js",
"lint": "tslint -p . -t verbose",
"publish-npm": "tsc --sourceMap false && yarn build && npm publish"
},
"license": "Apache-2.0"
}
@@ -0,0 +1,11 @@
{
"extends": "../../tsconfig.json",
"include": [
"index.ts", "mobiletnet.ts"
],
"exclude": [
"node_modules/",
"dist/"
]
}

@@ -202,4 +202,4 @@ export class SqueezeNet implements Model {
this.variables[varName].dispose();
}
}
}
}
@@ -17,6 +17,10 @@ acorn@^4.0.3:
version "4.0.13"
resolved "https://registry.yarnpkg.com/acorn/-/acorn-4.0.13.tgz#105495ae5361d697bd195c825192e1ad7f253787"

acorn@^5.2.1:
version "5.2.1"
resolved "https://registry.yarnpkg.com/acorn/-/acorn-5.2.1.tgz#317ac7821826c22c702d66189ab8359675f135d7"

ajv@^4.9.1:
version "4.11.8"
resolved "https://registry.yarnpkg.com/ajv/-/ajv-4.11.8.tgz#82ffb02b29e662ae53bdc20af15947706739c536"
@@ -474,8 +478,8 @@ combined-stream@^1.0.5, combined-stream@~1.0.5:
delayed-stream "~1.0.0"

commander@^2.9.0:
version "2.11.0"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.11.0.tgz#157152fd1e7a6c8d98a5b715cf376df928004563"
version "2.12.2"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.12.2.tgz#0f5946c427ed9ec0d91a46bb9def53e54650e555"

concat-map@0.0.1:
version "0.0.1"
@@ -504,8 +508,8 @@ constants-browserify@~1.0.0:
resolved "https://registry.yarnpkg.com/constants-browserify/-/constants-browserify-1.0.0.tgz#c20b96d8c617748aaf1c16021760cd27fcb8cb75"

convert-source-map@^1.1.0:
version "1.5.0"
resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.5.0.tgz#9acd70851c6d5dfdd93d9282e5edf94a03ff46b5"
version "1.5.1"
resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.5.1.tgz#b8278097b9bc229365de5c62cf5fcaed8b5599e5"

convert-source-map@~1.1.0:
version "1.1.3"
@@ -619,14 +623,14 @@ des.js@^1.0.0:
minimalistic-assert "^1.0.0"

detect-libc@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/detect-libc/-/detect-libc-1.0.2.tgz#71ad5d204bf17a6a6ca8f450c61454066ef461e1"
version "1.0.3"
resolved "https://registry.yarnpkg.com/detect-libc/-/detect-libc-1.0.3.tgz#fa137c4bd698edf55cd5cd02ac559f91a4c4ba9b"

detective@^4.0.0:
version "4.5.0"
resolved "https://registry.yarnpkg.com/detective/-/detective-4.5.0.tgz#6e5a8c6b26e6c7a254b1c6b6d7490d98ec91edd1"
version "4.7.0"
resolved "https://registry.yarnpkg.com/detective/-/detective-4.7.0.tgz#6276e150f9e50829ad1f90ace4d9a2304188afcf"
dependencies:
acorn "^4.0.3"
acorn "^5.2.1"
defined "^1.0.0"

diff@^3.2.0:
@@ -716,10 +720,14 @@ extglob@^0.3.1:
dependencies:
is-extglob "^1.0.0"

extsprintf@1.3.0, extsprintf@^1.2.0:
extsprintf@1.3.0:
version "1.3.0"
resolved "https://registry.yarnpkg.com/extsprintf/-/extsprintf-1.3.0.tgz#96918440e3041a7a414f8c52e3c574eb3c3e1e05"

extsprintf@^1.2.0:
version "1.4.0"
resolved "https://registry.yarnpkg.com/extsprintf/-/extsprintf-1.4.0.tgz#e2689f8f356fad62cca65a3a91c5df5f9551692f"

filename-regex@^2.0.0:
version "2.0.1"
resolved "https://registry.yarnpkg.com/filename-regex/-/filename-regex-2.0.1.tgz#c1c4b9bee3e09725ddb106b75c1e301fe2f18b26"
@@ -947,8 +955,8 @@ inherits@2.0.1:
resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.1.tgz#b17d08d326b4423e568eff719f91b0b1cbdf69f1"

ini@~1.3.0:
version "1.3.4"
resolved "https://registry.yarnpkg.com/ini/-/ini-1.3.4.tgz#0537cb79daf59b59a1a517dff706c86ec039162e"
version "1.3.5"
resolved "https://registry.yarnpkg.com/ini/-/ini-1.3.5.tgz#eee25f56db1c9ec6085e0c22778083f596abf927"

inline-source-map@~0.6.0:
version "0.6.2"
@@ -1790,8 +1798,8 @@ tsconfig@^5.0.3:
strip-json-comments "^2.0.0"

tsify@~3.0.3:
version "3.0.3"
resolved "https://registry.yarnpkg.com/tsify/-/tsify-3.0.3.tgz#a032e1a6a71c2621c3f25c0415459d53b70b9ec0"
version "3.0.4"
resolved "https://registry.yarnpkg.com/tsify/-/tsify-3.0.4.tgz#3c862c934aeeff705290de9ad2af8d197ac5bb03"
dependencies:
convert-source-map "^1.1.0"
fs.realpath "^1.0.0"
@@ -1821,8 +1829,8 @@ tslint@~5.8.0:
tsutils "^2.12.1"

tsutils@^2.12.1:
version "2.12.2"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.12.2.tgz#ad58a4865d17ec3ddb6631b6ca53be14a5656ff3"
version "2.13.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.13.0.tgz#0f52b6aabbc4216e72796b66db028c6cf173e144"
dependencies:
tslib "^1.7.1"

@@ -1845,8 +1853,8 @@ typedarray@~0.0.5:
resolved "https://registry.yarnpkg.com/typedarray/-/typedarray-0.0.6.tgz#867ac74e3864187b1d3d47d996a78ec5c8830777"

typescript@~2.6.1:
version "2.6.1"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-2.6.1.tgz#ef39cdea27abac0b500242d6726ab90e0c846631"
version "2.6.2"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-2.6.2.tgz#3c5b6fd7f6de0914269027f03c0946758f7673a4"

uglifyjs@~2.4.11:
version "2.4.11"
@@ -0,0 +1,33 @@
# MobileNet model

This package contains a standalone MobileNet model for detection - You Only Look Once (YOLO).

## Installation
You can use this as standalone es5 bundle like this:

```html
<script src="https://unpkg.com/deeplearn-mobilenet"></script>
```

Or you can install it via npm for use in a TypeScript / ES6 project.

```sh
npm install deeplearn-yolo_mobilenet --save-dev
```

## Usage

Check out [demo.html](https://github.com/PAIR-code/deeplearnjs/blob/master/yolo_mobilenet/demo.html)
for an example with ES5.

To run the demo, use the following:

```bash
cd models/yolo_mobilenet
npm run prep
npm run build
# Starts a webserver, navigate to localhost:8000/demo.html.
python -m SimpleHTTPServer
```
@@ -0,0 +1,123 @@
<script src="dist/bundle.js"></script>

<style>
#cat {
position: relative;
top: 0px;
left: 0px;
}

#canvas {
position: absolute;
top: 0px;
left: 0px;
}
</style>

<script src="https://unpkg.com/deeplearn"></script>


<img id="cat" height="416" width="416" crossorigin="anonymous"></img>
<canvas id="canvas" height="416" width="416"></canvas>
<br />
<div id="result">Status</div>
<br />
<select id="image" oninput="loadImage()">
<option value="raccoon1.jpg">raccoon1</option>
<option value="raccoon2.jpg">raccoon2</option>
<option value="raccoon3.jpg">raccoon3</option>
<option value="raccoon4.jpg">raccoon4</option>
<option value="raccoon5.jpg">raccoon5</option>
</select>

<script type="text/javascript">
const cat = document.getElementById('cat');
const resultElement = document.getElementById('result');
const canvas = document.getElementById('canvas');

context = canvas.getContext('2d');
width = canvas.width;
height = canvas.height;

const math = new dl.NDArrayMathGPU();
const yoloMobileNet = new yolo_mobilenet.YoloMobileNetDetection(math);

resultElement.innerText = 'Downloading weights ...';
yoloMobileNet.load().then(loadImage);

function loadImage() {
input = document.getElementById('image');
cat.src = input.value;
}

cat.onload = async () => {
resultElement.innerText = 'Predicting...';
const pixels = dl.Array3D.fromPixels(cat);

var t0 = performance.now();
const result = await yoloMobileNet.predict(pixels);
var t1 = performance.now();
const inferenceTime = t1 - t0

var t0 = performance.now();
const boxes = await yoloMobileNet.interpretNetout(result);
var t1 = performance.now();
const postProcessingTime = t1 - t0

context.clearRect(0, 0, canvas.width, canvas.height);
context.beginPath();

for (i = 0; i < boxes.length; i++) {
box = boxes[i];

const x = (box.x - box.w/2) * width;
const y = (box.y - box.h/2) * height;
const w = box.w * width;
const h = box.h * height;

// draw the rectangle bounding box;
context.strokeStyle = box.getColor();
context.lineWidth = 5;
context.rect(x,y,w,h);
context.stroke();

// draw the label and the probability
const label = box.getLabel() + ' ' + box.getMaxProb().toFixed(2).toString();
const font = '24px serif';

context.font = font;
context.textBaseline = 'top';
context.fillStyle = box.getColor();
const textWidth = context.measureText(label).width;
context.fillRect(x-2, y-24, textWidth, parseInt(font, 10));

context.fillStyle = 'rgb(255,255,255)';
context.fillText(label, x-2, y-24);
}

resultElement.innerText = 'Complete!, Inference time: ' + Math.round(inferenceTime) + 'ms' +
', Post precessing time: ' + Math.round(postProcessingTime) + 'ms';
}
</script>

<script>
// When the user clicks on the button, toggle between hiding and showing the dropdown content
function toggleList() {
document.getElementById("myDropdown").classList.toggle("show");
}

// Close the dropdown menu if the user clicks outside of it
window.onclick = function(event) {
if (!event.target.matches('.dropbtn')) {

var dropdowns = document.getElementsByClassName("dropdown-content");
var i;
for (i = 0; i < dropdowns.length; i++) {
var openDropdown = dropdowns[i];
if (openDropdown.classList.contains('show')) {
openDropdown.classList.remove('show');
}
}
}
}
</script>
@@ -0,0 +1,18 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licnses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

export {YoloMobileNetDetection} from './yolo_mobilenet';
@@ -0,0 +1,93 @@
export class BoundingBox {
public x : number;
public y : number;
public w : number;
public h : number;
public c : number;
public probs : Float32Array;

private maxProb = -1;
private maxIndx = -1;

public static LABELS = ['raccoon'];
public static COLORS = ['rgb(43,206,72)'];

constructor(x: number,
y: number,
w: number,
h: number,
conf: number,
probs: Float32Array) {
this.x = x;
this.y = y;
this.w = w;
this.h = h;
this.c = conf;

this.probs = probs;
}

public getMaxProb() : number {
if (this.maxProb === -1) {
this.maxProb = this.probs.reduce((a,b) => Math.max(a,b));
}

return this.maxProb;
}

public getLabel() : string {
if (this.maxIndx === -1) {
this.maxIndx = this.probs.indexOf(this.getMaxProb());
}

return BoundingBox.LABELS[this.maxIndx];
}

public getColor() : string {
if (this.maxIndx === -1) {
this.maxIndx = this.probs.indexOf(this.getMaxProb());
}

return BoundingBox.COLORS[this.maxIndx];
}

public iou(box: BoundingBox): number {
const intersection = this.intersect(box);
const union = this.w*this.h + box.w*box.h - intersection;

return intersection/union;
}

private intersect(box: BoundingBox): number {
const width = this.overlap([this.x-this.w/2,
this.x+this.w/2],
[box.x-box.w/2,
box.x+box.w/2]);
const height = this.overlap([this.y-this.h/2,
this.y+this.h/2],
[box.y-box.h/2,
box.y+box.h/2]);

return width * height;
}

private overlap(intervalA: [number, number],
intervalB: [number, number]): number {
const [x1, x2] = intervalA;
const [x3, x4] = intervalB;

if (x3 < x1) {
if (x4 < x1) {
return 0;
} else {
return Math.min(x2,x4) - x1;
}
} else {
if (x2 < x3) {
return 0;
} else {
return Math.min(x2,x4) - x3;
}
}
}
}
@@ -0,0 +1,31 @@
{
"name": "deeplearn-mobilenet",
"version": "0.1.2",
"description": "Pretrained MobileNet model in deeplearn.js",
"main": "dist/index.js",
"unpkg": "dist/bundle.js",
"types": "dist/index.d.ts",
"peerDependencies": {
"deeplearn": "~0.3.11"
},
"repository": {
"type": "git",
"url": "https://github.com/PAIR-code/deeplearnjs.git"
},
"devDependencies": {
"deeplearn": "~0.3.11",
"mkdirp": "~0.5.1",
"tsify": "~3.0.3",
"tslint": "~5.8.0",
"typescript": "~2.6.1",
"uglifyjs": "~2.4.11",
"watchify": "~3.9.0"
},
"scripts": {
"prep": "yarn && mkdirp dist",
"build": "browserify --standalone yolo_mobilenet yolo_mobilenet.ts -p [tsify] -o dist/bundle.js",
"lint": "tslint -p . -t verbose",
"publish-npm": "tsc --sourceMap false && yarn build && npm publish"
},
"license": "Apache-2.0"
}
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@@ -0,0 +1,10 @@
{
"extends": "../../tsconfig.json",
"include": [
"index.ts", "yolo_mobilenet.ts"
],
"exclude": [
"node_modules/",
"dist/"
]
}
@@ -0,0 +1,252 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line:max-line-length
import {Array1D, Array3D, Array4D, CheckpointLoader, initializeGPU, Model, NDArray, NDArrayMath, NDArrayMathGPU, Scalar} from 'deeplearn';
import {BoundingBox} from './mobilenet_utils';

const GOOGLE_CLOUD_STORAGE_DIR =
'https://storage.googleapis.com/learnjs-data/checkpoint_zoo/';

export class YoloMobileNetDetection implements Model {
private variables: {[varName: string]: NDArray};

// yolo variables
private PREPROCESS_DIVISOR = Scalar.new(255.0/2);
private ONE = Scalar.ONE;
private THRESHOLD = 0.3;
private THRESHOLD_SCALAR = Scalar.new(this.THRESHOLD);
private ANCHORS: number[] = [0.57273, 0.677385, 1.87446, 2.06253, 3.33843,
5.47434, 7.88282, 3.52778, 9.77052, 9.16828];

constructor(private math: NDArrayMath) {
// TODO(nsthorat): This awful hack is because we need to share the global
// GPGPU between deeplearn loaded from standalone as well as the internal
// deeplearn that gets compiled as part of this model. Remove this once we
// decouple NDArray from storage mechanism.
initializeGPU(
(this.math as NDArrayMathGPU).getGPGPUContext(),
(this.math as NDArrayMathGPU).getTextureManager());
}

/**
* Loads necessary variables for MobileNet.
*/
async load(): Promise<void> {
const checkpointLoader = new CheckpointLoader(GOOGLE_CLOUD_STORAGE_DIR +
'yolo_mobilenet_v1_1.0_416/');
this.variables = await checkpointLoader.getAllVariables();
}

/**
* Infer through MobileNet, assumes variables have been loaded. This does
* standard ImageNet pre-processing before inferring through the model. This
* method returns named activations as well as pre-softmax logits.
*
* @param input un-preprocessed input Array.
* @return Named activations and the pre-softmax logits.
*/
predict(input: Array3D): Array4D {
// Keep a map of named activations for rendering purposes.
const netout = this.math.scope((keep) => {
// Preprocess the input.
const preprocessedInput = this.math.subtract(
this.math.arrayDividedByScalar(input, this.PREPROCESS_DIVISOR),
this.ONE) as Array3D;

const x1 = this.convBlock(preprocessedInput, [2, 2]);
const x2 = this.depthwiseConvBlock(x1, [1, 1], 1);

const x3 = this.depthwiseConvBlock(x2, [2, 2], 2);
const x4 = this.depthwiseConvBlock(x3, [1, 1], 3);

const x5 = this.depthwiseConvBlock(x4, [2, 2], 4);
const x6 = this.depthwiseConvBlock(x5, [1, 1], 5);

const x7 = this.depthwiseConvBlock(x6, [2, 2], 6);
const x8 = this.depthwiseConvBlock(x7, [1, 1], 7);
const x9 = this.depthwiseConvBlock(x8, [1, 1], 8);
const x10 = this.depthwiseConvBlock(x9, [1, 1], 9);
const x11 = this.depthwiseConvBlock(x10, [1, 1], 10);
const x12 = this.depthwiseConvBlock(x11, [1, 1], 11);

const x13 = this.depthwiseConvBlock(x12, [2, 2], 12);
const x14 = this.depthwiseConvBlock(x13, [1, 1], 13);

const x15 = this.math.conv2d(x14,
this.variables['conv_23/kernel'] as Array4D,
this.variables['conv_23/bias'] as Array1D,
[1,1],
'same');

return x15.as4D(13, 13, 5, 6);
});

return netout;
}

private convBlock(inputs: Array3D, strides: [number, number]) {

const x1 = this.math.conv2d(inputs,
this.variables['conv1/kernel'] as Array4D,
null, // this convolutional layer does not use bias
strides,
'same');

const x2 = this.math.batchNormalization3D(x1,
this.variables['conv1_bn/moving_mean'] as Array1D,
this.variables['conv1_bn/moving_variance'] as Array1D,
.001,
this.variables['conv1_bn/gamma'] as Array1D,
this.variables['conv1_bn/beta'] as Array1D);

return this.math.clip(x2, 0, 6); // simple implementation of Relu6
}

private depthwiseConvBlock(inputs: Array3D,
strides: [number, number],
blockID: number) {
const dwPadding = 'conv_dw_' + String(blockID) + '';
const pwPadding = 'conv_pw_' + String(blockID) + '';

const x1 = this.math.depthwiseConv2D(inputs,
this.variables[dwPadding + '/depthwise_kernel'] as Array4D,
strides,
'same') as Array3D;

const x2 = this.math.batchNormalization3D(x1,
this.variables[dwPadding + '_bn/moving_mean'] as Array1D,
this.variables[dwPadding + '_bn/moving_variance'] as Array1D,
.001,
this.variables[dwPadding + '_bn/gamma'] as Array1D,
this.variables[dwPadding + '_bn/beta'] as Array1D);

const x3 = this.math.clip(x2, 0, 6);

const x4 = this.math.conv2d(x3,
this.variables[pwPadding + '/kernel'] as Array4D,
null, // this convolutional layer does not use bias
[1, 1],
'same');

const x5 = this.math.batchNormalization3D(x4,
this.variables[pwPadding + '_bn/moving_mean'] as Array1D,
this.variables[pwPadding + '_bn/moving_variance'] as Array1D,
.001,
this.variables[pwPadding + '_bn/gamma'] as Array1D,
this.variables[pwPadding + '_bn/beta'] as Array1D);

return this.math.clip(x5, 0, 6);
}

async interpretNetout(netout: Array4D): Promise<BoundingBox[]> {
// interpret the output by the network
const GRID_H = netout.shape[0];
const GRID_W = netout.shape[1];
const BOX = netout.shape[2];
const CLASS = netout.shape[3] - 5;
const boxes: BoundingBox[] = [];

// adjust confidence predictions
const confidence = this.math.sigmoid(this.math.slice4D(netout,
[0, 0, 0, 4],
[GRID_H, GRID_W, BOX, 1]));

// adjust class prediction
let classes = this.math.softmax(this.math.slice4D(netout, [0, 0, 0, 5],
[GRID_H, GRID_W, BOX, CLASS]));
classes = this.math.multiply(classes, confidence) as Array4D;
const mask = this.math.step(this.math.relu(this.math.subtract(classes,
this.THRESHOLD_SCALAR)));
classes = this.math.multiply(classes, mask) as Array4D;

const objectLikelihood = this.math.sum(classes, 3);
const objectLikelihoodValues = objectLikelihood.getValues();

for (let i = 0; i < objectLikelihoodValues.length; i++) {
if (objectLikelihoodValues[i] > 0) {
const [row, col, box] = objectLikelihood.indexToLoc(i) as number[];

const conf = confidence.get(row, col, box, 0);
const probs = this.math.slice4D(classes,
[row, col, box, 0], [1, 1, 1, CLASS]).getValues() as Float32Array;
const xywh = this.math.slice4D(netout,
[row, col, box, 0], [1, 1, 1, 4]).getValues();

let x = xywh[0];
let y = xywh[1];
let w = xywh[2];
let h = xywh[3];
x = (col + this.sigmoid(x)) / GRID_W;
y = (row + this.sigmoid(y)) / GRID_H;
w = this.ANCHORS[2 * box + 0] * Math.exp(w) / GRID_W;
h = this.ANCHORS[2 * box + 1] * Math.exp(h) / GRID_H;

boxes.push(new BoundingBox(x, y, w, h, conf, probs));
}
}

// suppress nonmaximal boxes
for (let cls = 0; cls < CLASS; cls++) {
const allProbs = boxes.map((box) => box.probs[cls]);
const indices = new Array(allProbs.length);

for (let i = 0; i < allProbs.length; ++i) {
indices[i] = i;
}

indices.sort((a,b) => allProbs[a] > allProbs[b] ? 1 : 0);

for (let i = 0; i < allProbs.length; i++) {
const indexI = indices[i];

if (boxes[indexI].probs[cls] === 0) {
continue;
} else {
for (let j = i+1; j < allProbs.length; j++){
const indexJ = indices[j];

if (boxes[indexI].iou(boxes[indexJ]) > 0.4) {
boxes[indexJ].probs[cls] = 0;
}
}
}
}

}

// obtain the most likely boxes
const likelyBoxes = [];

for (const box of boxes) {
if (box.getMaxProb() > this.THRESHOLD) {
likelyBoxes.push(box);
}
}

return likelyBoxes;
}

private sigmoid(x: number): number {
return 1./ (1. + Math.exp(-x));
}

dispose() {
for (const varName in this.variables) {
this.variables[varName].dispose();
}
}
}