In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

In [79]:
class SaveOutput:
    def __init__(self):
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)
        
    def clear(self):
        self.outputs = []


class Net(nn.Module):

    def __init__(self):
        # 여기에 모든 모듈을 생성해두고,
        # 나중에 여기에서 선언해둔 이름으로 사용할 수 있습니다.
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(10, 20, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    # 순전파 함수에서는 신경망의 구조를 정의합니다.
    # 여기에서는 단 하나의 입력만 받지만, 필요하면 더 받도록 변경하면 됩니다.
    def forward(self, input):
        x = self.pool1(F.relu(self.conv1(input)))
        x = self.pool2(F.relu(self.conv2(x)))

        # 모델 구조를 정의할 때는 어떤 Python 코드를 사용해도 괜찮습니다.
        # 모든 코드는 autograd에 의해 올바르고 완벽하게 처리될 것입니다.
        # if x.gt(0) > x.numel() / 2:
        #      ...
        #
        # 심지어 반복문을 만들고 그 안에서 동일한 모듈을 재사용해도 됩니다.
        # 모듈은 더 이상 일시적인 상태를 갖고 있지 않으므로,
        # 순전파 단계에서 여러번 사용해도 괜찮습니다.
        # while x.norm(2) < 10:
        #    x = self.conv1(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x


net = Net()


save_output = SaveOutput()
hook_handles = []

for name,layer in net.named_modules():
    print('=================================')
    print(name)
    print(layer)
    
    handle = layer.register_forward_hook(save_output)
    hook_handles.append(handle)
print(net)


Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)
conv1
Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
pool1
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
conv2
Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
pool2
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
fc1
Linear(in_features=320, out_features=50, bias=True)
fc2
Linear(in_features=50, out_features=10, bias=True)
Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(10, 20, kernel_size=(5, 5

In [72]:
print(save_output.outputs[:-1])

[tensor([-0.0232, -0.2018,  0.5397, -0.6820, -0.7194, -1.0043,  0.1066,  0.5058,
        -0.5767,  0.8421], grad_fn=<AddBackward0>), tensor([ 0.1940,  0.2531, -0.2500, -0.1275, -0.3000], grad_fn=<AddBackward0>), tensor([0.5303], grad_fn=<AddBackward0>)]


In [40]:
for name, param in net.named_parameters():
    print(name)
    print(param.grad)
    print(param.requires_grad)


fc1.weight
None
True
fc1.bias
None
True
fc2.weight
None
True
fc2.bias
None
True
fc3.weight
None
True
fc3.bias
None
True


In [41]:
freeze_parameters(net)

In [42]:
for name, param in net.named_parameters():
    print(name)
    print(param.grad)
    print(param.requires_grad)


fc1.weight
None
False
fc1.bias
None
False
fc2.weight
None
False
fc2.bias
None
False
fc3.weight
None
False
fc3.bias
None
False


In [80]:
def pod(
    list_attentions_a,
    list_attentions_b,
    collapse_channels="spatial",
    normalize=True,):
    loss = torch.tensor(0.).to(list_attentions_a[0].device)
    for i, (a, b) in enumerate(zip(list_attentions_a, list_attentions_b)):
        a = torch.pow(a, 2)
        b = torch.pow(b, 2)
        if collapse_channels == "channels":
            a = a.sum(dim=1).view(a.shape[0], -1)  # shape of (b, w * h)
            b = b.sum(dim=1).view(b.shape[0], -1)
        elif collapse_channels == "width":
            a = a.sum(dim=2).view(a.shape[0], -1)  # shape of (b, c * h)
            b = b.sum(dim=2).view(b.shape[0], -1)
        elif collapse_channels == "height":
            a = a.sum(dim=3).view(a.shape[0], -1)  # shape of (b, c * w)
            b = b.sum(dim=3).view(b.shape[0], -1)
        elif collapse_channels == "gap":
            a = F.adaptive_avg_pool2d(a, (1, 1))[..., 0, 0]
            b = F.adaptive_avg_pool2d(b, (1, 1))[..., 0, 0]
        elif collapse_channels == "spatial":
            a_h = a.sum(dim=3).view(a.shape[0], -1)
            b_h = b.sum(dim=3).view(b.shape[0], -1)
            a_w = a.sum(dim=2).view(a.shape[0], -1)
            b_w = b.sum(dim=2).view(b.shape[0], -1)
            a = torch.cat([a_h, a_w], dim=-1)
            b = torch.cat([b_h, b_w], dim=-1)
        else:
            raise ValueError("Unknown method to collapse: {}".format(collapse_channels))

        if normalize:
            a = F.normalize(a, dim=1, p=2)
            b = F.normalize(b, dim=1, p=2)

        layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1))
        loss += layer_loss
    return loss / len(list_attentions_a)

In [17]:
class SaveOutput:
    def __init__(self):
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)
        
    def clear(self):
        self.outputs = []
    
    @property
    def outputs(self):
        return self.outputs
    


def freeze_parameters(model):
    for n,m in model.named_parameters():
        param.grad=None
        m.requires_grad=False

class global_and_online_model(nn.Module):
    def __init__(self,args,online_model,global_model):
        super(global_and_online_mode,self).__init__()
        self.args=args
        self.online_model=online_model
        self.global_model=copy.deepcopy(global_model)
        freeze_parameters(self.global_model)

        
        self.online_save_output = SaveOutput()
        self.global_save_output = SaveOutput()
    

    
        online_hook_handles=[]
        for layer in self.online_model.modules():
            handle = layer.register_forward_hook(self.online_save_output)
            online_hook_handles.append(handle)
        
        global_hook_handles=[]
        for layer in self.global_model.modules():
            handle = layer.register_forward_hook(self.global_save_output)
            global_hook_handles.append(handle)
    
    def 
    
    
    
    def forward(self, x,online_target=False):
        if online_target==False:
            return self.online_model(x)
        
        else:
            self.online_save_output.clear()
            self.global_save_output.clear()
            
            
            x=self.online_model(x)
            online_outputs=self.online_save_output.outputs[:-1]
            
            
            x1=copy.deepcopy(x)
            global_outputst=self.global_save_output.outputs[:-1]
            
            activation_loss=pod(    list_attentions_a=online_outputs,
                    list_attentions_b=global_outputs,
                    collapse_channels=self.args.collapse_channels,
                    normalize=self.args.pod_normalize,)
            
            return x,activation_loss
            
        
        

            