In [48]:
import torch
import torch.nn.functional as F
import numpy as np
import math

In [49]:
# Define input tensors
a = torch.randn(2, 3, 4)  # Batch size 2, sequence length 3, quaternion components 4 
b = torch.randn(2, 3, 4)
x = torch.randn(2, 3, 4)  # Input for FFN
dim = 4  # Dimension of the feed-forward network
num_layers = 2  # Number of layers in FFN
activation = F.relu  # Activation function


In [50]:
def make_quaternion_mul(kernel, concat_dim=0):
    '''
    
    '''
    r, i, j, k = torch.chunk(kernel, 4, dim=-1)
    r2 = torch.cat([r, -i, -j, -k], dim=-1)
    i2 = torch.cat([i, r, -k, j], dim=-1)
    j2 = torch.cat([j, k, r, -i], dim=-1)
    k2 = torch.cat([k, -j, i, r], dim=-1)
    hamilton = torch.cat([r2, i2, j2, k2], dim=concat_dim)
    return hamilton


In [51]:
# Example for make_quaternion_mul
kernel = torch.randn(1, 4)
hamilton = make_quaternion_mul(kernel)
print("Hamilton product kernel:\n", hamilton)
print(hamilton.shape)

Hamilton product kernel:
 tensor([[-0.6688,  1.0368, -2.5359,  0.2173],
        [-1.0368, -0.6688,  0.2173,  2.5359],
        [ 2.5359, -0.2173, -0.6688,  1.0368],
        [-0.2173, -2.5359, -1.0368, -0.6688]])
torch.Size([4, 4])


In [52]:
def get_r(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[0]

def get_i(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[1]

def get_j(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[2]

def get_k(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[3]


In [53]:
def quaternion_attention(a, b):
    """Performs dot product attention between two quaternion sequences."""
    print("light Attention!")
    print(a)
    print(b)
    
    ar, ai, aj, ak = torch.chunk(a, 4, dim=-1)
    br, bi, bj, bk = torch.chunk(b, 4, dim=-1)
    
    r = torch.matmul(ar, br.transpose(-2, -1)) - torch.matmul(ai, bi.transpose(-2, -1)) - torch.matmul(aj, bj.transpose(-2, -1)) - torch.matmul(ak, bk.transpose(-2, -1))
    i = torch.matmul(ar, bi.transpose(-2, -1)) + torch.matmul(ai, br.transpose(-2, -1)) + torch.matmul(aj, bk.transpose(-2, -1)) - torch.matmul(ak, bj.transpose(-2, -1))
    j = torch.matmul(ar, bj.transpose(-2, -1)) - torch.matmul(ai, bk.transpose(-2, -1)) + torch.matmul(aj, br.transpose(-2, -1)) + torch.matmul(ak, bi.transpose(-2, -1))
    k = torch.matmul(ar, bk.transpose(-2, -1)) + torch.matmul(ai, bj.transpose(-2, -1)) - torch.matmul(aj, bi.transpose(-2, -1)) + torch.matmul(ak, br.transpose(-2, -1))
    
    return [r, i, j, k]


In [54]:
def quaternion_dot_product_att(a, b):
    """Wrapper for two sequences."""
    al = a.shape[1]
    bl = b.shape[1]
    d = a.shape[2]
    bsz = b.shape[0]
    
    a = a.view(-1, d)
    a = a.repeat(bl, 1)
    b = b.view(-1, d)
    b = b.repeat(al, 1)
    
    att = quaternion_dot(a, b)
    att = att.view(bsz, -1, al * bl)
    att = torch.sum(att, dim=1)
    
    return att.view(-1, al * bl)

In [55]:
def quaternion_dot_3d(q0, q1):
    d = q0.shape[2]
    sq = q0.shape[1]
    
    q0 = q0.view(-1, d)
    q1 = q1.view(-1, d)
    
    out = quaternion_dot(q0, q1)
    return out.view(-1, sq, d)

In [56]:
def quaternion_dot(q0, q1):
    """Quaternion product between 2 quaternions."""
    q1_r = get_r(q1)
    q1_i = get_i(q1)
    q1_j = get_j(q1)
    q1_k = get_k(q1)
    
    r_base = q0 * q1
    r = get_r(r_base) - get_i(r_base) - get_j(r_base) - get_k(r_base)
    
    i_base = q0 * torch.cat([q1_i, q1_r, q1_k, q1_j], dim=1)
    i = get_r(i_base) + get_i(i_base) + get_j(i_base) - get_k(i_base)
    
    j_base = q0 * torch.cat([q1_j, q1_k, q1_r, q1_i], dim=1)
    j = get_r(j_base) - get_i(j_base) + get_j(j_base) + get_k(j_base)
    
    k_base = q0 * torch.cat([q1_k, q1_j, q1_i, q1_r], dim=1)
    k = get_r(k_base) + get_i(k_base) - get_j(k_base) + get_k(k_base)
    
    return torch.cat([r, i, j, k], dim=1)

In [57]:
def quaternion_concat(x, dim):
    """Concatenates quaternion components individually."""
    output = [[] for _ in range(4)]
    for _x in x:
        sp = torch.chunk(_x, 4, dim=dim)
        for i in range(4):
            output[i].append(sp[i])
    
    final = []
    for o in output:
        o = torch.cat(o, dim)
        final.append(o)
    
    return torch.cat(final, dim)

In [58]:
def quaternion_ffn_3d(x, dim, num_layers=1, activation=None):
    """Quaternion Feed-forward layers to 3D input [bsz x seq_len x dim]."""
    print("QFFN layer..")
    _d = x.shape[2]
    sq = x.shape[1]
    
    x = x.view(-1, _d)
    x = quaternion_ffn(x, dim, num_layers=num_layers, activation=activation)
    return x.view(-1, sq, dim)

In [59]:
def quaternion_ffn(x, dim, num_layers=1, activation=None):
    """Implements quaternion feed-forward layer."""
    input_dim = x.shape[1] // 4
    kernel = torch.nn.Parameter(torch.randn(input_dim, dim))
    hamilton = make_quaternion_mul(kernel)
    
    output = torch.matmul(x, hamilton)
    if activation:
        output = activation(output)
    
    return output


In [60]:
def hamilton_product(x, kernel):
    h = make_quaternion_mul(kernel)
    return torch.matmul(x, h)

In [61]:
# Example
batch_size = 2
seq_length = 5
embed_dim = 8

# Create random input tensor (batch_size, seq_length, embed_dim)
x = torch.randn(batch_size, seq_length, embed_dim)

# Apply quaternion feed-forward network
ffn_output = quaternion_ffn_3d(x, dim=embed_dim)
print("Input:")
print(x)
print("\nFFN Output:")
print(ffn_output)


QFFN layer..
Input:
tensor([[[-0.3506,  0.0970,  1.3514, -0.6169,  1.3352, -1.4132,  0.1767,
           0.8505],
         [ 0.7452,  0.0193,  0.2961, -0.2805, -0.7586,  0.3983, -1.3985,
          -0.0962],
         [-0.0128, -2.1986, -0.5678,  0.4595, -0.6890,  0.6323,  1.1189,
           1.2924],
         [-0.1032,  2.8755, -0.7394,  0.2886, -0.1981,  0.4650, -0.5502,
          -0.4260],
         [ 0.1850, -0.4664,  2.1679,  0.5651, -1.1531, -0.7372, -0.4274,
           1.3448]],

        [[ 0.3386,  0.9734,  1.2382,  0.6949, -1.7217, -0.8094,  0.1479,
          -2.2291],
         [ 1.9987,  0.8493,  0.2982,  2.4384, -0.3508,  0.6973,  1.3037,
           1.2777],
         [ 0.4218, -0.4563,  0.1082, -1.4774,  0.2493,  0.1442,  1.3681,
          -0.6411],
         [ 0.2299,  0.1423,  0.8334, -0.8379, -0.0913, -1.3382,  0.0513,
          -0.6998],
         [-1.6995,  1.1139,  0.8475,  0.3015, -0.5085, -1.4477,  0.6160,
          -0.9741]]])

FFN Output:
tensor([[[ 3.5576, -2.2303, -2.21