-
Notifications
You must be signed in to change notification settings - Fork 1
/
local.m
73 lines (63 loc) · 2.27 KB
/
local.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
classdef Blob < handle
properties (Access = private)
hBlob_self
end
methods
function self = Blob(hBlob_blob)
CHECK(is_valid_handle(hBlob_blob), 'invalid Blob handle');
self.hBlob_self = hBlob_blob;
end
function shape = shape(self)
shape = caffe_('blob_get_shape', self.hBlob_self);
end
function reshape(self, shape)
shape = self.check_and_preprocess_shape(shape);
caffe_('blob_reshape', self.hBlob_self, shape);
end
function data = get_data(self)
data = caffe_('blob_get_data', self.hBlob_self);
end
function set_data(self, data)
data = self.check_and_preprocess_data(data);
caffe_('blob_set_data', self.hBlob_self, data);
end
function diff = get_diff(self)
diff = caffe_('blob_get_diff', self.hBlob_self);
end
function set_diff(self, diff)
diff = self.check_and_preprocess_data(diff);
caffe_('blob_set_diff', self.hBlob_self, diff);
end
end
methods (Access = private)
function shape = check_and_preprocess_shape(~, shape)
CHECK(isempty(shape) || (isnumeric(shape) && isrow(shape)), ...
'shape must be a integer row vector');
shape = double(shape);
end
function data = check_and_preprocess_data(self, data)
CHECK(isnumeric(data), 'data or diff must be numeric types');
self.check_data_size_matches(data);
if ~isa(data, 'single')
data = single(data);
end
end
function check_data_size_matches(self, data)
self_shape_extended = self.shape;
if isempty(self_shape_extended)
self_shape_extended = [1, 1];
elseif isscalar(self_shape_extended)
self_shape_extended = [self_shape_extended, 1];
end
data_size_extended = ...
[size(data), ones(1, length(self_shape_extended) - ndims(data))];
is_matched = ...
(length(self_shape_extended) == length(data_size_extended)) ...
&& all(self_shape_extended == data_size_extended);
CHECK(is_matched, ...
sprintf('%s, input data/diff size: [ %s] vs target blob shape: [ %s]', ...
'input data/diff size does not match target blob shape', ...
sprintf('%d ', data_size_extended), sprintf('%d ', self_shape_extended)));
end
end
end