From 0c5f30842552925d6b216667b3d56836fe321633 Mon Sep 17 00:00:00 2001 From: kangyizhang Date: Tue, 17 Dec 2019 14:30:52 -0800 Subject: [PATCH 1/4] multi output --- tfjs-node/src/saved_model.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tfjs-node/src/saved_model.ts b/tfjs-node/src/saved_model.ts index 8df10f388a3..5b0bff7e157 100644 --- a/tfjs-node/src/saved_model.ts +++ b/tfjs-node/src/saved_model.ts @@ -284,9 +284,10 @@ export class TFSavedModel implements InferenceModel { let inputTensors: Tensor[] = []; if (inputs instanceof Tensor) { inputTensors.push(inputs); - return this.backend.runSavedModel( + const result = this.backend.runSavedModel( this.sessionId, inputTensors, Object.values(this.inputNodeNames), - Object.values(this.outputNodeNames))[0]; + Object.values(this.outputNodeNames)); + return result.length > 1 ? result : result[0]; } else if (Array.isArray(inputs)) { inputTensors = inputs; return this.backend.runSavedModel( From 5ffd8ace82b313656f89bb0975c2a6678fbef459 Mon Sep 17 00:00:00 2001 From: kangyizhang Date: Tue, 17 Dec 2019 14:39:42 -0800 Subject: [PATCH 2/4] save --- tfjs-node/src/saved_model_test.ts | 17 +++++++++++++++++ .../saved_model.pb | Bin 0 -> 9702 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 172 bytes .../variables/variables.index | Bin 0 -> 236 bytes 4 files changed, 17 insertions(+) create mode 100644 tfjs-node/test_objects/saved_model/model_single_input_multi_output/saved_model.pb create mode 100644 tfjs-node/test_objects/saved_model/model_single_input_multi_output/variables/variables.data-00000-of-00001 create mode 100644 tfjs-node/test_objects/saved_model/model_single_input_multi_output/variables/variables.index diff --git a/tfjs-node/src/saved_model_test.ts b/tfjs-node/src/saved_model_test.ts index 3a9bed7a7b7..89d5c1f1e1e 100644 --- a/tfjs-node/src/saved_model_test.ts +++ b/tfjs-node/src/saved_model_test.ts @@ -429,6 +429,23 @@ describe('SavedModel', () => { model2.dispose(); }); + it('execute model with single inputs and outputs', async () => { + const model = await tf.node.loadSavedModel( + './test_objects/saved_model/model_single_input_multi_output', ['serve'], + 'serving_default'); + const input = tf.tensor1d([1, 2, 3], 'int32'); + const output = model.predict(input) as tf.Tensor[]; + const output1 = output[0]; + const output2 = output[1]; + expect(output1.shape).toEqual(input.shape); + expect(output1.dtype).toBe(input.dtype); + expect(output2.shape).toEqual(input.shape); + expect(output2.dtype).toBe(input.dtype); + test_util.expectArraysClose(await output1.data(), [2, 4, 6]); + test_util.expectArraysClose(await output2.data(), [1, 2, 3]); + model.dispose(); + }); + it('execute model with multiple inputs and outputs', async () => { const model = await tf.node.loadSavedModel( './test_objects/saved_model/model_multi_output', ['serve'], diff --git a/tfjs-node/test_objects/saved_model/model_single_input_multi_output/saved_model.pb b/tfjs-node/test_objects/saved_model/model_single_input_multi_output/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..7f19c27d6a52b840b6bba9ef1beb5b8761407555 GIT binary patch literal 9702 zcmc&)&2JmW73XkCk$U=~MzTzeWLc)|*j~pbx$+0i25v>cj_o*>D#^N4ixs()Xj3G! zyOiVtz2wr=2#OXB8X%`am!v)#Es9=pDNyv>OV9ladhQ`W-|lk1cPaVnhQ zoS_MQ9)=+Cq`lgK6&P7-N>zEU-K@zbff$?=&s%=S1 zL)vM|s~tKDX+>7sU8O2#r`dRq7>hXXre6>O-xD~!p-Gxt?>5&YMQdn{c1y0UNX;f4 zfrO;&ssgv!XyH~pMuo&Ca-5q2eV9VCp=y_P&aPZ0NTo698yyBxr-^#ERmBBE)po1i z*rh|NrZif+EKXxN2BYSv*iqVA+cP%@WAc$)?P_gB)Q@pb>I{4aBgp;(sa4bYXCD4@ zr(qa3)}%&DR@^VyB6JHLkIW@iSvipTq;S1(y&%jS7@yaa zs&GxXb#3?7jp|ZusZ=8LUq6N#q%Gx8H=VqX0HF3de#A1Ec!BUUFfN(~CK_a@Gy(B) zOkZS@g6=(XZ6|MFlUN1z#f3LnOO9)G3S-zD0QuLGr|qV zhy)AHNdvLsK5qbKI2bkkkD6ZiU($1AYWj0-EMJok8dbTxvi7uG;Qxg4raQkPm2XKp zMPrpe5Lkp{nox28sdSuj+F={U(;Nl0vD=cgE;=BTX$i)+me(IHS03LJx0WA2z2~Ty z%^&2iFK=$HKde06ytk3}u7kVuqrkf^KG^~u*zbMAIxO|2q9hONFsU0dI-yvT>r%I= ziAU5Ol-@#tU&f=uNmQI+K{A+&U4Fb5R4{T1agO78U$kH}3n}0}UVEB}=?|#63=!gw zVkN?71MY*nlw;^XV`ls^H=X9Mz@)gVNVSIC(!_Sh3^mh~`eNd>8#i$&hO-OnIan0M zMyoC>a;qwfw(iA4Me1~9MO?ZGt1zxg2l!lXG-V^wc;ws1U~T7iDl-fE2|mb7SK%j238rC4NTdiCX=nq>_I824pv(5``WcvW%U0AB%#<}zm!){? z4(iTKGfChCl6n(^fZz5JpfI=q?Fj0z)HkurDBX{nQcXcp;8N3r8)oSae@Pc!Zha9J z!i0xW33I3{mP$VP(LW0Cmm&B9!)1RMZ_MM zi5oTariS(+%OQLK>l@B;m`0CDq`Cg`ZK4P{VkZ#w{-c3#hhNl%KDRy#r`^I6MMD^( zC@j58{9>O$`e9Fq)P7=%5h0*#B875+Ww#@i@F^ghnAIZ_GE)JpBQc~5erWL>4$Yl$ zjD=2G$J^}W&bh6JPFV}YqP+wG)fh%&1mYG>!nlNi9#OI&ym?M4pG(ImMj@CMyO<^@fEPIZBSo34{PKc%6j9NZnihau(+HyG>C+YX>6(K>cKAp}2t$*Rih% zgbWAZUca+*ALgI6cgoE++Us6h)$)qmX*Q}-xscZ+^*Mfc@K99La$#1EJpvAirQE08 zCPS}eLCVVprb2vmqBU-WIzS4yEdLzsd}=$m`N^BEMwmDf8@gc>*B12y*yjlr-p_n5 zhblQl@q-)f=QHq6(l^J_2^i5q3sz)uy>XOH(Mxb%&%XzPXqm8K&pxBh5_$zL`WF^M zmvWSv_z3qbidd+rS~lg&6e%N9&`)J$SKbm<_T=jGPP@_4)MGNVf#hrn$0Th&h(-2m z0#h4CUHJ$X<&AiN*(d_sN2&EA3-d2kiU#7pDr)k6M=aftOU2r4xwdm>=Y~|@DV1&) z>hi5y!cw7DEZn@A*K?rmTH_NM&&JtffL3>3 zkAq50?eQK}1@P2jcDRzl8YWgIaW@MbtHY@ceJ8ngzXxLT#?Egj)ZjUuGKFKmBQW7R zi24VZOS`sw*n;mhlRjsD%8`aWIG|A3x4K0cc9u8qbBqfEORD}|ergHxLFGPLPJ+hM z@k-xfuW~{$=hj7vXMu7?&vM2w2o+BE2u1aDWD##fL={8197SxO&G^$6Bjc4(n~_+> zqp}qH4}qQ&(%vUyK3$Zp$_In>e25@zg&?Rg#q6ivSusfbVS~jUzv89?f{jCr&%lsT zMyu2)6A8!mP{>|q5R3209f(`i9Q&HUWZ}+0y08 z2R&ONm3eXH!M&AF)>a=r*)*<5Rx6M0t!#=PuP?7Xa7&_}(GxWKS(_BU;Y0+eXtO8$ zDT~z63fQFjjQ(!~?1W>R{d_>;Zo^-&X>^OrG3=Z@@8`i;W0detqGYLU^QFNo8Lz2( zmTcC{?5_kq2qx8h)lRUz#E;PQ(S zF2Ti!8f}ca1#ebs_!s*P_!of=yw1Qs$Vqo4K^%MrW+hd>9I;?9R&v&3cAB1pGlA)% zy}-B`ao>e3Wle;eSvL0EvO77)g!+fX;xQRsH?MJY+&f}FBlLYmV8;GFNLc=VVigUT zL#sdLMa1oaAW2_@^c?|5eAIdLV)3&MW#Q5KytBYhx*bRSE7rJZU<%Iq}1ZO-B zjU-d{$z_u?6-(&wD{g(;UU^Db7Xs!*10yWjtXVLyxKlLY!^7c|h%!5bjiN;?7tIg* z1%$Xn2qXO+kcD0!g>v9#l4D;I$l3}R4EjXpHNw9mEMgkWyQ%S;VuTmr|53``S%G`KKq=w>{p0N=#$(2FpW9a zKZNl)xvjVT{U`TaG_rT1-)K13u}dx@94i(aSq(~zOQvkpk*{dGVI)Bw7_j8Jie0j) z;LrGBz$k!u6>D6_GOq=>%}Ka0NTd>G!sECO9fFdlA!Bb19de2EVP+I>STZVBKu%w; zT%dCiH|}^zU0{(-JE0WMorNiVx0<-$uE|YtAAhcZEl7%`y=k;O=Jp9+tCk`NUchDy zoyC{^UEFdZ!?75>3hCZY2MY9TWZy8bxY%!TQQw&{)1&Ln_F=wvdLAzdYmKVL--T16 z(W$feX|x)8OO;-v=cA^;WR&R%@rMLFKC!(%noiKP)nv)dBLv))@i!igs(D@C-gzch zHHtU%8LSn`8U3an8v%CL&KN@s=o$MErqW637j*6vj5)2aeW}yi_3uQ$N-(PFL-F3| zEApEA4Tj76-DY!H+10XE4oSX!Be%Nj`Ul_}|Z@3u65M~Z}7E&0ITCFy=>hM=RD&<~ru?*oy zkGWw;y^JSN032MYTpFP9hiyMXx1#y|5hZapd(6$i#4(5d{xXiC?Z`MlKQ%XZ%w>C{ z1!L4)QPiL5%RyR?b^U~eV{Q((V=e~MueeE=(9dw4m0Nzr-34qW*>*qqGLF)-zv>J- L(?#l;Hi-EjpoA?) literal 0 HcmV?d00001 diff --git a/tfjs-node/test_objects/saved_model/model_single_input_multi_output/variables/variables.data-00000-of-00001 b/tfjs-node/test_objects/saved_model/model_single_input_multi_output/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..8020ed848344efdf229e7ae14782f4e5ea9c8074 GIT binary patch literal 172 zcmZQzaByH^U|^WXD3hnA&ZWS`%E2haSjNT9!6d|7Ql8Jn$H6SbRh*fgmsnC-lv>PX zC1kn`>FDI+8XxB96Y46&5tdk#nV6K5DkW20QjIPm;e9( literal 0 HcmV?d00001 diff --git a/tfjs-node/test_objects/saved_model/model_single_input_multi_output/variables/variables.index b/tfjs-node/test_objects/saved_model/model_single_input_multi_output/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..a2750a61168867a3dd111384bfa79654f3c9c839 GIT binary patch literal 236 zcmZQzVB=tvV&Y(Akl~AW_HcFf4)FK%3vqPvagFzP@^W;s-Hf literal 0 HcmV?d00001 From 41a7ee2efaa74247d8246d72ad41c7f2b75f5850 Mon Sep 17 00:00:00 2001 From: kangyizhang Date: Tue, 17 Dec 2019 14:40:26 -0800 Subject: [PATCH 3/4] add tests --- .../saved_model.pb | Bin 0 -> 9702 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 172 bytes .../variables/variables.index | Bin 0 -> 236 bytes 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/saved_model.pb create mode 100644 tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/variables/variables.data-00000-of-00001 create mode 100644 tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/variables/variables.index diff --git a/tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/saved_model.pb b/tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..7f19c27d6a52b840b6bba9ef1beb5b8761407555 GIT binary patch literal 9702 zcmc&)&2JmW73XkCk$U=~MzTzeWLc)|*j~pbx$+0i25v>cj_o*>D#^N4ixs()Xj3G! zyOiVtz2wr=2#OXB8X%`am!v)#Es9=pDNyv>OV9ladhQ`W-|lk1cPaVnhQ zoS_MQ9)=+Cq`lgK6&P7-N>zEU-K@zbff$?=&s%=S1 zL)vM|s~tKDX+>7sU8O2#r`dRq7>hXXre6>O-xD~!p-Gxt?>5&YMQdn{c1y0UNX;f4 zfrO;&ssgv!XyH~pMuo&Ca-5q2eV9VCp=y_P&aPZ0NTo698yyBxr-^#ERmBBE)po1i z*rh|NrZif+EKXxN2BYSv*iqVA+cP%@WAc$)?P_gB)Q@pb>I{4aBgp;(sa4bYXCD4@ zr(qa3)}%&DR@^VyB6JHLkIW@iSvipTq;S1(y&%jS7@yaa zs&GxXb#3?7jp|ZusZ=8LUq6N#q%Gx8H=VqX0HF3de#A1Ec!BUUFfN(~CK_a@Gy(B) zOkZS@g6=(XZ6|MFlUN1z#f3LnOO9)G3S-zD0QuLGr|qV zhy)AHNdvLsK5qbKI2bkkkD6ZiU($1AYWj0-EMJok8dbTxvi7uG;Qxg4raQkPm2XKp zMPrpe5Lkp{nox28sdSuj+F={U(;Nl0vD=cgE;=BTX$i)+me(IHS03LJx0WA2z2~Ty z%^&2iFK=$HKde06ytk3}u7kVuqrkf^KG^~u*zbMAIxO|2q9hONFsU0dI-yvT>r%I= ziAU5Ol-@#tU&f=uNmQI+K{A+&U4Fb5R4{T1agO78U$kH}3n}0}UVEB}=?|#63=!gw zVkN?71MY*nlw;^XV`ls^H=X9Mz@)gVNVSIC(!_Sh3^mh~`eNd>8#i$&hO-OnIan0M zMyoC>a;qwfw(iA4Me1~9MO?ZGt1zxg2l!lXG-V^wc;ws1U~T7iDl-fE2|mb7SK%j238rC4NTdiCX=nq>_I824pv(5``WcvW%U0AB%#<}zm!){? z4(iTKGfChCl6n(^fZz5JpfI=q?Fj0z)HkurDBX{nQcXcp;8N3r8)oSae@Pc!Zha9J z!i0xW33I3{mP$VP(LW0Cmm&B9!)1RMZ_MM zi5oTariS(+%OQLK>l@B;m`0CDq`Cg`ZK4P{VkZ#w{-c3#hhNl%KDRy#r`^I6MMD^( zC@j58{9>O$`e9Fq)P7=%5h0*#B875+Ww#@i@F^ghnAIZ_GE)JpBQc~5erWL>4$Yl$ zjD=2G$J^}W&bh6JPFV}YqP+wG)fh%&1mYG>!nlNi9#OI&ym?M4pG(ImMj@CMyO<^@fEPIZBSo34{PKc%6j9NZnihau(+HyG>C+YX>6(K>cKAp}2t$*Rih% zgbWAZUca+*ALgI6cgoE++Us6h)$)qmX*Q}-xscZ+^*Mfc@K99La$#1EJpvAirQE08 zCPS}eLCVVprb2vmqBU-WIzS4yEdLzsd}=$m`N^BEMwmDf8@gc>*B12y*yjlr-p_n5 zhblQl@q-)f=QHq6(l^J_2^i5q3sz)uy>XOH(Mxb%&%XzPXqm8K&pxBh5_$zL`WF^M zmvWSv_z3qbidd+rS~lg&6e%N9&`)J$SKbm<_T=jGPP@_4)MGNVf#hrn$0Th&h(-2m z0#h4CUHJ$X<&AiN*(d_sN2&EA3-d2kiU#7pDr)k6M=aftOU2r4xwdm>=Y~|@DV1&) z>hi5y!cw7DEZn@A*K?rmTH_NM&&JtffL3>3 zkAq50?eQK}1@P2jcDRzl8YWgIaW@MbtHY@ceJ8ngzXxLT#?Egj)ZjUuGKFKmBQW7R zi24VZOS`sw*n;mhlRjsD%8`aWIG|A3x4K0cc9u8qbBqfEORD}|ergHxLFGPLPJ+hM z@k-xfuW~{$=hj7vXMu7?&vM2w2o+BE2u1aDWD##fL={8197SxO&G^$6Bjc4(n~_+> zqp}qH4}qQ&(%vUyK3$Zp$_In>e25@zg&?Rg#q6ivSusfbVS~jUzv89?f{jCr&%lsT zMyu2)6A8!mP{>|q5R3209f(`i9Q&HUWZ}+0y08 z2R&ONm3eXH!M&AF)>a=r*)*<5Rx6M0t!#=PuP?7Xa7&_}(GxWKS(_BU;Y0+eXtO8$ zDT~z63fQFjjQ(!~?1W>R{d_>;Zo^-&X>^OrG3=Z@@8`i;W0detqGYLU^QFNo8Lz2( zmTcC{?5_kq2qx8h)lRUz#E;PQ(S zF2Ti!8f}ca1#ebs_!s*P_!of=yw1Qs$Vqo4K^%MrW+hd>9I;?9R&v&3cAB1pGlA)% zy}-B`ao>e3Wle;eSvL0EvO77)g!+fX;xQRsH?MJY+&f}FBlLYmV8;GFNLc=VVigUT zL#sdLMa1oaAW2_@^c?|5eAIdLV)3&MW#Q5KytBYhx*bRSE7rJZU<%Iq}1ZO-B zjU-d{$z_u?6-(&wD{g(;UU^Db7Xs!*10yWjtXVLyxKlLY!^7c|h%!5bjiN;?7tIg* z1%$Xn2qXO+kcD0!g>v9#l4D;I$l3}R4EjXpHNw9mEMgkWyQ%S;VuTmr|53``S%G`KKq=w>{p0N=#$(2FpW9a zKZNl)xvjVT{U`TaG_rT1-)K13u}dx@94i(aSq(~zOQvkpk*{dGVI)Bw7_j8Jie0j) z;LrGBz$k!u6>D6_GOq=>%}Ka0NTd>G!sECO9fFdlA!Bb19de2EVP+I>STZVBKu%w; zT%dCiH|}^zU0{(-JE0WMorNiVx0<-$uE|YtAAhcZEl7%`y=k;O=Jp9+tCk`NUchDy zoyC{^UEFdZ!?75>3hCZY2MY9TWZy8bxY%!TQQw&{)1&Ln_F=wvdLAzdYmKVL--T16 z(W$feX|x)8OO;-v=cA^;WR&R%@rMLFKC!(%noiKP)nv)dBLv))@i!igs(D@C-gzch zHHtU%8LSn`8U3an8v%CL&KN@s=o$MErqW637j*6vj5)2aeW}yi_3uQ$N-(PFL-F3| zEApEA4Tj76-DY!H+10XE4oSX!Be%Nj`Ul_}|Z@3u65M~Z}7E&0ITCFy=>hM=RD&<~ru?*oy zkGWw;y^JSN032MYTpFP9hiyMXx1#y|5hZapd(6$i#4(5d{xXiC?Z`MlKQ%XZ%w>C{ z1!L4)QPiL5%RyR?b^U~eV{Q((V=e~MueeE=(9dw4m0Nzr-34qW*>*qqGLF)-zv>J- L(?#l;Hi-EjpoA?) literal 0 HcmV?d00001 diff --git a/tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/variables/variables.data-00000-of-00001 b/tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..8020ed848344efdf229e7ae14782f4e5ea9c8074 GIT binary patch literal 172 zcmZQzaByH^U|^WXD3hnA&ZWS`%E2haSjNT9!6d|7Ql8Jn$H6SbRh*fgmsnC-lv>PX zC1kn`>FDI+8XxB96Y46&5tdk#nV6K5DkW20QjIPm;e9( literal 0 HcmV?d00001 diff --git a/tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/variables/variables.index b/tfjs-node-gpu/test_objects/saved_model/model_single_input_multi_output/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..a2750a61168867a3dd111384bfa79654f3c9c839 GIT binary patch literal 236 zcmZQzVB=tvV&Y(Akl~AW_HcFf4)FK%3vqPvagFzP@^W;s-Hf literal 0 HcmV?d00001 From aa29fd0c420c7b5466dacd432bc9148968f74ea6 Mon Sep 17 00:00:00 2001 From: kangyizhang Date: Wed, 18 Dec 2019 10:16:07 -0800 Subject: [PATCH 4/4] add comment for test model --- tfjs-node/src/saved_model_test.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tfjs-node/src/saved_model_test.ts b/tfjs-node/src/saved_model_test.ts index 89d5c1f1e1e..99fcaceddc5 100644 --- a/tfjs-node/src/saved_model_test.ts +++ b/tfjs-node/src/saved_model_test.ts @@ -429,7 +429,8 @@ describe('SavedModel', () => { model2.dispose(); }); - it('execute model with single inputs and outputs', async () => { + it('execute model with single input and multiple outputs', async () => { + // This test model behaves as: f(x)=[2*x, x] const model = await tf.node.loadSavedModel( './test_objects/saved_model/model_single_input_multi_output', ['serve'], 'serving_default'); @@ -446,7 +447,8 @@ describe('SavedModel', () => { model.dispose(); }); - it('execute model with multiple inputs and outputs', async () => { + it('execute model with multiple inputs and multiple outputs', async () => { + // This test model behaves as: f(x, y)=[2*x, y] const model = await tf.node.loadSavedModel( './test_objects/saved_model/model_multi_output', ['serve'], 'serving_default');