We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
class Conv2dObserver(Meltable): def init(self, conv): super(Conv2dObserver, self).init() assert isinstance(conv, nn.Conv2d) self.conv = conv self.in_mask = torch.zeros(conv.in_channels).to('cpu') self.out_mask = torch.zeros(conv.out_channels).to('cpu') self.f_hook = conv.register_forward_hook(self._forward_hook)
def extra_repr(self): return '(%d, %d) -> (%d, %d)' % (self.conv.in_channels, self.conv.out_channels, int((self.in_mask != 0).sum()), int((self.out_mask != 0).sum())) def _forward_hook(self, m, _in, _out): x = _in[0] self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) def _backward_hook(self, grad): self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) new_grad = torch.ones_like(grad) return new_grad def forward(self, x): output = self.conv(x) noise = torch.zeros_like(output).normal_() output = output + noise if self.training: output.register_hook(self._backward_hook) return output def melt(self): if self.conv.groups == 1: groups = 1 elif self.conv.groups == self.conv.out_channels: groups = int((self.out_mask != 0).sum()) else: assert False print("in_channels:",int((self.in_mask != 0).sum())) print("out_channels:", int((self.out_mask != 0).sum())) print("kernel_size:", self.conv.kernel_size) print("stride:", self.conv.stride) print("padding:", self.conv.padding) print("dilation:", self.conv.dilation) print("groups:", groups) print("bias:", (self.conv.bias is not None)) replacer = nn.Conv2d( in_channels = int((self.in_mask != 0).sum()), out_channels = int((self.out_mask != 0).sum()), kernel_size = self.conv.kernel_size, stride = self.conv.stride, padding = self.conv.padding, dilation = self.conv.dilation, groups = groups, bias = (self.conv.bias is not None) ).to(self.conv.weight.device) with torch.no_grad(): if self.conv.groups == 1: replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0]) else: replacer.weight.set_(self.conv.weight[self.out_mask != 0]) if self.conv.bias is not None: replacer.bias.set_(self.conv.bias[self.out_mask != 0]) return replacer @classmethod def transform(cls, net): r = [] def _inject(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: _inject(modules[k]._modules) if isinstance(modules[k], nn.Conv2d): modules[k] = Conv2dObserver(modules[k]) r.append(modules[k]) _inject(net._modules) return r
The text was updated successfully, but these errors were encountered:
如果最后一层是卷积层,应该也是可以已被 Conv2dObserver 处理的。
Sorry, something went wrong.
No branches or pull requests
class Conv2dObserver(Meltable):
def init(self, conv):
super(Conv2dObserver, self).init()
assert isinstance(conv, nn.Conv2d)
self.conv = conv
self.in_mask = torch.zeros(conv.in_channels).to('cpu')
self.out_mask = torch.zeros(conv.out_channels).to('cpu')
self.f_hook = conv.register_forward_hook(self._forward_hook)
The text was updated successfully, but these errors were encountered: