In [8]:
import torch
import torch.nn.functional as F
from torch import nn


In [None]:
# qlib

In [None]:
def make_quaternion_mul(kernel, concat_dim=0):
    r, i, j, k = torch.chunk(kernel, 4, dim=-1)
    r2 = torch.cat([r, -i, -j, -k], dim=-1)
    i2 = torch.cat([i, r, -k, j], dim=-1)
    j2 = torch.cat([j, k, r, -i], dim=-1)
    k2 = torch.cat([k, -j, i, r], dim=-1)
    hamilton = torch.cat([r2, i2, j2, k2], dim=concat_dim)
    return hamilton

def get_r(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[0]

def get_i(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[1]

def get_j(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[2]

def get_k(x, dim=1):
    return torch.chunk(x, 4, dim=dim)[3]

def quaternion_attention(a, b):
    """Performs dot product attention between two quaternion sequences."""
    print("light Attention!")
    print(a)
    print(b)
    
    ar, ai, aj, ak = torch.chunk(a, 4, dim=-1)
    br, bi, bj, bk = torch.chunk(b, 4, dim=-1)
    
    r = torch.matmul(ar, br.transpose(-2, -1)) - torch.matmul(ai, bi.transpose(-2, -1)) - torch.matmul(aj, bj.transpose(-2, -1)) - torch.matmul(ak, bk.transpose(-2, -1))
    i = torch.matmul(ar, bi.transpose(-2, -1)) + torch.matmul(ai, br.transpose(-2, -1)) + torch.matmul(aj, bk.transpose(-2, -1)) - torch.matmul(ak, bj.transpose(-2, -1))
    j = torch.matmul(ar, bj.transpose(-2, -1)) - torch.matmul(ai, bk.transpose(-2, -1)) + torch.matmul(aj, br.transpose(-2, -1)) + torch.matmul(ak, bi.transpose(-2, -1))
    k = torch.matmul(ar, bk.transpose(-2, -1)) + torch.matmul(ai, bj.transpose(-2, -1)) - torch.matmul(aj, bi.transpose(-2, -1)) + torch.matmul(ak, br.transpose(-2, -1))
    
    return [r, i, j, k]

def quaternion_dot_product_att(a, b):
    """Wrapper for two sequences."""
    al = a.shape[1]
    bl = b.shape[1]
    d = a.shape[2]
    bsz = b.shape[0]
    
    a = a.view(-1, d)
    a = a.repeat(bl, 1)
    b = b.view(-1, d)
    b = b.repeat(al, 1)
    
    att = quaternion_dot(a, b)
    att = att.view(bsz, -1, al * bl)
    att = torch.sum(att, dim=1)
    
    return att.view(-1, al * bl)

def quaternion_dot_3d(q0, q1):
    d = q0.shape[2]
    sq = q0.shape[1]
    
    q0 = q0.view(-1, d)
    q1 = q1.view(-1, d)
    
    out = quaternion_dot(q0, q1)
    return out.view(-1, sq, d)

def quaternion_dot(q0, q1):
    """Quaternion product between 2 quaternions."""
    q1_r = get_r(q1)
    q1_i = get_i(q1)
    q1_j = get_j(q1)
    q1_k = get_k(q1)
    
    r_base = q0 * q1
    r = get_r(r_base) - get_i(r_base) - get_j(r_base) - get_k(r_base)
    
    i_base = q0 * torch.cat([q1_i, q1_r, q1_k, q1_j], dim=1)
    i = get_r(i_base) + get_i(i_base) + get_j(i_base) - get_k(i_base)
    
    j_base = q0 * torch.cat([q1_j, q1_k, q1_r, q1_i], dim=1)
    j = get_r(j_base) - get_i(j_base) + get_j(j_base) + get_k(j_base)
    
    k_base = q0 * torch.cat([q1_k, q1_j, q1_i, q1_r], dim=1)
    k = get_r(k_base) + get_i(k_base) - get_j(k_base) + get_k(k_base)
    
    return torch.cat([r, i, j, k], dim=1)

def quaternion_concat(x, dim):
    """Concatenates quaternion components individually."""
    output = [[] for _ in range(4)]
    for _x in x:
        sp = torch.chunk(_x, 4, dim=dim)
        for i in range(4):
            output[i].append(sp[i])
    
    final = []
    for o in output:
        o = torch.cat(o, dim)
        final.append(o)
    
    return torch.cat(final, dim)

def quaternion_ffn_3d(x, dim, num_layers=1, activation=None):
    """Quaternion Feed-forward layers to 3D input [bsz x seq_len x dim]."""
    print("QFFN layer..")
    _d = x.shape[2]
    sq = x.shape[1]
    
    x = x.view(-1, _d)
    x = quaternion_ffn(x, dim, num_layers=num_layers, activation=activation)
    return x.view(-1, sq, dim)

def quaternion_ffn(x, dim, num_layers=1, activation=None):
    """Implements quaternion feed-forward layer."""
    input_dim = x.shape[1] // 4
    kernel = torch.nn.Parameter(torch.randn(input_dim, dim))
    hamilton = make_quaternion_mul(kernel)
    
    output = torch.matmul(x, hamilton)
    if activation:
        output = activation(output)
    
    return output

def hamilton_product(x, kernel):
    h = make_quaternion_mul(kernel)
    return torch.matmul(x, h)



In [2]:
#check point

In [None]:
def compute_attention_component(antecedent,
                                total_depth,
                                filter_width=1,
                                padding="valid",
                                name="c",
                                vars_3d_num_heads=0):
    input_depth = antecedent.size(-1)
    initializer_stddev = input_depth ** -0.5
    if "q" in name:
        depth_per_head = total_depth
        initializer_stddev *= depth_per_head ** -0.5

    if vars_3d_num_heads > 0:
        assert filter_width == 1
        input_depth = antecedent.size(-1)
        depth_per_head = total_depth // vars_3d_num_heads
        initializer_stddev = input_depth ** -0.5
        if "q" in name:
            initializer_stddev *= depth_per_head ** -0.5
        var = nn.Parameter(torch.randn(input_depth, vars_3d_num_heads, total_depth // vars_3d_num_heads) * initializer_stddev)
        var = var.to(antecedent.dtype)
        var = var.view(input_depth, total_depth)
        return torch.einsum('bld,df->blf', antecedent, var)
    
    if filter_width == 1:
        return quarternion_ffn_3d(antecedent, total_depth, name=name,
                                  init=torch.nn.init.normal_(torch.empty(input_depth, total_depth), mean=0, std=initializer_stddev))
    else:
        return F.conv1d(antecedent.permute(0, 2, 1), 
                        torch.randn(total_depth, input_depth, filter_width) * initializer_stddev, 
                        padding=padding).permute(0, 2, 1)


In [None]:
def compute_qkv(query_antecedent,
                memory_antecedent,
                total_key_depth,
                total_value_depth,
                q_filter_width=1,
                kv_filter_width=1,
                q_padding="valid",
                kv_padding="valid",
                vars_3d_num_heads=0):
    if memory_antecedent is None:
        memory_antecedent = query_antecedent
    q = compute_attention_component(
        query_antecedent,
        total_key_depth,
        q_filter_width,
        q_padding,
        "q",
        vars_3d_num_heads=vars_3d_num_heads)
    k = compute_attention_component(
        memory_antecedent,
        total_key_depth,
        kv_filter_width,
        kv_padding,
        "k",
        vars_3d_num_heads=vars_3d_num_heads)
    v = compute_attention_component(
        memory_antecedent,
        total_value_depth,
        kv_filter_width,
        kv_padding,
        "v",
        vars_3d_num_heads=vars_3d_num_heads)
    return q, k, v

In [None]:
def split_heads(x, num_heads):
    batch_size, length, depth = x.size()
    depth_per_head = depth // num_heads
    x = x.view(batch_size, length, num_heads, depth_per_head)
    return x.permute(0, 2, 1, 3)


In [None]:
def combine_heads(x):
    batch_size, num_heads, length, depth_per_head = x.size()
    return x.permute(0, 2, 1, 3).contiguous().view(batch_size, length, num_heads * depth_per_head)

In [None]:
def dot_product_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None,
                          save_weights_to=None, make_image_summary=True,
                          dropout_broadcast_dims=None):
    logits = torch.matmul(q, k.transpose(-2, -1))
    if bias is not None:
        logits += bias
    weights = F.softmax(logits, dim=-1)
    if dropout_rate > 0.0:
        weights = F.dropout(weights, p=dropout_rate)
    return torch.matmul(weights, v)

In [None]:
def quaternion_dot_product_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None,
                                     save_weights_to=None, make_image_summary=True,
                                     dropout_broadcast_dims=None):
    print("Using QDP attention..")
    logits = torch.matmul(q, k.transpose(-2, -1))
    if bias is not None:
        logits += bias
    weights = F.softmax(logits, dim=-1)
    if dropout_rate > 0.0:
        weights = F.dropout(weights, p=dropout_rate)
    return torch.matmul(weights, v)


In [None]:
#def quarternion_ffn_3d(x, output_depth, name='output_transform', init=None):
#    return F.linear(x, init)


In [None]:
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        attention_type="dot_product",
                        max_relative_position=None,
                        heads_share_relative_embedding=False,
                        add_relative_to_values=False,
                        image_shapes=None,
                        block_length=128,
                        block_width=128,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="valid",
                        kv_padding="valid",
                        cache=None,
                        gap_size=0,
                        num_memory_blocks=2,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        vars_3d=False,
                        is_training=False,
                        **kwargs):
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of attention heads (%d)." % (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of attention heads (%d)." % (total_value_depth, num_heads))

    vars_3d_num_heads = num_heads if vars_3d else 0
    
    if cache is None or memory_antecedent is None:
        q, k, v = compute_qkv(query_antecedent, memory_antecedent,
                              total_key_depth, total_value_depth, q_filter_width,
                              kv_filter_width, q_padding, kv_padding,
                              vars_3d_num_heads=vars_3d_num_heads)
    if cache is not None:
        if attention_type not in ["dot_product", "dot_product_relative", "quaternion_dot_product"]:
            raise NotImplementedError("Caching is not guaranteed to work with attention types other than dot_product.")
        if bias is None:
            raise ValueError("Bias required for caching. See function docstring for details.")

        if memory_antecedent is not None:
            q = compute_attention_component(query_antecedent, total_key_depth,
                                            q_filter_width, q_padding, "q",
                                            vars_3d_num_heads=vars_3d_num_heads)
            k = cache["k_encdec"]
            v = cache["v_encdec"]
        else:
            k = split_heads(k, num_heads)
            v = split_heads(v, num_heads)
            decode_loop_step = kwargs.get("decode_loop_step")
            if decode_loop_step is None:
                k = cache["k"] = torch.cat([cache["k"], k], dim=2)
                v = cache["v"] = torch.cat([cache["v"], v], dim=2)
            else:
                tmp_k = cache["k"].permute(2, 0, 1, 3)
                tmp_k[decode_loop_step] = k.squeeze(2)
                k = cache["k"] = tmp_k.permute(1, 2, 0, 3)
                tmp_v = cache["v"].permute(2, 0, 1, 3)
                tmp_v[decode_loop_step] = v.squeeze(2)
                v = cache["v"] = tmp_v.permute(1, 2, 0, 3)

    q = split_heads(q, num_heads)
    if cache is None:
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

    key_depth_per_head = total_key_depth // num_heads
    if not vars_3d:
        q *= key_depth_per_head**-0.5

    additional_returned_value = None
    if callable(attention_type):
        x = attention_type(q, k, v, **kwargs)
        if isinstance(x, tuple):
            x, additional_returned_value = x
    elif attention_type == "dot_product":
        x = dot_product_attention(q, k, v, bias, dropout_rate, image_shapes,
                                  save_weights_to=save_weights_to,
                                  make_image_summary=make_image_summary,
                                  dropout_broadcast_dims=dropout_broadcast_dims)
    elif attention_type == 'quaternion_dot_product':
        print("Using QDP attention..")
        x = quaternion_dot_product_attention(q, k, v, bias, dropout_rate, image_shapes,
                                             save_weights_to=save_weights_to,
                                             make_image_summary=make_image_summary,
                                             dropout_broadcast_dims=dropout_broadcast_dims)

    x = combine_heads(x)
    x = quarternion_ffn_3d(x, output_depth, name='output_transform',
                           init=torch.nn.init.normal_(torch.empty(output_depth, output_depth), mean=0, std=output_depth ** -0.5))

    if additional_returned_value is not None:
        return x, additional_returned_value
    return x


In [33]:

# Testing the updated multihead_attention function
query_antecedent = torch.randn(1, 2, 4)
memory_antecedent = torch.randn(1, 8, 4)
bias = None
total_key_depth = 4
total_value_depth = 4
output_depth = 4
num_heads = 2
dropout_rate = 0.1
attention_type = "quaternion_dot_product"
vars_3d = False

output = multihead_attention(query_antecedent, memory_antecedent, bias,
                             total_key_depth, total_value_depth, output_depth,
                             num_heads, dropout_rate, attention_type, vars_3d=vars_3d)
print(output.shape)


Using QDP attention..
Using QDP attention..
torch.Size([1, 2, 4])


In [34]:
output

tensor([[[-0.2630, -0.1027, -0.0266, -0.0309],
         [ 0.3399,  0.1480,  0.0696,  0.0332]]])