Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

如果我的网络最后一层是卷积层,我应该怎么去修改Conv2dObserver中的代码呢?谢谢 #10

Open
cvJie opened this issue Oct 9, 2019 · 1 comment

Comments

@cvJie
Copy link

cvJie commented Oct 9, 2019

class Conv2dObserver(Meltable):
def init(self, conv):
super(Conv2dObserver, self).init()
assert isinstance(conv, nn.Conv2d)
self.conv = conv
self.in_mask = torch.zeros(conv.in_channels).to('cpu')
self.out_mask = torch.zeros(conv.out_channels).to('cpu')
self.f_hook = conv.register_forward_hook(self._forward_hook)

def extra_repr(self):
    return '(%d, %d) -> (%d, %d)' % (self.conv.in_channels, self.conv.out_channels, int((self.in_mask != 0).sum()), int((self.out_mask != 0).sum()))

def _forward_hook(self, m, _in, _out):
    x = _in[0]
    self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)

def _backward_hook(self, grad):
    self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
    new_grad = torch.ones_like(grad)
    return new_grad

def forward(self, x):
    output = self.conv(x)
    noise = torch.zeros_like(output).normal_()
    output = output + noise
    if self.training:
        output.register_hook(self._backward_hook)
    return output

def melt(self):
    if self.conv.groups == 1:
        groups = 1
    elif self.conv.groups == self.conv.out_channels:
        groups = int((self.out_mask != 0).sum())
    else:
        assert False

    print("in_channels:",int((self.in_mask != 0).sum()))
    print("out_channels:", int((self.out_mask != 0).sum()))
    print("kernel_size:", self.conv.kernel_size)
    print("stride:", self.conv.stride)
    print("padding:", self.conv.padding)
    print("dilation:", self.conv.dilation)
    print("groups:", groups)
    print("bias:", (self.conv.bias is not None))

    replacer = nn.Conv2d(
        in_channels = int((self.in_mask != 0).sum()),
        out_channels = int((self.out_mask != 0).sum()),
        kernel_size = self.conv.kernel_size,
        stride = self.conv.stride,
        padding = self.conv.padding,
        dilation = self.conv.dilation,
        groups = groups,
       bias = (self.conv.bias is not None)
    ).to(self.conv.weight.device)

    with torch.no_grad():
        if self.conv.groups == 1:
            replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0])
        else:
            replacer.weight.set_(self.conv.weight[self.out_mask != 0])
        if self.conv.bias is not None:
            replacer.bias.set_(self.conv.bias[self.out_mask != 0])
    return replacer

@classmethod
def transform(cls, net):
    r = []
    def _inject(modules):
        keys = modules.keys()
        for k in keys:
            if len(modules[k]._modules) > 0:
                _inject(modules[k]._modules)
            if isinstance(modules[k], nn.Conv2d):
                modules[k] = Conv2dObserver(modules[k])
                r.append(modules[k])
    _inject(net._modules)
    return r
@youzhonghui
Copy link
Owner

如果最后一层是卷积层,应该也是可以已被 Conv2dObserver 处理的。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants