In [53]:
import torch
from functools import partial
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F

from einops import rearrange

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

# regular attention

def attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    **kwargs
):
    scale = q.shape[-1] ** -0.5
    q = q * scale
    
    print(q.shape, k.shape)
    print(q[0][0][0].dot(k[0][0][0]))
    print(q[0][0][0].dot(k[0][0][1]))
    print(q[0][0][1].dot(k[0][0][0]))
    print(q[0][0][1].dot(k[0][0][1]))
    sim = einsum('b h i d, b h j d -> b h i j', q, k)
    print(sim, sim.shape, sim[0][0][0], F.softmax(sim[0][0][0]))
    # sim = sim - sim.amax(dim = -1, keepdim = True).detach()
    attn = sim.softmax(dim = -1)
    print(attn)

    out = einsum('b h i j, b h j d -> b h i d', attn, v)
    return out

# main class

class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False,
        memory_efficient = False,
        q_bucket_size = 512,
        k_bucket_size = 1024
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal
        self.dropout = dropout
        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # memory efficient attention related parameters
        # can be overriden on forward
        self.memory_efficient = memory_efficient
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None,
        memory_efficient = None,
        q_bucket_size = None,
        k_bucket_size = None,
    ):
        memory_efficient = default(memory_efficient, self.memory_efficient)
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)

        h = self.heads
        context = default(context, x)

        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        attn_fn = attention

        out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, 
                    k_bucket_size = k_bucket_size, dropout = self.dropout, training = self.training)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
cross_attn = Attention(
    dim = 8,
    dim_head = 8,
    heads = 1,
    memory_efficient = False,
    q_bucket_size = 1024,
    k_bucket_size = 2048
).cuda()

x = torch.randn(1, 3, 8).cuda()
context = torch.randn(1, 3, 8).cuda()

out = cross_attn(x, context = context) # (1, 65536, 512)

torch.Size([1, 1, 3, 8]) torch.Size([1, 1, 3, 8])
tensor(-0.3593, device='cuda:0', grad_fn=<DotBackward0>)
tensor(-0.1886, device='cuda:0', grad_fn=<DotBackward0>)
tensor(0.2340, device='cuda:0', grad_fn=<DotBackward0>)
tensor(0.4919, device='cuda:0', grad_fn=<DotBackward0>)
tensor([[[[-0.3593, -0.1886, -0.2142],
          [ 0.2340,  0.4919, -0.2945],
          [ 0.1237,  0.2328, -0.0557]]]], device='cuda:0',
       grad_fn=<ViewBackward0>) torch.Size([1, 1, 3, 3]) tensor([-0.3593, -0.1886, -0.2142], device='cuda:0', grad_fn=<SelectBackward0>) tensor([0.2992, 0.3549, 0.3459], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([[[[0.2992, 0.3549, 0.3459],
          [0.3468, 0.4488, 0.2044],
          [0.3389, 0.3779, 0.2832]]]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)


  print(sim, sim.shape, sim[0][0][0], F.softmax(sim[0][0][0]))


torch.Size([1, 1, 2, 8]) torch.Size([1, 1, 2, 8])


RuntimeError: 1D tensors expected, but got 0D and 0D tensors