In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def mean_features(x):
    out = x.mean(0)
    while out.dim() > 1:
        out = out.mean(1)
    return out

def fowardbn(x, gam, beta, ):
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 1
    cur_mean = mean_features(x)
    running_mean = (1 - momentum) * running_mean + momentum * cur_mean
    running_var = (1 - momentum) * running_var + momentum * cur_mean
    mean = cur_mean
    cur_var = F.relu(mean_features(x**2) - cur_mean**2)
    var = cur_var
    x_hat = (x.view(-1,x.shape[1]) - cur_mean) / torch.sqrt(cur_var + eps)
    print(gam.shape)
    print(beta.shape)
    out = gam * x_hat + beta
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out.view(x.shape), cache

model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3,5,6, requires_grad=True)
print(input1)
input2 = input1.clone().detach().requires_grad_()
y = model2(input1)

tensor([[[-1.1233,  0.1309,  0.2814, -1.9544, -1.2139,  0.9455],
         [ 2.7893, -0.2317, -1.5156,  1.2039, -1.2037,  0.9130],
         [ 1.2137,  0.6958, -0.4921, -0.8276,  0.8377,  0.9562],
         [-1.5076,  0.6363,  0.5928, -0.6300,  0.0240,  1.9158],
         [-0.5199,  0.1634,  1.4953, -0.4837,  0.2691, -0.7630]],

        [[ 0.1926, -0.1499,  1.5808,  0.2936,  0.4000, -0.7620],
         [-0.0719,  0.3883, -0.1320, -0.8559, -0.8921, -1.5119],
         [-0.6167,  0.0750, -0.8566, -0.0536,  1.6675,  0.4145],
         [ 0.2672, -0.5273, -0.6057, -0.6942,  0.3763, -0.1747],
         [-0.8692, -1.4416,  0.7252,  0.4823,  0.4696,  0.5063]],

        [[-0.5577,  0.0569,  0.0959,  0.2419, -0.4015, -1.3604],
         [ 0.6920,  1.3198,  1.4232,  0.3770, -1.3151, -0.8023],
         [-0.2655,  0.6422,  0.9520,  0.9185,  0.7974, -0.9940],
         [-0.1103, -0.6113, -0.5758, -0.8179,  1.3566,  2.2937],
         [ 0.1024, -0.5505,  0.2461, -0.4399,  0.5919, -0.7221]]],
       requires_gra

In [3]:
y.shape

torch.Size([3, 5, 6])

In [4]:
out, cache = fowardbn(input2, torch.ones(5), torch.zeros(5))

torch.Size([5])
torch.Size([5])


In [5]:
out.shape

torch.Size([3, 5, 6])

In [17]:
from torchsnooper import snoop
class ABN(nn.Module):
    def __init__(self,in_channels, out_channels):
        super(ABN,self).__init__()
        self.W_e = nn.Conv1d(in_channels, out_channels, 1, bias = True)
        self.linear_transform1 = nn.Linear(135,in_channels)
        self.linear_transform2 = nn.Linear(135,in_channels)
    
    def mean_features(self,x):
        out = x.mean(0)
        while out.dim() > 1:
            out = out.mean(1)
        return out
    @snoop()
    def forward(self, x):
        momentum, eps,running_mean,running_var = 0.1, 1e-05, 0, 1
        mean = mean_features(x)
        running_mean = (1 - momentum) * running_mean + momentum * mean
        running_var = (1 - momentum) * running_var + momentum * mean
        var = F.relu(mean_features(x**2) - mean**2)
        x_hat = (x.view(-1,x.shape[1]) - mean) / torch.sqrt(var + eps)
        e_t = torch.tanh(self.W_e(x))
        exp_mean_e_t = torch.exp(torch.mean(e_t,dim  = 2))
        a_t = exp_mean_e_t/torch.sum(exp_mean_e_t,dim = 1).unsqueeze(1).expand(-1,e_t.shape[1])
        a_t = a_t.unsqueeze(2).expand(-1,-1,e_t.shape[2])
        C_abn = torch.sum(a_t * e_t, dim = 1)
        gamma = self.linear_transform1(C_abn).mean(0)
        beta = self.linear_transform2(C_abn).mean(0)
        out = gamma * x_hat + beta
        out = out.view(x.shape)
        return out


In [18]:
test = torch.rand(10,39,135)

In [19]:
model = ABN(39,39)

In [20]:
z = model(test)

Source path:... <ipython-input-17-c886609fcd5f>
Starting var:.. self = ABN(  (W_e): Conv1d(39, 39, kernel_size=(1,), st...ear(in_features=135, out_features=39, bias=True))
Starting var:.. x = tensor<(10, 39, 135), float32, cpu>
15:47:26.051538 call        15     def forward(self, x):
15:47:26.063506 line        16         momentum, eps,running_mean,running_var = 0.1, 1e-05, 0, 1
New var:....... momentum = 0.1
New var:....... eps = 1e-05
New var:....... running_mean = 0
New var:....... running_var = 1
15:47:26.065501 line        17         mean = mean_features(x)
New var:....... mean = tensor<(39,), float32, cpu>
15:47:26.067494 line        18         running_mean = (1 - momentum) * running_mean + momentum * mean
Modified var:.. running_mean = tensor<(39,), float32, cpu>
15:47:26.072485 line        19         running_var = (1 - momentum) * running_var + momentum * mean
Modified var:.. running_var = tensor<(39,), float32, cpu>
15:47:26.077477 line        20         var = F.relu(mean_feat

In [10]:
z.shape

torch.Size([10, 39, 135])

In [11]:
t = torch.rand(10,135)

In [12]:
t.mean(dim = 0 ).shape

torch.Size([135])

In [58]:

class ABN(nn.Module):
    def __init__(self,in_channels):
        super(ABN,self).__init__()
        self.W_e = nn.Conv1d(in_channels, in_channels, 1, bias = True)
        self.linear_transform1 = nn.Linear(in_channels,in_channels,bias = True)
        self.linear_transform2 = nn.Linear(in_channels,in_channels,bias = True)
    
    def mean_features(self,x):
        out = x.mean(0)
        while out.dim() > 1:
            out = out.mean(1)
        return out
    @snoop()
    def forward(self, x):
        momentum, eps,running_mean,running_var = 0.1, 1e-05, 0, 1
        mean = self.mean_features(x).unsqueeze(0).unsqueeze(2).expand(x.shape[0],-1,x.shape[2])
        running_mean = (1 - momentum) * running_mean + momentum * mean# (39,)
        running_var = (1 - momentum) * running_var + momentum * mean# (39,)
        var = (F.relu(mean**2) - mean**2)
        x_hat = (x - mean) / torch.sqrt(var + eps)
        e_t = torch.tanh(self.W_e(x))
        exp_mean_e_t = torch.exp(torch.mean(e_t,dim  = 1))
        a_t = exp_mean_e_t/torch.sum(exp_mean_e_t,dim = 1).unsqueeze(1).expand(-1,exp_mean_e_t.shape[1])
        a_t = a_t.unsqueeze(1).expand(-1,e_t.shape[1],-1)
        C_abn = torch.sum(a_t * e_t, dim = 2)
        gamma = self.linear_transform1(C_abn).unsqueeze(2).expand(-1,-1,x.shape[2])
        beta = self.linear_transform2(C_abn).unsqueeze(2).expand(-1,-1,x.shape[2])
        out = gamma * x_hat + beta
        return out


In [59]:
model = ABN(39,135)
x = torch.rand(41,39,135)

In [60]:
model(x).shape

Source path:... <ipython-input-58-61d4016bf5a5>
Starting var:.. self = ABN(  (W_e): Conv1d(39, 39, kernel_size=(1,), st...near(in_features=39, out_features=39, bias=True))
Starting var:.. x = tensor<(41, 39, 135), float32, cpu>
17:08:25.703611 call        15     def forward(self, x):
17:08:25.710592 line        16         momentum, eps,running_mean,running_var = 0.1, 1e-05, 0, 1
New var:....... momentum = 0.1
New var:....... eps = 1e-05
New var:....... running_mean = 0
New var:....... running_var = 1
17:08:25.712587 line        17         mean = self.mean_features(x).unsqueeze(0).unsqueeze(2).expand(x.shape[0],-1,x.shape[2])
New var:....... mean = tensor<(41, 39, 135), float32, cpu>
17:08:25.715579 line        18         running_mean = (1 - momentum) * running_mean + momentum * mean# (39,)
Modified var:.. running_mean = tensor<(41, 39, 135), float32, cpu>
17:08:25.721563 line        19         running_var = (1 - momentum) * running_var + momentum * mean# (39,)
Modified var:.. running_v

torch.Size([41, 39, 135])