Skip to content

Commit 0eb89a5

Browse files
committed
fix: truncated-normal generator
1 parent 7948ccd commit 0eb89a5

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

lib/backend/tensorflow-backend.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ class TensorflowBackend {
188188

189189
addNoiseOne(img, noiseImg) {
190190
return this._tf.tidy(() => {
191-
const black = this._tf.zerosLike(img);
192-
const white = this._tf.onesLike(img).mul(255);
193-
const res = img.add(noiseImg).maximum(black).minimum(white).toInt();
191+
const black = this._tf.zerosLike(img, 'int32');
192+
const white = this._tf.onesLike(img, 'int32').mul(255);
193+
const res = img.add(noiseImg.round().toInt()).maximum(black).minimum(white);
194194
return res;
195195
});
196196
}

lib/generators/truncated-normal-noise.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ new ia.TruncatedNormalNoiseGenerator({
8080
class TruncatedNormalNoiseGenerator extends GaussianNoiseGenerator {
8181
buildHasard(o) {
8282
if (typeof (this.backend.truncatedNormal) === 'function') {
83+
// Tensorflow specific code
84+
// TO DO move this into tfjs backend folder
8385
const imagesProps = h.array({
8486
size: o.nImages,
8587
value: h.object({

test/augmenters/additive-truncated-normal-noise.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ const nImages = 5;
77

88
test('additiveTruncatedNormalNoise not perChannel', macroAugmenter, AdditiveTruncatedNormalNoise, {
99
inputFilenames: new Array(nImages).fill('lenna.png'),
10-
// BackendLibs: [require('opencv4nodejs')],
10+
backendLibs: [require('@tensorflow/tfjs-node')],
1111
expectImg(t, mats1, mats2, backend) {
1212
const metadatas = backend.getMetadata(mats1);
1313
const metadata = metadatas[0];
@@ -23,8 +23,8 @@ test('additiveTruncatedNormalNoise not perChannel', macroAugmenter, AdditiveTrun
2323
const m2 = backend.imageToArray(mats2);
2424

2525
backend.forEachPixel(diff, ([b, g, r], batchIndex, rowIndex, colIndex) => {
26-
// Console.log(m2[batchIndex] && m2[batchIndex][rowIndex])
2726
if (m2[batchIndex][rowIndex][colIndex].slice(0, 3).indexOf(255) === -1 && m2[batchIndex][rowIndex][colIndex].slice(0, 3).indexOf(0) === -1 && (r !== g || g !== b)) {
27+
console.log(m2[batchIndex] && m2[batchIndex][rowIndex][colIndex], [r, g, b]);
2828
count++;
2929
}
3030
});

0 commit comments

Comments
 (0)