Permalink
Switch branches/tags
JoelMarcey-patch-1 Jorghi12-PytorchROCmRemove_Deps Jorghi12-patch-2 Jorghi12-patch-6 Jorghi12-transpiler PyRocm_removing_deps SsnL-patch-1 SsnL-patch-2 SsnL-patch-3 SsnL-patch-4 SsnL-patch-5 a5d7abe add-issue-templates always_scriptmodule anderspapitto-patch-1 batch_mm_t build_fixes build_incremental_fix cache_vars caffe2-docker-image-rebuild circleci_all_commits circleci_credentials circleci_disable_all_jobs circleci_fetch_depth circleci_fix_git circleci_flaky_tests circleci_generic_kernel circleci_test circleci_timestamp circleci cleanup_linkage cpp cudastaticfixes docker/rocm-1.8.2 docker/rocm-update export-D9526248 export-D9526650 export-D9526737 export-D9539945 export-D9540025 export-D9545248 export-D9545704 export-D9557315 export-D9561478 export-D9561802 export-D9562197 export-D9562312 export-D9562467 export-D9563464 export-D9563753 export-D9564206 export-D9564516 export-D9578397 export-D9578398 export-D9578399 export-D9578734 export-D9579371 export-D9581560 export-D9583630 export-D9583699 export-D9585627 export-D9613800 export-D9613897 export-D9614321 export-D9623916 export-D9631619 export-D9634750 export-D9634904 export-D9635105 export-D9635292 export-D9644700 export-D9644899 export-D9646190 export-D9648570 export-D9648830 export-D9652088 export-D9652089 export-D9654871 export-D9656548 export-D9657449 export-D9663476 export-D9666612 export-D9670493 export-D9676205 export-D9694097 export-D9694326 export-D9694327 export-D9694918 export-D9697878 export-D9724805 export-D9727134 export-D9727532 export-D9728631 export-D9731326 export-D9756666 export-D9757935 export-D9763422 export-D9763423 export-D9763424 export-D9771708 export-D9775191 export-D9778042 export-D9778043 export-D9779821 export-D9790187 export-D9806425 export-D9810823 export-D9811028 export-D9813544 export-D9813742 export-D9814536 export-D9823457 export-D9826209 export-D9830460 export-D9831230 export-D9831384 export-D9833361 export-D9841660 export-D9847835 export-D9847859 export-D9882726 export-D9884177 export-D9884563 export-D9889990 export-D9924348 export-D9967509 export-D9968041 export-D9968320 export-D9977058 export-D9977505 export-D9977654 export-D9979976 export-D9980641 export-D9995559 export-D9995633 export-D9996898 export-D10001033 export-D10022853 export-D10024439 export-D10024467 export-D10024485 export-D10024554 export-D10026392 export-D10030556 export-D10030819 export-D10031072 export-D10032707 export-D10033396 export-D10034589 export-D10037265 export-D10050781 export-D10050859 export-D10050905 export-D10051005 export-D10051012 export-D10051078 export-D10051079 export-D10051126 export-D10051202 export-D10051298 export-D10051365 export-D10051424 export-D10052523 export-D10069839 export-D10073519 export-D10111759 export-D10134083 export-D10134319 export-D10139933 export-D10139934 export-D10139935 export-D10150834 export-D10184116 export-D10184117 export-D10200448 export-D10204135 export-D10207890 export-D10209620 export-D10216315 export-D10222739 export-D10227820 export-D10229684 export-D10232118 export-D10232147 export-D10232154 export-D10249293 export-D10251907 export-D10255651 export-D10359443 export-D10371541 export-D10379903 export-D10380678 export-D10392295 export-D10400927 export-D10404407 export-D10415069 export-D10415430 export-D10416051 export-D10419671 export-D10421896 export-D10450290 export-D10454455 export-D10457671 export-D10467239 export-D10467556 export-D10469310 export-D10469960 export-D10476220 export-D10476225 export-D10476226 export-D10476232 export-D10476235 export-D10488399 export-D10492071 export-D10492507 export-D10496244 export-D10513246 export-D10518499 export-D10518929 export-D10520295 export-D10520421 export-D10528061 export-D10853224 export-D10855883 export-D10858024 export-D11669870 export-D12143282 export-D12832080 export-D12848855 export-D12849620 export-D12850690 export-D12850691 export-D12850833 export-D12873145 export-D12874357 export-D12894385 export-D12894386 export-D12912235 export-D12912237 export-D12912238 export-D12912239 export-D12912240 export-D12912241 export-D12912242 export-D12934074 export-D12936031 export-D12937090 export-D12937091 export-D12964886 export-D12985774 export-D13009482 export-D13011878 export-D13015236 export-D13015239 export-D13024368 export-D13025313 export-D13036478 export-D13046201 export-D13046500 export-D13046722 export-D13047468 export-D13053648 export-D13056152 export-D13062526 export-D13062564 export-D13062604 export-D13062631 export-D13062649 export-D13062706 export-D13066808 export-D13081602 export-D13081603 export-D13081604 export-D13081605 export-D13081606 export-D13081607 export-D13081608 export-D13081609 export-D13081610 export-D13104693 export-D13104694 export-D13105166 export-D13111509 export-D13111712 export-D13111781 export-D13112081 export-D13112298 export-D13113129 export-D13119624 export-D13121531 export-D13128077 export-D13128977 export-D13131338 export-D13141949 export-D13145293 export-D13156470 export-D13156471 export-D13156472 export-D13158474 export-D13158475 export-D13205022 export-D13218540 export-D13221302 export-D13223125 export-D13223126 export-D13223904 export-D13224015 export-D13235001 export-D13241355 export-D13241401 export-D13257847 export-D13258252 export-D13258512 export-D13258513 export-D13266063 export-D13267832 export-D13271560 export-D13272227 export-D13277246 export-D13277567 export-D13283492 export-D13283493 export-D13283494 export-D13283495 export-D13283496 export-D13283497 export-D13285370 export-D13287688 export-D13288655 export-D13304398 export-D13316078 export-D13318594 export-D13318596 export-D13318644 export-D13318645 export-D13336841 export-D13336842 export-D13336843 export-D13336856 export-D13348039 export-D13348040 export-D13348041 export-D13348042 export-D13348044 export-D13349163 export-D13349164 export-D13365817 ext_test_fix ezyang-patch-1 ezyang-patch-2 ezyang-patch-3 ezyang-patch-4 ezyang/retry-type-id-core ezyang/rocm-docker-update fast_dp fb-config fbsync gh/ezyang/1/base gh/ezyang/1/head gh/ezyang/1/orig gh/ezyang/2/base gh/ezyang/2/head gh/ezyang/2/orig gh/ezyang/3/base gh/ezyang/3/head gh/ezyang/3/orig gloo_dedup halfconv jerryzh168-patch-1-1 jerryzh168-patch-1 jit_frontend known-good magmatestfix master merge_variable_tensor mkl_set_dynamic nccl_fix new_symbolic_diff nn_c_port oanderso/test random_device readme_fix remove_time scalar_type simple_engine soumith-patch-1 ssnl-9348 suo/anno_test suo/annotations suo/dce2 suo/fix-expect suo/graph-equals suo/ir-parser suo/mm suo/parser suo/schematize suo/slicer tensor-merge tensorimpl_autogradmeta tensorimpl_3_variable_functions tensorimpl_4_AutogradMetaInterface test10 testing/full-caffe2 tmp_enable_scalars v0.4.1 v1.0.0 weak_tensor weak_tracing win_py2.7 windows_jit_error windows_jit_error_12378
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
94 lines (85 sloc) 3.46 KB
# This a large test that goes through the translation of the bvlc caffenet
# model, runs an example through the whole model, and verifies numerically
# that all the results look right. In default, it is disabled unless you
# explicitly want to run it.
from google.protobuf import text_format
import numpy as np
import os
import sys
CAFFE_FOUND = False
try:
from caffe.proto import caffe_pb2
from caffe2.python import caffe_translator
CAFFE_FOUND = True
except Exception as e:
# Safeguard so that we only catch the caffe module not found exception.
if ("'caffe'" in str(e)):
print(
"PyTorch/Caffe2 now requires a separate installation of caffe. "
"Right now, this is not found, so we will skip the caffe "
"translator test.")
from caffe2.python import utils, workspace, test_util
import unittest
def setUpModule():
# Do nothing if caffe and test data is not found
if not (CAFFE_FOUND and os.path.exists('data/testdata/caffe_translator')):
return
# We will do all the computation stuff in the global space.
caffenet = caffe_pb2.NetParameter()
caffenet_pretrained = caffe_pb2.NetParameter()
text_format.Merge(
open('data/testdata/caffe_translator/deploy.prototxt').read(), caffenet
)
caffenet_pretrained.ParseFromString(
open(
'data/testdata/caffe_translator/bvlc_reference_caffenet.caffemodel')
.read()
)
for remove_legacy_pad in [True, False]:
net, pretrained_params = caffe_translator.TranslateModel(
caffenet, caffenet_pretrained, is_test=True,
remove_legacy_pad=remove_legacy_pad
)
with open('data/testdata/caffe_translator/'
'bvlc_reference_caffenet.translatedmodel',
'w') as fid:
fid.write(str(net))
for param in pretrained_params.protos:
workspace.FeedBlob(param.name, utils.Caffe2TensorToNumpyArray(param))
# Let's also feed in the data from the Caffe test code.
data = np.load('data/testdata/caffe_translator/data_dump.npy').astype(
np.float32)
workspace.FeedBlob('data', data)
# Actually running the test.
workspace.RunNetOnce(net.SerializeToString())
@unittest.skipIf(not CAFFE_FOUND,
'No Caffe installation found.')
@unittest.skipIf(not os.path.exists('data/testdata/caffe_translator'),
'No testdata existing for the caffe translator test. Exiting.')
class TestNumericalEquivalence(test_util.TestCase):
def testBlobs(self):
names = [
"conv1", "pool1", "norm1", "conv2", "pool2", "norm2", "conv3",
"conv4", "conv5", "pool5", "fc6", "fc7", "fc8", "prob"
]
for name in names:
print('Verifying {}'.format(name))
caffe2_result = workspace.FetchBlob(name)
reference = np.load(
'data/testdata/caffe_translator/' + name + '_dump.npy'
)
self.assertEqual(caffe2_result.shape, reference.shape)
scale = np.max(caffe2_result)
np.testing.assert_almost_equal(
caffe2_result / scale,
reference / scale,
decimal=5
)
if __name__ == '__main__':
if len(sys.argv) == 1:
print(
'If you do not explicitly ask to run this test, I will not run it. '
'Pass in any argument to have the test run for you.'
)
sys.exit(0)
unittest.main()