In [1]:
import torch
import torch.nn as nn
from time import time
# from resnet_real import ResNet18
# from resnet_quat import ResNet18_quat
from tqdm import tqdm, trange
from htorch import layers as quatnn
from torch.nn import init
from htorch.quaternion import QuaternionTensor as Q

GPU = torch.device('cuda:0')
bn_time = 0

In [2]:
class QBatchNorm2d_0(nn.Module):
    """
    Quaternion batch normalization 2d
    please check whitendxd in cplx module at https://github.com/ivannz
    """

    def __init__(self,
                 in_channels,
                 affine=True,
                 training=True,
                 eps=1e-5,
                 momentum=0.9,
                 track_running_stats=True):
        """
        @type in_channels: int
        @type affine: bool
        @type training: bool
        @type eps: float
        @type momentum: float
        @type track_running_stats: bool
        """
        super(QBatchNorm2d_0, self).__init__()
        self.in_channels = in_channels

        self.affine = affine
        self.training = training
        self.track_running_stats = track_running_stats
        self.register_buffer('eye', torch.diag(torch.cat([torch.Tensor([eps])] * 4)).unsqueeze(0))

        if self.affine:
            self.weight = torch.nn.Parameter(torch.zeros(4, 4, in_channels))
            self.bias = torch.nn.Parameter(torch.zeros(4, in_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(4, in_channels))
            self.register_buffer('running_cov', torch.zeros(in_channels, 4, 4))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_cov', None)

        self.momentum = momentum

        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_cov.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            init.constant_(self.weight[0, 0], 0.5)
            init.constant_(self.weight[1, 1], 0.5)
            init.constant_(self.weight[2, 2], 0.5)
            init.constant_(self.weight[3, 3], 0.5)

    def forward(self, x):
        # print(f"\t\t\tBN: {self.in_channels = }")
        x = torch.stack(torch.chunk(x, 4, 1), 1).permute(1, 0, 2, 3, 4)
        axes, d = (1, *range(3, x.dim())), x.shape[0]
        shape = 1, x.shape[2], *([1] * (x.dim() - 3))

        if self.training:
            mean = x.mean(dim=axes)
            if self.running_mean is not None:
                with torch.no_grad():
                    self.running_mean = self.momentum * self.running_mean + \
                                        (1.0 - self.momentum) * mean
        else:
            mean = self.running_mean

        x = x - mean.reshape(d, *shape)

        if self.training:
            perm = x.permute(2, 0, *axes).flatten(2, -1)
            cov = torch.matmul(perm, perm.transpose(-1, -2)) / perm.shape[-1]

            if self.running_cov is not None:
                with torch.no_grad():
                    self.running_cov = self.momentum * self.running_cov + \
                                       (1.0 - self.momentum) * cov

        else:
            cov = self.running_cov

        ell = torch.cholesky(cov + self.eye, upper=True)
        soln = torch.triangular_solve(
            x.unsqueeze(-1).permute(*range(1, x.dim()), 0, -1),
            ell.reshape(*shape, d, d)
        )

        wht = soln.solution.squeeze(-1)
        z = torch.stack(torch.unbind(wht, dim=-1), dim=0)

        if self.affine:
            weight = self.weight.view(4, 4, *shape)
            scaled = torch.stack([
                z[0] * weight[0, 0] + z[1] * weight[0, 1] + z[2] * weight[0, 2] + z[3] * weight[0, 3],
                z[0] * weight[1, 0] + z[1] * weight[1, 1] + z[2] * weight[1, 2] + z[3] * weight[1, 3],
                z[0] * weight[2, 0] + z[1] * weight[2, 1] + z[2] * weight[2, 2] + z[3] * weight[2, 3],
                z[0] * weight[3, 0] + z[1] * weight[3, 1] + z[2] * weight[3, 2] + z[3] * weight[3, 3],
            ], dim=0)
            z = scaled + self.bias.reshape(4, *shape)

        z = torch.cat(torch.chunk(z, 4, 0), 2).squeeze()
        
        if z.dim() == 2:
            z = z.reshape(z.shape[0], z.shape[1], 1, 1)

        # print(f"shape before returning: {z.shape = }")

        return Q(z)

In [3]:
model_real = nn.BatchNorm2d(8).to(GPU)
model_quat = QBatchNorm2d_0(2).to(GPU)

In [4]:
batchx = torch.randn(256, 8, 224, 224).to(GPU)

In [5]:
model_quat(batchx).shape

L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A).mH().
This transform will produce equivalent results for all valid (symmetric positive definite) inputs. (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1691.)
  ell = torch.cholesky(cov + self.eye, upper=True)
torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2191.)
  soln = torch.triangular_solve(


torch.Size([128, 8, 224, 224])

In [6]:
repeat = 10000
e = 10

start = time()
for i in trange(repeat):
    output = model_real(batchx)
end = time()
real_time = (end - start)

start = time()
for i in trange(repeat//e):
    output = model_quat(batchx)
end = time()
quat_time = (end - start)*e

print(f"Real: {real_time}")
print(f"Quat: {quat_time}")

100%|██████████| 10000/10000 [00:23<00:00, 417.67it/s]
100%|██████████| 1000/1000 [00:38<00:00, 25.86it/s]

Real: 23.94432520866394
Quat: 386.65738582611084





In [7]:
print(f"Quat is about {quat_time/real_time:.1f}x slower than Real")

Quat is about 16.1x slower than Real


In [8]:
model_real.state_dict().keys()

odict_keys(['weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked'])

In [9]:
model_quat.state_dict().keys()

odict_keys(['weight', 'bias', 'eye', 'running_mean', 'running_cov'])

In [10]:
print(model_real.weight.shape)
print(model_quat.weight.shape)

torch.Size([8])
torch.Size([4, 4, 2])


In [11]:
model_quat.weight

Parameter containing:
tensor([[[0.5000, 0.5000],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000]],

        [[0.0000, 0.0000],
         [0.5000, 0.5000],
         [0.0000, 0.0000],
         [0.0000, 0.0000]],

        [[0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.5000, 0.5000],
         [0.0000, 0.0000]],

        [[0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.5000, 0.5000]]], device='cuda:0', requires_grad=True)

In [12]:
import numpy as np

In [13]:
x = np.random.rand(5,2)
g = np.random.rand(1,2)
b = np.random.rand(1,2)

In [14]:
b

array([[0.19776855, 0.48728672]])

In [15]:
x*g

array([[0.03381753, 0.20481149],
       [0.5940217 , 0.2110181 ],
       [0.20750374, 0.19137958],
       [0.34746205, 0.19015002],
       [0.01904878, 0.03207484]])

In [16]:
(x*g) + b

array([[0.23158609, 0.69209821],
       [0.79179025, 0.69830482],
       [0.40527229, 0.6786663 ],
       [0.5452306 , 0.67743674],
       [0.21681733, 0.51936156]])