-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathsliceLayer.m
41 lines (34 loc) · 1.21 KB
/
sliceLayer.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
classdef sliceLayer < nnet.layer.Layer
%#codegen
% Custom layer used for channel grouping.
% Copyright 2021 The MathWorks, Inc.
properties
connectID
groups
group_id
end
methods
function layer = sliceLayer(name,con,groups,group_id)
% Set layer name.
layer.Name = name;
% Set layer description.
text = [ num2str(groups), ' groups,group_id: ', num2str(group_id), ' sliceLayer '];
layer.Description = text;
% Set layer type.
layer.Type = 'sliceLayer';
% Set other properties.
layer.connectID= con;
layer.groups= groups;
layer.group_id= group_id;
assert(group_id>0,'group_id must great zero! it must start index from 1');
end
function Z = predict(layer, X)
X = reshape(X,[size(X),1]);
channels = size(X,3);
deltaChannels = channels/layer.groups;
selectStart = (layer.group_id-1)*deltaChannels+1;
selectEnd = layer.group_id*deltaChannels;
Z = X(:,:,selectStart:selectEnd,:);
end
end
end