Skip to content

Commit

Permalink
diagnose option: get_entry to print a whole row (#11308)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #11308

Pull Request resolved: #11299

Reviewed By: xianjiec

Differential Revision: D9652844

fbshipit-source-id: 650d550317bfbed0c1f25ae7d74286cfc7c3ac70
  • Loading branch information
Wakeupbuddy authored and facebook-github-bot committed Sep 7, 2018
1 parent 2946b02 commit c59c1a2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
20 changes: 13 additions & 7 deletions caffe2/python/modeling/get_entry_from_blobs.py
Expand Up @@ -33,21 +33,23 @@ class GetEntryFromBlobs(NetModifier):
blobs: list of blobs to get entry from
logging_frequency: frequency for printing entry values to logs
i1, i2: the first, second dimension of the blob. (currently, we assume
the blobs to be 2-dimensional blobs)
the blobs to be 2-dimensional blobs). When i2 = -1, print all entries
in blob[i1]
"""

def __init__(self, blobs, logging_frequency, i1=0, i2=0):
self._blobs = blobs
self._logging_frequency = logging_frequency
self._i1 = i1
self._i2 = i2
self._field_name_suffix = '_{0}_{1}'.format(i1, i2)
self._field_name_suffix = '_{0}_{1}'.format(i1, i2) if i2 >= 0 \
else '_{0}_all'.format(i1)

def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
modify_output_record=False):

i1, i2 = [self._i1, self._i2]
if i1 < 0 or i2 < 0:
if i1 < 0:
raise ValueError('index is out of range')

for blob_name in self._blobs:
Expand All @@ -57,16 +59,20 @@ def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
blob, net.Name()))

blob_i1 = net.Slice([blob], starts=[i1, 0], ends=[i1 + 1, -1])
blob_i1_i2 = net.Slice([blob_i1],
net.NextScopedBlob(prefix=blob + '_{0}_{1}'.format(i1, i2)),
starts=[0, i2], ends=[-1, i2 + 1])
if self._i2 == -1:
blob_i1_i2 = net.Copy([blob_i1],
[net.NextScopedBlob(prefix=blob + '_{0}_all'.format(i1))])
else:
blob_i1_i2 = net.Slice([blob_i1],
net.NextScopedBlob(prefix=blob + '_{0}_{1}'.format(i1, i2)),
starts=[0, i2], ends=[-1, i2 + 1])

if self._logging_frequency >= 1:
net.Print(blob_i1_i2, [], every_n=self._logging_frequency)

if modify_output_record:
output_field_name = str(blob) + self._field_name_suffix
output_scalar = schema.Scalar((np.float, (1,)), blob_i1_i2)
output_scalar = schema.Scalar((np.float), blob_i1_i2)

if net.output_record() is None:
net.set_output_record(
Expand Down
16 changes: 12 additions & 4 deletions caffe2/python/modeling/get_entry_from_blobs_test.py
Expand Up @@ -60,7 +60,7 @@ def test_get_entry_from_blobs_modify_output_record(self):

# no operator name set, will use default
brew.fc(model, fc1, "fc2", dim_in=4, dim_out=4)
i1, i2 = np.random.randint(4, size=2)
i1, i2 = np.random.randint(4), np.random.randint(5) - 1
net_modifier = GetEntryFromBlobs(
blobs=['fc1_w', 'fc2_w'],
logging_frequency=10,
Expand All @@ -74,10 +74,18 @@ def test_get_entry_from_blobs_modify_output_record(self):
workspace.RunNetOnce(model.net)

fc1_w = workspace.FetchBlob('fc1_w')
fc1_w_entry = workspace.FetchBlob('fc1_w_{0}_{1}'.format(i1, i2))
if i2 < 0:
fc1_w_entry = workspace.FetchBlob('fc1_w_{0}_all'.format(i1))
else:
fc1_w_entry = workspace.FetchBlob('fc1_w_{0}_{1}'.format(i1, i2))

self.assertEqual(fc1_w_entry.size, 1)
self.assertEqual(fc1_w_entry[0], fc1_w[i1][i2])
if i2 < 0:
self.assertEqual(fc1_w_entry.size, 4)
for j in range(4):
self.assertEqual(fc1_w_entry[0][j], fc1_w[i1][j])
else:
self.assertEqual(fc1_w_entry.size, 1)
self.assertEqual(fc1_w_entry[0], fc1_w[i1][i2])

assert 'fc1_w' + net_modifier.field_name_suffix() in\
model.net.output_record().field_blobs(),\
Expand Down

0 comments on commit c59c1a2

Please sign in to comment.