In [6]:
import torch
import torch.nn.functional as F
import numpy as np
import math

In [None]:
def make_quaternion_mul(kernel, concat_dim=0):
    r, i, j, k = torch.chunk(kernel, 4, dim=-1)
    r2 = torch.cat([r, -i, -j, -k], dim=-1)
    i2 = torch.cat([i, r, -k, j], dim=-1)
    j2 = torch.cat([j, k, r, -i], dim=-1)
    k2 = torch.cat([k, -j, i, r], dim=-1)
    hamilton = torch.cat([r2, i2, j2, k2], dim=concat_dim)
    return hamilton


In [None]:
def get_r(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[0]

def get_i(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[1]

def get_j(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[2]

def get_k(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[3]


In [None]:
def quaternion_attention(a, b):
    """Performs dot product attention between two quaternion sequences."""
    print("light Attention!")
    print(a)
    print(b)
    
    ar, ai, aj, ak = torch.chunk(a, 4, dim=-1)
    br, bi, bj, bk = torch.chunk(b, 4, dim=-1)
    
    r = torch.matmul(ar, br.transpose(-2, -1)) - torch.matmul(ai, bi.transpose(-2, -1)) - torch.matmul(aj, bj.transpose(-2, -1)) - torch.matmul(ak, bk.transpose(-2, -1))
    i = torch.matmul(ar, bi.transpose(-2, -1)) + torch.matmul(ai, br.transpose(-2, -1)) + torch.matmul(aj, bk.transpose(-2, -1)) - torch.matmul(ak, bj.transpose(-2, -1))
    j = torch.matmul(ar, bj.transpose(-2, -1)) - torch.matmul(ai, bk.transpose(-2, -1)) + torch.matmul(aj, br.transpose(-2, -1)) + torch.matmul(ak, bi.transpose(-2, -1))
    k = torch.matmul(ar, bk.transpose(-2, -1)) + torch.matmul(ai, bj.transpose(-2, -1)) - torch.matmul(aj, bi.transpose(-2, -1)) + torch.matmul(ak, br.transpose(-2, -1))
    
    return [r, i, j, k]


In [None]:
def quaternion_dot_product_att(a, b):
    """Wrapper for two sequences."""
    al = a.shape[1]
    bl = b.shape[1]
    d = a.shape[2]
    bsz = b.shape[0]
    
    a = a.view(-1, d)
    a = a.repeat(bl, 1)
    b = b.view(-1, d)
    b = b.repeat(al, 1)
    
    att = quaternion_dot(a, b)
    att = att.view(bsz, -1, al * bl)
    att = torch.sum(att, dim=1)
    
    return att.view(-1, al * bl)

In [None]:
def quaternion_dot_3d(q0, q1):
    d = q0.shape[2]
    sq = q0.shape[1]
    
    q0 = q0.view(-1, d)
    q1 = q1.view(-1, d)
    
    out = quaternion_dot(q0, q1)
    return out.view(-1, sq, d)

In [None]:
def quaternion_dot(q0, q1):
    """Quaternion product between 2 quaternions."""
    q1_r = get_r(q1)
    q1_i = get_i(q1)
    q1_j = get_j(q1)
    q1_k = get_k(q1)
    
    r_base = q0 * q1
    r = get_r(r_base) - get_i(r_base) - get_j(r_base) - get_k(r_base)
    
    i_base = q0 * torch.cat([q1_i, q1_r, q1_k, q1_j], dim=1)
    i = get_r(i_base) + get_i(i_base) + get_j(i_base) - get_k(i_base)
    
    j_base = q0 * torch.cat([q1_j, q1_k, q1_r, q1_i], dim=1)
    j = get_r(j_base) - get_i(j_base) + get_j(j_base) + get_k(j_base)
    
    k_base = q0 * torch.cat([q1_k, q1_j, q1_i, q1_r], dim=1)
    k = get_r(k_base) + get_i(k_base) - get_j(k_base) + get_k(k_base)
    
    return torch.cat([r, i, j, k], dim=1)

In [None]:
def quaternion_concat(x, dim):
    """Concatenates quaternion components individually."""
    output = [[] for _ in range(4)]
    for _x in x:
        sp = torch.chunk(_x, 4, dim=dim)
        for i in range(4):
            output[i].append(sp[i])
    
    final = []
    for o in output:
        o = torch.cat(o, dim)
        final.append(o)
    
    return torch.cat(final, dim)

In [None]:
def quaternion_ffn_3d(x, dim, num_layers=1, activation=None):
    """Quaternion Feed-forward layers to 3D input [bsz x seq_len x dim]."""
    print("QFFN layer..")
    _d = x.shape[2]
    sq = x.shape[1]
    
    x = x.view(-1, _d)
    x = quaternion_ffn(x, dim, num_layers=num_layers, activation=activation)
    return x.view(-1, sq, dim)

In [None]:
def quaternion_ffn(x, dim, num_layers=1, activation=None):
    """Implements quaternion feed-forward layer."""
    input_dim = x.shape[1] // 4
    kernel = torch.nn.Parameter(torch.randn(input_dim, dim))
    hamilton = make_quaternion_mul(kernel)
    
    output = torch.matmul(x, hamilton)
    if activation:
        output = activation(output)
    
    return output


In [4]:
def hamilton_product(x, kernel):
    h = make_quaternion_mul(kernel)
    return torch.matmul(x, h)

In [5]:
# Example
batch_size = 2
seq_length = 5
embed_dim = 8

# Create random input tensor (batch_size, seq_length, embed_dim)
x = torch.randn(batch_size, seq_length, embed_dim)

# Apply quaternion feed-forward network
ffn_output = quaternion_ffn_3d(x, dim=embed_dim)
print("Input:")
print(x)
print("\nFFN Output:")
print(ffn_output)


QFFN layer..
Input:
tensor([[[-0.4064,  0.2875,  0.6615, -0.9883, -0.9946, -0.4287,  1.1147,
          -0.1516],
         [-0.8567, -0.8068, -0.3999,  1.5470, -0.4792, -0.7334,  0.3172,
          -0.5948],
         [-1.3753, -1.3093, -0.9861,  0.3092, -0.3756,  0.4984, -0.0749,
          -0.0754],
         [ 0.5036,  2.2707, -0.7800,  0.6164,  0.7807, -1.1588,  1.3719,
          -2.5887],
         [ 0.3680, -1.6838, -0.9956, -0.4719,  0.6074, -0.0849, -0.6777,
           1.5089]],

        [[ 0.7396,  0.4635, -0.3600,  0.2646,  1.5537,  0.8283,  0.9151,
           1.2679],
         [ 0.8087,  0.8427,  0.1244,  0.3937,  0.6388, -1.5309,  0.5833,
          -1.2415],
         [-1.1429, -1.2020, -0.6896, -0.1691,  0.3244, -1.3988, -1.2236,
          -0.3101],
         [ 1.5469,  0.0957,  1.2384, -0.4310,  1.6011,  1.6669,  0.5646,
           1.3517],
         [ 0.3872,  0.1423, -1.6086, -0.9811, -0.0216,  0.1643,  1.1826,
           1.5435]]])

FFN Output:
tensor([[[-0.7941, -0.6765, -2.36