In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
import numpy as np

In [2]:
# batch_size, in_channels, seq_len
class ACNN(nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size, seq_len, N):
        super(ACNN,self).__init__()
        self.W_e = nn.Conv1d(in_channels, out_channels,kernel_size)
        self.W_a = nn.Conv1d(in_channels, out_channels,kernel_size)
        self.params = nn.ParameterDict({
                'b_e': nn.Parameter(torch.randn(1)),
                'b_a': nn.Parameter(torch.randn(1)),
                'v':nn.Parameter(torch.rand(seq_len))
        })
        self.W_b_list = nn.ModuleList()
        self.N = N
        for i in range(N):
            self.W_b_list.append(nn.Conv1d(in_channels = in_channels, 
                                        out_channels = out_channels, 
                                        kernel_size = kernel_size,bias = True))

        # self.b_e = b_e
        # self.b_a = b_a
        # self.v = v #(seq_len,)
        self.linear_combination = nn.Linear(seq_len*2,N)
    def forward(self, x):
        e_t = self.W_e(x) + self.params['b_e']#好像Bias = True就可以了
        a_t = torch.matmul(torch.tanh(self.W_a(x) + self.params['b_a']),self.params['v'])
        a_t = a_t.unsqueeze(2).expand(-1,-1,e_t.shape[2])
        u = torch.sum(a_t * e_t, dim = 1)
        sigma = torch.sqrt(torch.sum(a_t*(e_t*e_t),dim = 1) - u*u)#batch,seq_len
        C_acnn = torch.cat([u,sigma], dim = 1)#batch,seq_len*2
        beta = self.linear_combination(C_acnn)#batch,N 有N个filter
        out = sum([conv(x)* beta.unsqueeze(1).unsqueeze(2).expand(-1,conv(x).shape[1],conv(x).shape[2]) for conv,beta in zip(self.W_b_list, beta.T)])/self.N
        return out 

In [3]:
acnn_layer = ACNN(512,512,1,135,3)

In [4]:
conv1_layer = nn.Conv1d(512,512,1)

In [5]:
input_feature = torch.rand(10,512,135)

In [6]:
conv_output = conv1_layer(input_feature)

In [7]:
conv_output.shape

torch.Size([10, 512, 135])

In [8]:
acnn_layer(input_feature).shape

torch.Size([10, 512, 135])

In [86]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchsnooper import snoop
class ACNN(nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size, c, N):
        super(ACNN,self).__init__()
        self.W_e = nn.Conv1d(in_channels, in_channels,1, bias = True)
        self.W_a = nn.Conv1d(in_channels, 1, 1, bias = True)
        self.params = nn.ParameterDict({
                'v':nn.Parameter(torch.rand(c))
        })
        self.W_b_list = nn.ModuleList()
        self.N = N
        for i in range(N):
            self.W_b_list.append(nn.Conv1d(in_channels = in_channels, 
                                        out_channels = out_channels, 
                                        kernel_size = kernel_size,bias = True))
        self.linear_combination = nn.Linear(in_channels*2,N)
    @snoop()
    def forward(self, x):
        e_t = self.W_e(x)
        a_t = torch.tanh(self.W_a(x))
        a_t = a_t.expand(-1,e_t.shape[1],-1)
        u = torch.sum(a_t * e_t, dim = 2)
        sigma = torch.sqrt(torch.sum(a_t*(e_t*e_t),dim = 2) - u*u)
        C_acnn = torch.cat([u,sigma], dim = 1)
        beta = self.linear_combination(C_acnn)
        out = sum([conv(x)* beta.unsqueeze(1).unsqueeze(2).expand(-1,conv(x).shape[1],conv(x).shape[2]) for conv,beta in zip(self.W_b_list, beta.T)])/self.N
        return out 

In [87]:
x = torch.rand(42,39,135)


In [88]:
model = ACNN(39,512,2,135,3)

In [89]:
model(x)

Source path:... <ipython-input-86-0564964b7bd3>
Starting var:.. self = ACNN(  (W_e): Conv1d(39, 39, kernel_size=(1,), s...inear(in_features=78, out_features=3, bias=True))
Starting var:.. x = tensor<(42, 39, 135), float32, cpu>
15:44:39.921010 call        22     def forward(self, x):
15:44:39.926994 line        23         e_t = self.W_e(x)
New var:....... e_t = tensor<(42, 39, 135), float32, cpu, grad>
15:44:39.930983 line        24         a_t = torch.tanh(self.W_a(x))
New var:....... a_t = tensor<(42, 1, 135), float32, cpu, grad>
15:44:39.935970 line        25         a_t = a_t.expand(-1,e_t.shape[1],-1)
Modified var:.. a_t = tensor<(42, 39, 135), float32, cpu, grad>
15:44:39.939959 line        26         u = torch.sum(a_t * e_t, dim = 2)
New var:....... u = tensor<(42, 39), float32, cpu, grad>
15:44:39.945943 line        27         sigma = torch.sqrt(torch.sum(a_t*(e_t*e_t),dim = 2) - u*u)
New var:....... sigma = tensor<(42, 39), float32, cpu, grad, has_nan>
15:44:39.951928 line    

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        ...,

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]

In [72]:
v = torch.rand(42,3)

In [73]:
w_b_list = nn.ModuleList()
for i in range(3):
    w_b_list.append(nn.Conv1d(in_channels = 39, 
                                out_channels = 512, 
                                kernel_size = 1,bias = True))

In [74]:
g = torch.rand(42,39,135)

In [75]:
s = w_b_list[0](g).shape

In [76]:
out = sum([conv(g)* beta.unsqueeze(1).unsqueeze(2).expand(-1,conv(g).shape[1],conv(g).shape[2]) for conv,beta in zip(w_b_list, v.T)])/2

In [81]:
for conv,beta in zip(w_b_list, v.T):
    #转职一下就是把每一个conv的系数提到前面来，这样在for的时候就是一个一个取出来的
    print(beta.shape)
    print(beta.unsqueeze(1).unsqueeze(2).shape)
    print(beta.unsqueeze(1).unsqueeze(2).expand(-1,conv(g).shape[1],conv(g).shape[2]).shape)

torch.Size([42])
torch.Size([42, 1, 1])
torch.Size([42, 512, 135])
torch.Size([42])
torch.Size([42, 1, 1])
torch.Size([42, 512, 135])
torch.Size([42])
torch.Size([42, 1, 1])
torch.Size([42, 512, 135])


In [59]:
j = v.unsqueeze(1)

In [54]:
j.shape

torch.Size([2, 1, 42])

In [56]:
s

torch.Size([42, 512, 135])

In [60]:
k = j.expand(42,s[1],s[2])

RuntimeError: The expanded size of the tensor (135) must match the existing size (2) at non-singleton dimension 2.  Target sizes: [42, 512, 135].  Tensor sizes: [42, 1, 2]

In [47]:
for conv, beta in zip(w_b_list, v.T):
    print(beta)

tensor([0.0777, 0.3039, 0.2183, 0.6788, 0.5681, 0.6752, 0.2077, 0.6903, 0.5494,
        0.9487, 0.3374, 0.5979, 0.4314, 0.2119, 0.8146, 0.8913, 0.5831, 0.3882,
        0.2667, 0.2575, 0.1272, 0.0443, 0.4212, 0.8634, 0.8978, 0.8854, 0.5142,
        0.6452, 0.7947, 0.5458, 0.7659, 0.0603, 0.1320, 0.4148, 0.8459, 0.0586,
        0.6056, 0.6519, 0.9173, 0.5364, 0.8322, 0.1111])
tensor([0.8051, 0.3497, 0.3009, 0.4961, 0.7294, 0.2902, 0.4880, 0.4592, 0.9567,
        0.1518, 0.1337, 0.2897, 0.2017, 0.8400, 0.0687, 0.3372, 0.6599, 0.6890,
        0.4962, 0.1497, 0.0524, 0.9495, 0.2119, 0.8436, 0.2531, 0.7784, 0.4149,
        0.8176, 0.5467, 0.2176, 0.4249, 0.3640, 0.3473, 0.3341, 0.6063, 0.0459,
        0.9975, 0.0631, 0.1189, 0.9486, 0.9079, 0.7226])
