From 7c62b52cda0cfbb9af04b747b605e677f0e5c0f9 Mon Sep 17 00:00:00 2001 From: Brad Heintz Date: Tue, 15 Oct 2019 10:16:03 -0700 Subject: [PATCH 1/5] added correctness tests for classification models --- .../AlexnetTester.test_alexnet_expect.pkl | Bin 0 -> 543 bytes ...DensenetTester.test_densenet121_expect.pkl | Bin 0 -> 543 bytes ...DensenetTester.test_densenet161_expect.pkl | Bin 0 -> 543 bytes ...DensenetTester.test_densenet169_expect.pkl | Bin 0 -> 543 bytes ...DensenetTester.test_densenet201_expect.pkl | Bin 0 -> 543 bytes .../GooglenetTester.test_googlenet_expect.pkl | Bin 0 -> 543 bytes ...ptionV3Tester.test_inception_v3_expect.pkl | Bin 0 -> 543 bytes .../MNASNetTester.test_mnasnet0_5_expect.pkl | Bin 0 -> 543 bytes .../MNASNetTester.test_mnasnet0_75_expect.pkl | Bin 0 -> 543 bytes .../MNASNetTester.test_mnasnet1_0_expect.pkl | Bin 0 -> 543 bytes .../MNASNetTester.test_mnasnet1_3_expect.pkl | Bin 0 -> 543 bytes ...bilenetTester.test_mobilenet_v2_expect.pkl | Bin 0 -> 543 bytes ...st_mobilenetv2_residual_setting_expect.pkl | Bin 0 -> 543 bytes .../ResnetTester.test_resnet101_expect.pkl | Bin 0 -> 543 bytes .../ResnetTester.test_resnet152_expect.pkl | Bin 0 -> 543 bytes .../ResnetTester.test_resnet18_expect.pkl | Bin 0 -> 543 bytes .../ResnetTester.test_resnet34_expect.pkl | Bin 0 -> 543 bytes .../ResnetTester.test_resnet50_expect.pkl | Bin 0 -> 543 bytes ...netTester.test_resnext101_32x8d_expect.pkl | Bin 0 -> 543 bytes ...snetTester.test_resnext50_32x4d_expect.pkl | Bin 0 -> 543 bytes ...netTester.test_wide_resnet101_2_expect.pkl | Bin 0 -> 543 bytes ...snetTester.test_wide_resnet50_2_expect.pkl | Bin 0 -> 543 bytes ...tTester.test_shufflenet_v2_x0_5_expect.pkl | Bin 0 -> 543 bytes ...tTester.test_shufflenet_v2_x1_0_expect.pkl | Bin 0 -> 543 bytes ...tTester.test_shufflenet_v2_x1_5_expect.pkl | Bin 0 -> 543 bytes ...tTester.test_shufflenet_v2_x2_0_expect.pkl | Bin 0 -> 543 bytes ...ezenetTester.test_squeezenet1_0_expect.pkl | Bin 0 -> 543 bytes ...ezenetTester.test_squeezenet1_1_expect.pkl | Bin 0 -> 543 bytes .../VGGNetTester.test_vgg11_bn_expect.pkl | Bin 0 -> 543 bytes .../expect/VGGNetTester.test_vgg11_expect.pkl | Bin 0 -> 543 bytes .../VGGNetTester.test_vgg13_bn_expect.pkl | Bin 0 -> 543 bytes .../expect/VGGNetTester.test_vgg13_expect.pkl | Bin 0 -> 543 bytes .../VGGNetTester.test_vgg16_bn_expect.pkl | Bin 0 -> 543 bytes .../expect/VGGNetTester.test_vgg16_expect.pkl | Bin 0 -> 543 bytes .../VGGNetTester.test_vgg19_bn_expect.pkl | Bin 0 -> 543 bytes .../expect/VGGNetTester.test_vgg19_expect.pkl | Bin 0 -> 543 bytes test/test_models.py | 417 +++++++++++++----- 37 files changed, 314 insertions(+), 103 deletions(-) create mode 100644 test/expect/AlexnetTester.test_alexnet_expect.pkl create mode 100644 test/expect/DensenetTester.test_densenet121_expect.pkl create mode 100644 test/expect/DensenetTester.test_densenet161_expect.pkl create mode 100644 test/expect/DensenetTester.test_densenet169_expect.pkl create mode 100644 test/expect/DensenetTester.test_densenet201_expect.pkl create mode 100644 test/expect/GooglenetTester.test_googlenet_expect.pkl create mode 100644 test/expect/InceptionV3Tester.test_inception_v3_expect.pkl create mode 100644 test/expect/MNASNetTester.test_mnasnet0_5_expect.pkl create mode 100644 test/expect/MNASNetTester.test_mnasnet0_75_expect.pkl create mode 100644 test/expect/MNASNetTester.test_mnasnet1_0_expect.pkl create mode 100644 test/expect/MNASNetTester.test_mnasnet1_3_expect.pkl create mode 100644 test/expect/MobilenetTester.test_mobilenet_v2_expect.pkl create mode 100644 test/expect/MobilenetTester.test_mobilenetv2_residual_setting_expect.pkl create mode 100644 test/expect/ResnetTester.test_resnet101_expect.pkl create mode 100644 test/expect/ResnetTester.test_resnet152_expect.pkl create mode 100644 test/expect/ResnetTester.test_resnet18_expect.pkl create mode 100644 test/expect/ResnetTester.test_resnet34_expect.pkl create mode 100644 test/expect/ResnetTester.test_resnet50_expect.pkl create mode 100644 test/expect/ResnetTester.test_resnext101_32x8d_expect.pkl create mode 100644 test/expect/ResnetTester.test_resnext50_32x4d_expect.pkl create mode 100644 test/expect/ResnetTester.test_wide_resnet101_2_expect.pkl create mode 100644 test/expect/ResnetTester.test_wide_resnet50_2_expect.pkl create mode 100644 test/expect/ShufflenetTester.test_shufflenet_v2_x0_5_expect.pkl create mode 100644 test/expect/ShufflenetTester.test_shufflenet_v2_x1_0_expect.pkl create mode 100644 test/expect/ShufflenetTester.test_shufflenet_v2_x1_5_expect.pkl create mode 100644 test/expect/ShufflenetTester.test_shufflenet_v2_x2_0_expect.pkl create mode 100644 test/expect/SqueezenetTester.test_squeezenet1_0_expect.pkl create mode 100644 test/expect/SqueezenetTester.test_squeezenet1_1_expect.pkl create mode 100644 test/expect/VGGNetTester.test_vgg11_bn_expect.pkl create mode 100644 test/expect/VGGNetTester.test_vgg11_expect.pkl create mode 100644 test/expect/VGGNetTester.test_vgg13_bn_expect.pkl create mode 100644 test/expect/VGGNetTester.test_vgg13_expect.pkl create mode 100644 test/expect/VGGNetTester.test_vgg16_bn_expect.pkl create mode 100644 test/expect/VGGNetTester.test_vgg16_expect.pkl create mode 100644 test/expect/VGGNetTester.test_vgg19_bn_expect.pkl create mode 100644 test/expect/VGGNetTester.test_vgg19_expect.pkl diff --git a/test/expect/AlexnetTester.test_alexnet_expect.pkl b/test/expect/AlexnetTester.test_alexnet_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..217eed006da3b820e2c3bd39f6439dee8f5ed9b8 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y3i2~2wqO^LQywRVB^1l8HQ zCJH>*{b|`E>t}MRd*&_fw~;t*Wy591Ws~;4W={nFyWK83zV8n3Fx<`a*3qWH)M1ZA zqtk8+?Qp9W8|&TIuAj4J`CMd^BVMzIC5dB?%)$9Kb^LWUl7}_+d|qF@ThBps&mTUO zT@I1kEG;6kZC1R@vfg0bxaXVkt6jTVckbrdD7^b0d;adnrB|)F<_hmQ;k&^qzKnCv qq=?2n>;BBxGr`zzSHSFBmKmD__WXGD)=Hzx+girKeD|GsIeP#&_Q3K0 literal 0 HcmV?d00001 diff --git a/test/expect/DensenetTester.test_densenet121_expect.pkl b/test/expect/DensenetTester.test_densenet121_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..bb88c7faffad86407a474fb0f2ff47486270ef07 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ%ytJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y3Klk@kpM`!LA?h~wb~EKR`^yKEySD@9u-Q_c`4Ww7a%NW?#*wKYOQWS?-hH?``k&DcEj) z?0Wl*B%%F2vzhF-+`6%MYTojFoMyfDeX05TmxZhDm&@(g_i)4YeKP$K`@e1Wvj6dS z?SAL>mHSs57P3908ERL!?Vf$8!&JL|oB8{ER;=2)F?{m=3px2VT1_4HPei=!Co%IK rIQ`ygzvxom{Yo!9_uKq?w%0&wy3MJlTlVj)zGYXHaCP4SwT;#Q+TqBB literal 0 HcmV?d00001 diff --git a/test/expect/DensenetTester.test_densenet161_expect.pkl b/test/expect/DensenetTester.test_densenet161_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e8f1b2e2af5fdc406a15f97a10420fa6c2f8cf95 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cnbIbC+aQ}XtG5jlZ5gkfQ9i%j-J93#LcPTIpSsMxKRQX&zJD>#ewS#C zeT$fH+I~}gVaHXs+|K{k{eADv57{;pf3uyx@339QW!=3jNeugn_cPiToa)f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3clApHA&_QtYz(>DRq?L3qPH z?HkMO7<|pA-(>4Qt$g3U zmsj?76t~+gO5eM8AAiukcN66H9aopP|HRO+m!Xr-`rrS4+Z#vo_Z~K=-#hzcpY664 zZu@wS^4i(e{Ix5}j@++N*}M1C#+ZF^ho{-?f9GQVp+k?M7|FPD@ rF6Ha$eK~$Q`^_#U+U0$=-{11E-_GOxcH0TdV{P|LI<)VRWQ#ojan;F8 literal 0 HcmV?d00001 diff --git a/test/expect/DensenetTester.test_densenet201_expect.pkl b/test/expect/DensenetTester.test_densenet201_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e4b9d65585ffeec2475062d111887eb7bb7635d1 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cm)LcYDXP2%l-OyS*7NMycjfdGn+LB7?40fi*{5ARv`_y+#Qvh2Np{?MUHgpuC-2d+*MbDD8zhEhCU$WbJpX0u^eW&EV+G%DV zx7|>bX!rKVd%Hs_?fVvKciSB(SZg=!f$6>pI#2dpcT(B6*hq5Um4!|FDlbad2W5QN rm+i7|U$ZigUCv&c{eC6^_RgzA_EouS?CVJEv^jO@!QOqPoAvf*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJVPIis zYHVz2RLBgrJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cs!BI_ASyKF8n81G5&oxexL z<)Fn`}ZNIQLAbTx0X*?1w!8Zf~t+{`cf*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ%ytJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y302U~ns$}aQ$V5jbHyEVoy zX2lX;)}>MBW=N#_RULTZ+i-2ixsZ^|vyvYe{9BkxeNSbt@iUOF@SC^K`P}`V^Ug6m zIqG*aR{Nao?pVJk6Skd8+;;EWv}GLr4BKS<%7hO2wQuM27fIIk@tK=;mTQCExsADp z&jxVbI<2uz<=hf*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ%ytJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cm9L3W!N&Q=&tsM&0+{qoVK zBMX9zGiKCmVrG44#K7EcyoR@T(~HcwjWZ73-h6H2a>E@r&u(_Jnq#d0N6xq}pl9<9 z!8OJucHNt`9qXJ9I s%}FblZk~2x`X*7+Uz=XA7#mxhKfbx3Zp!AbE4wyN?w@D;N$;>R0Pw!WssI20 literal 0 HcmV?d00001 diff --git a/test/expect/MNASNetTester.test_mnasnet0_75_expect.pkl b/test/expect/MNASNetTester.test_mnasnet0_75_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..530c8eeafc6b30a3897497f6b90a1d31655a2020 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cm)wkw;ict6?r%zNgh?!(87 z1pNdz6^2=FVh&>4v_sQ-)8g;|qg$`z3@<6Q8cki`W4O=eqT!s$v73SzR~g;>_j%(< z(~Cxj{>E+6+pD=LSDbazhj)D&8;(8LSkU@f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y3A4L1#9);%}y&YHQYC(~uq zx=H+-cuQ4{=KNT?X@Pj}ra4h6n<{56+BCUA)yVSSu1%Nruo(q3>lvi9?KIL8YTCqf zRou|maKpwQBEcKJs4U$u$JutHhpm}mZ86WLOAlj=W*pvWIA`-#qc)A5hWq?eH^rRi z-1LMgaN~pBU-f&QDsHmypSEdVw~3+Q_HrY?-W^7PZk?M#=J;>iq;!2_9KXBK#`uXw r=bz2qsHi&6aKQzh4MEOUo6hZM-tcKzgVFC!VMF<20Ys literal 0 HcmV?d00001 diff --git a/test/expect/MNASNetTester.test_mnasnet1_3_expect.pkl b/test/expect/MNASNetTester.test_mnasnet1_3_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a79038ce3eb57966eb13f8abc4cd9153342c0c92 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cnm(Ju_9EYIF}s;tCN@sHt# zcENKS?ibwLAnpHT!}@qeLz7a54K?jgH`HaZ85YG{-*CdlUcdC9tYOV5eS?>eqBr(@ zQ`#66+iAeKa?^&tKNf8`venBlEY!-7fx%w;amn_L;U5Y&8oh7WAm#dWgJR&X4M#4V z*w9xUv!QLnGlLt=JJ;>Y+qzLwq|-ohsg>dP2a<-9l-mt6trZNj+}>{p5K%Rl^R#}Q q%)_7hy_q_OvWAm49IR*B7&OmkgYd=q2KVwrH!jU&GCbp!vH<{%+`>Wt literal 0 HcmV?d00001 diff --git a/test/expect/MobilenetTester.test_mobilenet_v2_expect.pkl b/test/expect/MobilenetTester.test_mobilenet_v2_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..54ee307c01bf7f06dc7bcc6deaa51da65b1e0501 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y3k0><@=H`wc|ZFN}B>6*G> zO*Py46nDlAruUoHA1h|q@ToX&eQ4Wy{aHMQdO{iD8}6O<*B44VtY6)nV31L&v_7Fs z%b?&$?uH{xS?e=|bT|C@9=QGiyU&J0k(CA;)`)KS@R)CX!J=}#kY%O%TKCqh=Skw; z@W6BKI_=L@`q8Z>>lb`uSZ}4uup#Na=7zWbg$*p~rmnj(<)hw{M{EWjuObZgEd9Fv q*MaByYXprqv;}wRiyX;Zzr;&UulDPE{R9@y4QFmN8SKmJSq}h$2f&a3 literal 0 HcmV?d00001 diff --git a/test/expect/MobilenetTester.test_mobilenetv2_residual_setting_expect.pkl b/test/expect/MobilenetTester.test_mobilenetv2_residual_setting_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c7733885caf89e9ce68afb9421b695b4dac60d1c GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cnqcbRtf87#LHs(Q8K3)jUR zQ>C|CxcF_`K5c5$4l!}(ow;oGJEO(9EDYvv+wn8>-j0taOLneOVYbM)vc__fHmjx9 z8_k_>&IxXBxE`|2BHwoB3^!?uo{79WrpNbhSFw!Qxx((Eg{-QU#ex@8En>^`ckbDh zvU5*Z|IXeDE{iKQF_z7hUw0U9*4h!naM|LC?Q#phwouE1>GLceTx{N!wY16d2j4q$ q<0UJ1y8CxnDD0TMBdB7_j+ZuKmJ;VnEN`St+*vj4o%uE1+#LYZp}|T3 literal 0 HcmV?d00001 diff --git a/test/expect/ResnetTester.test_resnet101_expect.pkl b/test/expect/ResnetTester.test_resnet101_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..ba62eb8e625ee6ba33352509e629de3477da6857 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cmN4K~LZo36U}r7duqTqWxA z^~(d-d8b>BRq0l`Wj&E~d$y40*sGJ%jy6s_;(XCC;mCc12d;MAtB>Zn?$`qA>rXt(>5yj&0L*FENJdfqj`eY%_bllMp4J^GI3A3Alky=&6Z qz(-eI%bOP;o}1rvO!1THG5uQwZd>Za-275LA2xls$!)Q2`7r=_f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cmyMEg_IHWYd>`3Rog)p7W= z&xhJmP7=3HCr{k(E$}S##7dCp;O=I+4)VI)EXy$b9Y2r+$GA6C_{PH%$n`_dF)0+#mPxn+wpZZZOdYVIW zk@xof7N^vtFL*u9>OZ~w*OyZ|T4i3*C--f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ zYG7hwQOFFoJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3clC>yFr6%Sg04i?<(j0QgvSY`QII}GuyDj?%Y*xyV8GV_7g+4?3>M{XaA@2n4Ll1wtc-v0{6et zP}!IBIewp2e1YBCnf*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ7MM9Ws8>3)2EKBL{rvR612!dx*_l!YCpLNHAOOQ~#e4t& literal 0 HcmV?d00001 diff --git a/test/expect/ResnetTester.test_resnet50_expect.pkl b/test/expect/ResnetTester.test_resnet50_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..1a94550e33665cbf557128f3bc5f271ba01491a3 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ zWMOJ#QOFFoJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cmgRfiA!R`PUgcd2x!QOj_= zRQlq8oQ1DL$BQox+1e`{j2AN=XbsdnaR2Q)how^vIrP1ld~kKlHiv+XkE{}0f*qe7 z)^wZ_z576w45Pz~S!v4U^NsB_d@9IM;IQ4_$SB|BDT;4zOIBU_akY z+R@^}6Ni+WUmSKU4|NnR>vc>iU~@F-I&@%*rKO|t#vaFcR_=~&B9V?Y6Ac`yuZkV) qJUZ#%s*er_zpf2CXyv2i=>B4<-6qA>1LyZN9%Q#WdVp!p?E?VAmBOL` literal 0 HcmV?d00001 diff --git a/test/expect/ResnetTester.test_resnext101_32x8d_expect.pkl b/test/expect/ResnetTester.test_resnext101_32x8d_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..b2dd8c42da419ffa46dd5321f8cf392c554afa7f GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cm&R9jo$E64W;nR)NcI}vI7 z%G71=fkUQN0^68vZM|3RDJ_-XJIz*P&m++Y+hwb?_5{3Zx6<=6-Scc>$KJrpGJ6-W zuNhmrcRPAZvlyhj&XTuHUQB|8AEvx5}O~636$%2TiiRa7=D*LQ4Ct{q1FY rHneTBx$UF2_sa7Y8)>J>d%bJ}_e|JUwD*Ci#NPELZ)~?bJ!TC6m7Kz* literal 0 HcmV?d00001 diff --git a/test/expect/ResnetTester.test_resnext50_32x4d_expect.pkl b/test/expect/ResnetTester.test_resnext50_32x4d_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..fd4b9d49c49ae2fbd225ac2e58e047789c6d3b96 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ%ytJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3ck}8oqlyjC^-*Khn6j^sblf zxrT!_)2{xo_FLRwZFAqzmdn0#x6Q@ZHua)iw!f6zY#x46+EZa9wKp$gs`WL?7+WE? zLYu%V=WKpNRqrlrUb83d%uH)mDUQ7t-)*o_-QI7r_V1}Z&#a&Aj%pI!^K0L&-80Wj zu(8%t+|%&r{H~hfKfAMo?QInNr1x?@XR=*tsJ3_8o7z2#)u!1z3twoXB>ij;UsL}c qDXUbQ8OIy<2uQf>)!AyYSL%lCUI8aTo4O^l_gs3byw~BZr4;}fEW&{R literal 0 HcmV?d00001 diff --git a/test/expect/ResnetTester.test_wide_resnet101_2_expect.pkl b/test/expect/ResnetTester.test_wide_resnet101_2_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..8aef5fb29093728bea6e7f534c17edcb86013743 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cnmhbFqQ-W57-AhGY5XYRaX z$98EPTf8^>7=L!gu~v>5ZVT>iJT}G5*X`zht7DF*JKg%6ULIv}Q9hQp>YZrtCFyFPANaC8peYPYwKqKf*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cnJ`;#3SOKag~`wnM@jm4mj6LLKh?WN@@z^Yp-)SAqMjewQC~yyADj zYm=YjD(6=F&!-g*1nw(voKR4G;MB&a`*+qIbU0Qiaj?H<@_`dh{0F-qu5fss-t9QS z=eNTL<%$Er`-L2Px{MF(>-XOu%(47{Z4$HNg@+1`lDVn-PYC~a$a?g_;o?4p1N*<& sI>xQiba;}teSZ!6=>y5FC-*-)!*|eO#}WJGmx2x$9lve!#f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ%ytJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y3|t*`dvB#K#BUAbXnefR#J zX&fhQn5!PztY{3~bGvrxp6_>_TTPoKz56ELNgLzx&wH|S?RGm!Y_{e-T5qFpc7@Hg z>a~0LFMr>2mnUk^e6JOIzW6S&d0M}GkD7kxp2=GhcR!b_-Sgwb(mnU8yKH7G`(XXy zbg#|QIT!bsN7vZ&e_`2OJaeth@0~Aq8J9%tSydUnr^9K%o@rk%?cUz|)n=W{P3t|E tXYE<*`rPLBho{z`b!O}d;5%WnuxRm~ou_;Dd}2Ic6UH^oI;G{l4FIV_&|d%m literal 0 HcmV?d00001 diff --git a/test/expect/ShufflenetTester.test_shufflenet_v2_x1_0_expect.pkl b/test/expect/ShufflenetTester.test_shufflenet_v2_x1_0_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..ff3d93dfc6cc3234f64a17ac5cdbaf37df678bec GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cl)zs~Nq;CZvh_@jpnS7oBL zpkcgq?YWGdOTzc=SvfCd7x&M;J)wW=cHKX*caKNiNt-`ki}ytQp19`;WB;DkcQJd8 z^K;p7f4pt;s<~#*OR=SU#Jl`$tUvGCv-rKso}v{`ZQQaS+FX0P+=fTz&>r6P&-U1| zitO3+f6Ja?rLWe9WZG?J+n?GqS)FaqoJ*E_rs(aqu9S+g$$so+6T?5#Cc)t9o)wFZ r*({!TclVpRW4r4zMQkRBOtM*OriUx>v}k literal 0 HcmV?d00001 diff --git a/test/expect/ShufflenetTester.test_shufflenet_v2_x1_5_expect.pkl b/test/expect/ShufflenetTester.test_shufflenet_v2_x1_5_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a4f1426e95a82b4e9a2805d3c12cc341531adb9d GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3clY##^_0Hcqn4T;Na#MEBxxva2v*YoITn|BqHZCu}8v~I5WXY(s@ll7e94>l9e9=6e&(qLn8 z&3bprP18N??VtBVou0ktf7UhYz^fnjG-=M<>ic1v?@QL$*p~0Jp1HWoW|F}f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cngCjxEmx+v~>Z=Jj6|L(xu z#og7rjiff(aPOGEM_+iJ&EaDkZIU|Eci4FU*}YJFicQyA7i-H?E3NObEwJ(as$^|# zuCgb#xp6o1ovpiLQ)k%RyTGz%;q1xQD;OJWN@i``(eba>CMR>{9=@Mb_Uv7}bWd+< z=AJoiZ+Fi+(`4gq9Ao|WUbaodomp1v|GnE?SUJu5y3QnYhsU$->6p*4 rhjDu3Zb9!YdrC#$?e5z@&szJ*wcUB2T5Z^*EN!?8v~5`b`RxG!L?6W? literal 0 HcmV?d00001 diff --git a/test/expect/SqueezenetTester.test_squeezenet1_0_expect.pkl b/test/expect/SqueezenetTester.test_squeezenet1_0_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9cc5f9a1e188c676f0d7cead585e1bfd64dbcbf4 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZH9raEg80K9VvDkPn_*QV#mru z?54IX2D3qY<$E%AOH1z9faMew?N~DwS-rmY(-y>Uy&`W7qPNBiTJu>d+kOsfwdI-f z&OFvq%eI1($u3Nf*IFfYo9(HP*>-J;3$2-6`q~~y+hFl;y11QOji}YNw`Z(C`j_04 zwfTIE6Kv;pcLv+-JKAl%Vp_mv^WC`srj@hL+DMf*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ%ytJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cl0R^RQ8{n=`l`}Cn5-?pQ+ z8wC8U{d|_&ys5FYdvtw;-7!CR`yx&qyNMlzc89v|+vTiwv5)Hvv);`jY-b@HX@6yF zkB!2-EUTOIR$3QbTxAC`)3KPx?yg*m`R<7AcJ+TE?D#tb?dB9}faUJ+};;<+MIKnWPkWa fm7O48vdypKO7^Axe{FiVxY$J}pSKBK)?o(#wdk{u literal 0 HcmV?d00001 diff --git a/test/expect/VGGNetTester.test_vgg11_bn_expect.pkl b/test/expect/VGGNetTester.test_vgg11_bn_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d48fc986c9e85ba1b148294a46129d06b95afd2d GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y1a|Hr!=UkL5p&fsOsa`2e# zw{P8h97{g$37NlT&$Vq|_t?$Kv6&~!y(d9=wGDg7#yxEE2dvptAKGN?JZkH3I(hd) zDJh$xhfaG+_Mfz#u%*W8i?7ID?T}C##)H3XUZ|MuRah^vd*Rm=d%i4gu-S2}c-IGE zHCrCpWwu>A?pn({cw{p}c*dSQbFaP2|9sfqxJv*GWx@l literal 0 HcmV?d00001 diff --git a/test/expect/VGGNetTester.test_vgg11_expect.pkl b/test/expect/VGGNetTester.test_vgg11_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..ef0eecbfb3afa7a1d99e4648f584dc5f8b30a06d GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3clefsc0`eJQl}9FvzV+u>uj zmOr}p%qjc4=fQ$4drWqI-P1BB$0kIcd(SJC)izpT8}~3M9k7m3e`r&)`>5@PQ^~u_ zWu$CwJaXFOeDI|8f}J&1b^aoIw}ge-v>g6rvqQ~nZ^ssi-QC|;?9o};U?Xv^c-MSU zHQRl1%WT#5+_ir2=#fo`=!`w3mR@^Z|9{xCEK|ZZLA}oAp+2|moz4Dxmb~1uSLas3 rZie|vyG{!z?X_+zu+`e-Vf!I9e2-t>mOWV#7x#$WHMi;MkJf*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJaswS2TgZUb3dTe|BcRP7@YJ2b*3vP?w!chpZ=AN=-l7%9 z_89-4ZnI=|=iVDuv9=1^*6d0Xyt?Q9?)tq8ytnT0X`X5GPhDg8gf#}XmG9r!q^-JS z)0)L#Gj;2%JqhCbZ0tHO?qNS6xA(v*#l2T%XWHHpVz-I=Aa3iLC$smp#dn*I9UZ%$ zU7E6ct%hpn$Yn`p!F{+A7hrSqOV6RABsVKZ#@b$8kvU$x8z06fvdD*ylh literal 0 HcmV?d00001 diff --git a/test/expect/VGGNetTester.test_vgg13_expect.pkl b/test/expect/VGGNetTester.test_vgg13_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..044e160ca449fd2b200d27d17fe9e2c9ec8b6369 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cmro(#4eE-|)o6@q)abmjK` zS#@mBO@Zn2HE?bx|ym!-(nJ#PE!_dfI8y2rY8rcIru#_qH03~aqVy|H<+ z=8}y{4uehM_E~#&Nba+_+ljZF+&sUICl$HVnHu zcAvU3W%nL4Wm~fwu{H%heztQMgRHe4GTB~85wmvv{BV!Joy#^`9((Wo;a0lmG?V+@ q8w(Qm%-}w3z3s(BoA6J+Y&fl*_avA}?GcTbVRNXr(}sENG8+KvV!#yu literal 0 HcmV?d00001 diff --git a/test/expect/VGGNetTester.test_vgg16_bn_expect.pkl b/test/expect/VGGNetTester.test_vgg16_bn_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7c5f83594f9c1d3ae478f5624b4e291553ed7511 GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cmNSAzGX7v|V@-D* z!#+P7ooWy3RSQ}6URYGTXGu_-O~4}t+uHBP_BfR&+1wLYzGtnZo9((wp*Ck`H`^=- zZ?usrV6oB9ytBuHv(3hO@3%c?IAv{C{N7{zXnn4=0{7ybEo;U1s;sHsv!Z1Bp04C8 zHt$*(_Izh=v?gHjgZ~@7|Nny*I~xf*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJ%ytJh`B>kj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3cm7*Ms*2mgd;*zSC@Db9nCV zpo4xk4fP(@j~BD-UAL@w&!vzyn<-BiY?u5xwnwK-$>x;s@;w5wZnl}%LTxV1ZMJEP zY_zE>VzJ@Qy|ZT>cbkpM!Ebv)cw}uf{_U|=+mUO%Kz8xYck9LXK3iYE$ESSy9;@^# zHc#6a_NZ|++W6+Iwb|gVZId@!#QN%+W}7onf_pilI&A`SocFTdSK3>4dCl(EJ49{f qu)eXGc2sjuuGN%1Ia4L~9y$ALx1HWfn`buLcmK)c-ka&Pat{DbPr?5H literal 0 HcmV?d00001 diff --git a/test/expect/VGGNetTester.test_vgg19_bn_expect.pkl b/test/expect/VGGNetTester.test_vgg19_bn_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..260f506eb3e4ab12bd7f56f699faeee845ed906e GIT binary patch literal 543 zcmZo*>f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y1@0|nbEUSk`BvLm}wS{nCs zNSo~a*K>W()Kq5M3lAUdw#s$g6EEImGlhND?xkBjY}j_n*dEcoV&!DC#@0$|pLO=Z z%spO{?%2HX2)AB$>V|bsovV#(!W^5}XIgtTH7?lfSw4Nw!5y-D#ghN-Y3Q%Co|czy zqj>7@9-a1YyVu=%w}`Oixb(KXPj>nAMt*V rS!?uOz72+ZA3luQ+mLh2%5Gxp-hvOC_81&t-TgE0f*}zGso?xLe`1^NxcRp-EDcxmzHGa6m!KFr6!eT=A^`z zq~;ap7sZzuaRJ?=5y1{J5U4USJ++V#qJ_&XCqJkj2}`59pjgZw7BhZ=*J#V~o5R+X~q`l7arrNlh*Rdb60zzbGZO zC^f|;Gr5GTkV7-5kh7$aE2xkg=+M|g2CP;vCh8dhZ3Y29V+GrL0>(B!s*dcw)7H2r zP~K$kiiy|vC}%L+c07Bu+oRBR&nd|!8y4|TK{oAw+y%)0ww=*c}N4(jXy08pmMZU6uP literal 0 HcmV?d00001 diff --git a/test/test_models.py b/test/test_models.py index 1864d233772..a56a463738b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -7,17 +7,29 @@ import unittest import traceback import random +import inspect -def set_rng_seed(seed): +EPSILON = 1e-5 # small value for approximate comparisons/assertions +STANDARD_NUM_CLASSES = 50 +STANDARD_INPUT_SHAPE = (1, 3, 224, 224) +STANDARD_SEED = 1729 # https://fburl.com/3i5wkg9p + +def set_rng_seed(seed=STANDARD_SEED): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) -def get_available_classification_models(): - # TODO add a registration mechanism to torchvision.models - return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + +def subsample_tensor(tensor, num_samples=20): + num_elems = tensor.numel() + if num_elems <= num_samples: + return tensor + + flat_tensor = tensor.flatten() + ith_index = num_elems // num_samples + return flat_tensor[ith_index - 1::ith_index] def get_available_segmentation_models(): @@ -40,22 +52,31 @@ def get_available_video_models(): # they are not yet supported in JIT. script_test_models = [ "deeplabv3_resnet101", - "mobilenet_v2", - "resnext50_32x4d", "fcn_resnet101", - "googlenet", - "densenet121", - "resnet18", - "alexnet", - "shufflenet_v2_x1_0", - "squeezenet1_0", - "vgg11", - "inception_v3", 'r3d_18', ] class ModelTester(TestCase): + + # create random tensor with given shape using synced RNG state + # caching because these tests take pretty long already (instantiating models and all) + TEST_INPUTS = {} + def _get_test_input(self, shape=STANDARD_INPUT_SHAPE): + # NOTE not thread-safe, but should give same results even if multi-threaded testing gave a race condition + # giving consistent results is kind of the point of this helper method + if shape not in self.TEST_INPUTS: + set_rng_seed(STANDARD_SEED) + self.TEST_INPUTS[shape] = torch.rand(shape) + return self.TEST_INPUTS[shape] + + # create a randomly-weighted model w/ synced RNG state + def _get_test_model(self, callable, **kwargs): + set_rng_seed(STANDARD_SEED) + model = callable(**kwargs) + model.eval() + return model + def check_script(self, model, name): if name not in script_test_models: return @@ -69,16 +90,268 @@ def check_script(self, model, name): msg = str(e) + str(tb) self.assertTrue(scriptable, msg) - def _test_classification_model(self, name, input_shape): - # passing num_class equal to a number other than 1000 helps in making the test - # more enforcing in nature - model = models.__dict__[name](num_classes=50) - self.check_script(model, name) - model.eval() - x = torch.rand(input_shape) - out = model(x) - self.assertEqual(out.shape[-1], 50) + def _check_scriptable(self, model, expected): + if expected is None: # we don't check scriptability for all models + return + + actual = True + msg = '' + try: + torch.jit.script(model) + except Exception as e: + tb = traceback.format_exc() + actual = False + msg = str(e) + str(tb) + self.assertEqual(actual, expected, msg) + + + +class ClassificationCoverageTester(TestCase): + + # Find all models exposed by torchvision.models factory methods (with assumptions) + def get_available_classification_models(self): + # TODO add a registration mechanism to torchvision.models + return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + + # Recursively gather test methods from all classification testers + def get_test_methods_for_class(self, klass): + all_methods = inspect.getmembers(klass, predicate=inspect.isfunction) + test_methods = set([method[0] for method in all_methods if method[0].startswith('test_')]) + for child in klass.__subclasses__(): + test_methods = test_methods.union(self.get_test_methods_for_class(child)) + return test_methods + + # Verify that all models exposed by torchvision.models factory methods + # have corresponding test methods + # NOTE This does not include some of the extra tests (such as Resnet + # dilation) and says nothing about the correctness of the test, nor + # of the model. It just enforces a naming scheme on the tests, and + # verifies that all models have a corresponding test name. + def test_classification_model_coverage(self): + model_names = self.get_available_classification_models() + test_names = self.get_test_methods_for_class(ClassificationModelTester) + + for model_name in model_names: + test_name = 'test_' + model_name + self.assertTrue(test_name in test_names) + + + +class ClassificationModelTester(ModelTester): + def _infer_for_test_with(self, model, test_input): + return model(test_input) + + def _check_classification_output_shape(self, test_output, num_classes): + self.assertEqual(test_output.shape, (1, num_classes)) + + # NOTE Depends on presence of test data fixture. See common_utils.py for + # details on creating fixtures. + def _check_model_correctness(self, model, test_input, num_classes=STANDARD_NUM_CLASSES): + test_output = self._infer_for_test_with(model, test_input) + self._check_classification_output_shape(test_output, num_classes) + self.assertExpected(test_output) + return test_output + + # NOTE override this in a child class + def _get_input_shape(self): + return STANDARD_INPUT_SHAPE + + def _test_classification_model(self, model_callable, num_classes=STANDARD_NUM_CLASSES, **kwargs): + model = self._get_test_model(model_callable, num_classes=num_classes, **kwargs) + self._check_scriptable(model, True) # currently, all expected to be scriptable + test_input = self._get_test_input(shape=self._get_input_shape()) + test_output = self._check_model_correctness(model, test_input) + return model, test_input, test_output + + + +class AlexnetTester(ClassificationModelTester): + def test_alexnet(self): + self._test_classification_model(models.alexnet) + + + +# TODO add test for aux_logits arg to factory method +# TODO add test for transform_input arg to factory method +class InceptionV3Tester(ClassificationModelTester): + def _get_input_shape(self): + return (1, 3, 299, 299) + + def test_inception_v3(self): + self._test_classification_model(models.inception_v3) + + + +class SqueezenetTester(ClassificationModelTester): + def test_squeezenet1_0(self): + self._test_classification_model(models.squeezenet1_0) + + def test_squeezenet1_1(self): + self._test_classification_model(models.squeezenet1_1) + + + +# TODO add test for width_mult arg to factory method +class MobilenetTester(ClassificationModelTester): + def test_mobilenet_v2(self): + self._test_classification_model(models.mobilenet_v2) + + def test_mobilenetv2_residual_setting(self): + self._test_classification_model(models.mobilenet_v2, inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]]) + + + +# TODO add test for aux_logits arg to factory method +# TODO add test for transform_input arg to factory method +class GooglenetTester(ClassificationModelTester): + def test_googlenet(self): + self._test_classification_model(models.googlenet) + + + +class VGGNetTester(ClassificationModelTester): + def test_vgg11(self): + self._test_classification_model(models.vgg11) + + def test_vgg11_bn(self): + self._test_classification_model(models.vgg11_bn) + + def test_vgg13(self): + self._test_classification_model(models.vgg13) + + def test_vgg13_bn(self): + self._test_classification_model(models.vgg13_bn) + + def test_vgg16(self): + self._test_classification_model(models.vgg16) + + def test_vgg16_bn(self): + self._test_classification_model(models.vgg16_bn) + + def test_vgg19(self): + self._test_classification_model(models.vgg19) + + def test_vgg19_bn(self): + self._test_classification_model(models.vgg19_bn) + + + +# TODO add test for dropout arg to factory method +class MNASNetTester(ClassificationModelTester): + def test_mnasnet0_5(self): + self._test_classification_model(models.mnasnet0_5) + + def test_mnasnet0_75(self): + self._test_classification_model(models.mnasnet0_75) + + def test_mnasnet1_0(self): + self._test_classification_model(models.mnasnet1_0) + + def test_mnasnet1_3(self): + self._test_classification_model(models.mnasnet1_3) + + + +# TODO add test for bn_size arg to factory method +# TODO add test for drop_rate arg to factory method +class DensenetTester(ClassificationModelTester): + def _test_densenet_plus_mem_eff(self, model_callable): + model, test_input, test_output = self._test_classification_model(model_callable) + + # above, we perform the standard correctness test against the test fixture, and capture key test params + # below, we check that the memory efficient/time inefficient DenseNet implementation behaves like the "standard" one + me_model = self._get_test_model(model_callable, num_classes=STANDARD_NUM_CLASSES, memory_efficient=True) + me_model.load_state_dict(model.state_dict()) # xfer weights over + me_output = self._infer_for_test_with(me_model, test_input) + test_output.squeeze(0) + me_output.squeeze(0) + self.assertTrue((test_output - me_output).abs().max() < EPSILON) + + def test_densenet121(self): + self._test_densenet_plus_mem_eff(models.densenet121) + + def test_densenet161(self): + self._test_densenet_plus_mem_eff(models.densenet161) + + def test_densenet169(self): + self._test_densenet_plus_mem_eff(models.densenet169) + + def test_densenet201(self): + self._test_densenet_plus_mem_eff(models.densenet201) + + + +class ShufflenetTester(ClassificationModelTester): + def test_shufflenet_v2_x0_5(self): + self._test_classification_model(models.shufflenet_v2_x0_5) + + def test_shufflenet_v2_x1_0(self): + self._test_classification_model(models.shufflenet_v2_x1_0) + + def test_shufflenet_v2_x1_5(self): + self._test_classification_model(models.shufflenet_v2_x1_5) + + def test_shufflenet_v2_x2_0(self): + self._test_classification_model(models.shufflenet_v2_x2_0) + + +# TODO add test for zero_init_residual arg to factory method +# TODO add test for norm_layer arg to factory method +class ResnetTester(ClassificationModelTester): + def _get_scriptability_value(self): + return True + + def test_resnet18(self): + self._test_classification_model(models.resnet18) + + def test_resnet34(self): + self._test_classification_model(models.resnet34) + + def test_resnet50(self): + self._test_classification_model(models.resnet50) + + def test_resnet101(self): + self._test_classification_model(models.resnet101) + + def test_resnet152(self): + self._test_classification_model(models.resnet152) + + def test_resnext50_32x4d(self): + self._test_classification_model(models.resnext50_32x4d) + + def test_resnext101_32x8d(self): + self._test_classification_model(models.resnext101_32x8d) + + def test_wide_resnet50_2(self): + self._test_classification_model(models.wide_resnet50_2) + + def test_wide_resnet101_2(self): + self._test_classification_model(models.wide_resnet101_2) + + def _make_sliced_model(self, model, stop_layer): + layers = OrderedDict() + for name, layer in model.named_children(): + layers[name] = layer + if name == stop_layer: + break + new_model = torch.nn.Sequential(layers) + return new_model + + def test_resnet_dilation(self): + # TODO improve tests to also check that each layer has the right dimensionality + for i in product([False, True], [False, True], [False, True]): + model = models.__dict__["resnet50"](replace_stride_with_dilation=i) + model = self._make_sliced_model(model, stop_layer="layer4") + model.eval() + x = self._get_test_input(STANDARD_INPUT_SHAPE) + out = model(x) + f = 2 ** sum(i) + self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f)) + + + +class SegmentationModelTester(ModelTester): def _test_segmentation_model(self, name): # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature @@ -90,6 +363,9 @@ def _test_segmentation_model(self, name): out = model(x) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) + + +class DetectionModelTester(ModelTester): def _test_detection_model(self, name): set_rng_seed(0) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) @@ -102,16 +378,6 @@ def _test_detection_model(self, name): self.assertIs(model_input[0], x) self.assertEqual(len(out), 1) - def subsample_tensor(tensor): - num_elems = tensor.numel() - num_samples = 20 - if num_elems <= num_samples: - return tensor - - flat_tensor = tensor.flatten() - ith_index = num_elems // num_samples - return flat_tensor[ith_index - 1::ith_index] - def compute_mean_std(tensor): # can't compute mean of integral tensor tensor = tensor.to(torch.double) @@ -132,64 +398,6 @@ def compute_mean_std(tensor): self.assertTrue("scores" in out[0]) self.assertTrue("labels" in out[0]) - def _test_video_model(self, name): - # the default input shape is - # bs * num_channels * clip_len * h *w - input_shape = (1, 3, 4, 112, 112) - # test both basicblock and Bottleneck - model = models.video.__dict__[name](num_classes=50) - self.check_script(model, name) - x = torch.rand(input_shape) - out = model(x) - self.assertEqual(out.shape[-1], 50) - - def _make_sliced_model(self, model, stop_layer): - layers = OrderedDict() - for name, layer in model.named_children(): - layers[name] = layer - if name == stop_layer: - break - new_model = torch.nn.Sequential(layers) - return new_model - - def test_memory_efficient_densenet(self): - input_shape = (1, 3, 300, 300) - x = torch.rand(input_shape) - - for name in ['densenet121', 'densenet169', 'densenet201', 'densenet161']: - model1 = models.__dict__[name](num_classes=50, memory_efficient=True) - params = model1.state_dict() - model1.eval() - out1 = model1(x) - out1.sum().backward() - - model2 = models.__dict__[name](num_classes=50, memory_efficient=False) - model2.load_state_dict(params) - model2.eval() - out2 = model2(x) - - max_diff = (out1 - out2).abs().max() - - self.assertTrue(max_diff < 1e-5) - - def test_resnet_dilation(self): - # TODO improve tests to also check that each layer has the right dimensionality - for i in product([False, True], [False, True], [False, True]): - model = models.__dict__["resnet50"](replace_stride_with_dilation=i) - model = self._make_sliced_model(model, stop_layer="layer4") - model.eval() - x = torch.rand(1, 3, 224, 224) - out = model(x) - f = 2 ** sum(i) - self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f)) - - def test_mobilenetv2_residual_setting(self): - model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]]) - model.eval() - x = torch.rand(1, 3, 224, 224) - out = model(x) - self.assertEqual(out.shape[-1], 1000) - def test_fasterrcnn_double(self): model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) model.double() @@ -205,16 +413,19 @@ def test_fasterrcnn_double(self): self.assertTrue("labels" in out[0]) -for model_name in get_available_classification_models(): - # for-loop bodies don't define scopes, so we have to save the variables - # we want to close over in some way - def do_test(self, model_name=model_name): - input_shape = (1, 3, 224, 224) - if model_name in ['inception_v3']: - input_shape = (1, 3, 299, 299) - self._test_classification_model(model_name, input_shape) - setattr(ModelTester, "test_" + model_name, do_test) +class VideoModelTester(ModelTester): + def _test_video_model(self, name): + # the default input shape is + # bs * num_channels * clip_len * h *w + input_shape = (1, 3, 4, 112, 112) + # test both basicblock and Bottleneck + model = models.video.__dict__[name](num_classes=50) + self.check_script(model, name) + x = torch.rand(input_shape) + out = model(x) + self.assertEqual(out.shape[-1], 50) + for model_name in get_available_segmentation_models(): @@ -223,7 +434,7 @@ def do_test(self, model_name=model_name): def do_test(self, model_name=model_name): self._test_segmentation_model(model_name) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(SegmentationModelTester, "test_" + model_name, do_test) for model_name in get_available_detection_models(): @@ -232,7 +443,7 @@ def do_test(self, model_name=model_name): def do_test(self, model_name=model_name): self._test_detection_model(model_name) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(DetectionModelTester, "test_" + model_name, do_test) for model_name in get_available_video_models(): @@ -240,7 +451,7 @@ def do_test(self, model_name=model_name): def do_test(self, model_name=model_name): self._test_video_model(model_name) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(VideoModelTester, "test_" + model_name, do_test) if __name__ == '__main__': unittest.main() From 884a4ca7cd0e14eaca642256261f9a2277055b66 Mon Sep 17 00:00:00 2001 From: Brad Heintz Date: Tue, 15 Oct 2019 12:07:26 -0700 Subject: [PATCH 2/5] flake8 fixes --- test/test_models.py | 45 +++++++++++++++------------------------------ 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index a56a463738b..572ed2a1ae3 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -10,10 +10,11 @@ import inspect -EPSILON = 1e-5 # small value for approximate comparisons/assertions +EPSILON = 1e-5 # small value for approximate comparisons/assertions STANDARD_NUM_CLASSES = 50 STANDARD_INPUT_SHAPE = (1, 3, 224, 224) -STANDARD_SEED = 1729 # https://fburl.com/3i5wkg9p +STANDARD_SEED = 1729 # https://fburl.com/3i5wkg9p + def set_rng_seed(seed=STANDARD_SEED): torch.manual_seed(seed) @@ -21,7 +22,6 @@ def set_rng_seed(seed=STANDARD_SEED): np.random.seed(seed) - def subsample_tensor(tensor, num_samples=20): num_elems = tensor.numel() if num_elems <= num_samples: @@ -62,6 +62,7 @@ class ModelTester(TestCase): # create random tensor with given shape using synced RNG state # caching because these tests take pretty long already (instantiating models and all) TEST_INPUTS = {} + def _get_test_input(self, shape=STANDARD_INPUT_SHAPE): # NOTE not thread-safe, but should give same results even if multi-threaded testing gave a race condition # giving consistent results is kind of the point of this helper method @@ -91,7 +92,7 @@ def check_script(self, model, name): self.assertTrue(scriptable, msg) def _check_scriptable(self, model, expected): - if expected is None: # we don't check scriptability for all models + if expected is None: # we don't check scriptability for all models return actual = True @@ -105,7 +106,6 @@ def _check_scriptable(self, model, expected): self.assertEqual(actual, expected, msg) - class ClassificationCoverageTester(TestCase): # Find all models exposed by torchvision.models factory methods (with assumptions) @@ -136,7 +136,6 @@ def test_classification_model_coverage(self): self.assertTrue(test_name in test_names) - class ClassificationModelTester(ModelTester): def _infer_for_test_with(self, model, test_input): return model(test_input) @@ -158,19 +157,17 @@ def _get_input_shape(self): def _test_classification_model(self, model_callable, num_classes=STANDARD_NUM_CLASSES, **kwargs): model = self._get_test_model(model_callable, num_classes=num_classes, **kwargs) - self._check_scriptable(model, True) # currently, all expected to be scriptable + self._check_scriptable(model, True) # currently, all expected to be scriptable test_input = self._get_test_input(shape=self._get_input_shape()) test_output = self._check_model_correctness(model, test_input) return model, test_input, test_output - class AlexnetTester(ClassificationModelTester): def test_alexnet(self): self._test_classification_model(models.alexnet) - # TODO add test for aux_logits arg to factory method # TODO add test for transform_input arg to factory method class InceptionV3Tester(ClassificationModelTester): @@ -181,7 +178,6 @@ def test_inception_v3(self): self._test_classification_model(models.inception_v3) - class SqueezenetTester(ClassificationModelTester): def test_squeezenet1_0(self): self._test_classification_model(models.squeezenet1_0) @@ -190,7 +186,6 @@ def test_squeezenet1_1(self): self._test_classification_model(models.squeezenet1_1) - # TODO add test for width_mult arg to factory method class MobilenetTester(ClassificationModelTester): def test_mobilenet_v2(self): @@ -200,7 +195,6 @@ def test_mobilenetv2_residual_setting(self): self._test_classification_model(models.mobilenet_v2, inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]]) - # TODO add test for aux_logits arg to factory method # TODO add test for transform_input arg to factory method class GooglenetTester(ClassificationModelTester): @@ -208,7 +202,6 @@ def test_googlenet(self): self._test_classification_model(models.googlenet) - class VGGNetTester(ClassificationModelTester): def test_vgg11(self): self._test_classification_model(models.vgg11) @@ -235,7 +228,6 @@ def test_vgg19_bn(self): self._test_classification_model(models.vgg19_bn) - # TODO add test for dropout arg to factory method class MNASNetTester(ClassificationModelTester): def test_mnasnet0_5(self): @@ -251,7 +243,6 @@ def test_mnasnet1_3(self): self._test_classification_model(models.mnasnet1_3) - # TODO add test for bn_size arg to factory method # TODO add test for drop_rate arg to factory method class DensenetTester(ClassificationModelTester): @@ -259,9 +250,9 @@ def _test_densenet_plus_mem_eff(self, model_callable): model, test_input, test_output = self._test_classification_model(model_callable) # above, we perform the standard correctness test against the test fixture, and capture key test params - # below, we check that the memory efficient/time inefficient DenseNet implementation behaves like the "standard" one + # below, we check that memory efficient/time inefficient DenseNet implementation behaves like the "standard" one me_model = self._get_test_model(model_callable, num_classes=STANDARD_NUM_CLASSES, memory_efficient=True) - me_model.load_state_dict(model.state_dict()) # xfer weights over + me_model.load_state_dict(model.state_dict()) # xfer weights over me_output = self._infer_for_test_with(me_model, test_input) test_output.squeeze(0) me_output.squeeze(0) @@ -280,7 +271,6 @@ def test_densenet201(self): self._test_densenet_plus_mem_eff(models.densenet201) - class ShufflenetTester(ClassificationModelTester): def test_shufflenet_v2_x0_5(self): self._test_classification_model(models.shufflenet_v2_x0_5) @@ -295,7 +285,6 @@ def test_shufflenet_v2_x2_0(self): self._test_classification_model(models.shufflenet_v2_x2_0) - # TODO add test for zero_init_residual arg to factory method # TODO add test for norm_layer arg to factory method class ResnetTester(ClassificationModelTester): @@ -307,25 +296,25 @@ def test_resnet18(self): def test_resnet34(self): self._test_classification_model(models.resnet34) - + def test_resnet50(self): self._test_classification_model(models.resnet50) - + def test_resnet101(self): self._test_classification_model(models.resnet101) - + def test_resnet152(self): self._test_classification_model(models.resnet152) - + def test_resnext50_32x4d(self): self._test_classification_model(models.resnext50_32x4d) - + def test_resnext101_32x8d(self): self._test_classification_model(models.resnext101_32x8d) - + def test_wide_resnet50_2(self): self._test_classification_model(models.wide_resnet50_2) - + def test_wide_resnet101_2(self): self._test_classification_model(models.wide_resnet101_2) @@ -348,7 +337,6 @@ def test_resnet_dilation(self): out = model(x) f = 2 ** sum(i) self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f)) - class SegmentationModelTester(ModelTester): @@ -364,7 +352,6 @@ def _test_segmentation_model(self, name): self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) - class DetectionModelTester(ModelTester): def _test_detection_model(self, name): set_rng_seed(0) @@ -413,7 +400,6 @@ def test_fasterrcnn_double(self): self.assertTrue("labels" in out[0]) - class VideoModelTester(ModelTester): def _test_video_model(self, name): # the default input shape is @@ -427,7 +413,6 @@ def _test_video_model(self, name): self.assertEqual(out.shape[-1], 50) - for model_name in get_available_segmentation_models(): # for-loop bodies don't define scopes, so we have to save the variables # we want to close over in some way From 5c4c14543823e382e59de7d9877f9a34e009fd2e Mon Sep 17 00:00:00 2001 From: Brad Heintz Date: Tue, 15 Oct 2019 16:58:16 -0700 Subject: [PATCH 3/5] took out bogus link in comment --- test/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index 572ed2a1ae3..3093b2b4dd0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -13,7 +13,7 @@ EPSILON = 1e-5 # small value for approximate comparisons/assertions STANDARD_NUM_CLASSES = 50 STANDARD_INPUT_SHAPE = (1, 3, 224, 224) -STANDARD_SEED = 1729 # https://fburl.com/3i5wkg9p +STANDARD_SEED = 1729 def set_rng_seed(seed=STANDARD_SEED): From fa0bd8976f09b27efe718db84d7d07b120def088 Mon Sep 17 00:00:00 2001 From: Brad Heintz Date: Tue, 15 Oct 2019 17:34:37 -0700 Subject: [PATCH 4/5] updated correctness checks to use tolerances more idiomatically --- .../DensenetTester.test_densenet121_expect.pkl | Bin 543 -> 543 bytes .../DensenetTester.test_densenet161_expect.pkl | Bin 543 -> 543 bytes .../DensenetTester.test_densenet169_expect.pkl | Bin 543 -> 543 bytes .../DensenetTester.test_densenet201_expect.pkl | Bin 543 -> 543 bytes test/test_models.py | 6 +++--- 5 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/expect/DensenetTester.test_densenet121_expect.pkl b/test/expect/DensenetTester.test_densenet121_expect.pkl index bb88c7faffad86407a474fb0f2ff47486270ef07..32127953ee8dcd30aebbc8ccc65f9e6172dd485b 100644 GIT binary patch delta 37 jcmbQwGM{C_9d2_&V-pJt6H`kg%Zbk;khzn+7_R~V(+mpa delta 37 jcmbQwGM{C_9d2U-GXo1tBO?=2^NG(Rkhzn+7_R~V(H;uO diff --git a/test/expect/DensenetTester.test_densenet161_expect.pkl b/test/expect/DensenetTester.test_densenet161_expect.pkl index e8f1b2e2af5fdc406a15f97a10420fa6c2f8cf95..7746061cd2cef1b89008f874ac5b7cc8bbc21602 100644 GIT binary patch delta 39 kcmbQwGM{C_9UgN-Q$tI0OH&I|OQVS|!cq8>y&10n0Owi@RsaA1 delta 39 kcmbQwGM{C_9Ufx?b3y&10n0OL3d9smFU diff --git a/test/expect/DensenetTester.test_densenet169_expect.pkl b/test/expect/DensenetTester.test_densenet169_expect.pkl index 4c3541310ddef5a3729ac68bdb02fb1caaedd02e..fe377f88b056bd0fd3d2ec303f275d467ff02104 100644 GIT binary patch delta 39 kcmbQwGM{C_9UgN-6Jui|6GLM|GmD8Y!cq8>y&10n0OQCDBme*a delta 39 kcmbQwGM{C_9Ufx?V{=1e3v&}wL!*f=!cq8>y&10n0OR-zCIA2c diff --git a/test/expect/DensenetTester.test_densenet201_expect.pkl b/test/expect/DensenetTester.test_densenet201_expect.pkl index e4b9d65585ffeec2475062d111887eb7bb7635d1..2185d458666504b51c28f90a230331b099e21f27 100644 GIT binary patch delta 37 jcmbQwGM{C_9d2_&BSRA-QxhWtvx(0mkhzn+7_R~V&%O$< delta 37 jcmbQwGM{C_9d2U-BU4L50|NsyqlwQWkhzn+7_R~V&h!eT diff --git a/test/test_models.py b/test/test_models.py index 3093b2b4dd0..c419994b75c 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -10,7 +10,6 @@ import inspect -EPSILON = 1e-5 # small value for approximate comparisons/assertions STANDARD_NUM_CLASSES = 50 STANDARD_INPUT_SHAPE = (1, 3, 224, 224) STANDARD_SEED = 1729 @@ -148,7 +147,7 @@ def _check_classification_output_shape(self, test_output, num_classes): def _check_model_correctness(self, model, test_input, num_classes=STANDARD_NUM_CLASSES): test_output = self._infer_for_test_with(model, test_input) self._check_classification_output_shape(test_output, num_classes) - self.assertExpected(test_output) + self.assertExpected(test_output, rtol=1e-5, atol=1e-5) return test_output # NOTE override this in a child class @@ -256,7 +255,8 @@ def _test_densenet_plus_mem_eff(self, model_callable): me_output = self._infer_for_test_with(me_model, test_input) test_output.squeeze(0) me_output.squeeze(0) - self.assertTrue((test_output - me_output).abs().max() < EPSILON) + # NOTE testing against same memory fixtures as the non-mem-efficient version + self.assertExpected(test_output, rtol=1e-5, atol=1e-5) def test_densenet121(self): self._test_densenet_plus_mem_eff(models.densenet121) From 8a20ae0a88e5540d5b266b4995bd58e9dc0245a1 Mon Sep 17 00:00:00 2001 From: Brad Heintz Date: Thu, 17 Oct 2019 10:30:53 -0700 Subject: [PATCH 5/5] flake8 fix --- test/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index c419994b75c..a563c35a0d4 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -115,7 +115,7 @@ def get_available_classification_models(self): # Recursively gather test methods from all classification testers def get_test_methods_for_class(self, klass): all_methods = inspect.getmembers(klass, predicate=inspect.isfunction) - test_methods = set([method[0] for method in all_methods if method[0].startswith('test_')]) + test_methods = {method[0] for method in all_methods if method[0].startswith('test_')} for child in klass.__subclasses__(): test_methods = test_methods.union(self.get_test_methods_for_class(child)) return test_methods