In [27]:
# Borrowed Code From : https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/attention.py

import torch
from torch import nn, einsum
from einops import rearrange, repeat

def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads
        
        # q, k, v linear layers
        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)
        print('q.shape :', q.shape,
              'k.shape :', k.shape,
              'v.shape :', v.shape)
        
        # multi-head attention
        # b : batch, n : length, h : head, d : dim
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        print('q.shape :', q.shape,
              'k.shape :', k.shape,
              'v.shape :', v.shape)

        # get similarity by QK^T / \sqrt(d_k)
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        print('sim.shape :', sim.shape)

        if exists(mask):
            # (b, n)
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            # (b, n) -> (b*h, 1, n)
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            # set masked regions to -inf.
            sim.masked_fill_(~mask, max_neg_value)

        # attention
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

In [28]:
cross_attention = CrossAttention(query_dim=512, context_dim=512, heads=8, dim_head=64, dropout=0.)

source_length = 10
target_length = 20
x = torch.randn(2, source_length, 512)
context = torch.randn(2, target_length, 512)
y = cross_attention(x, context)
print(y.shape)

q.shape : torch.Size([2, 10, 512]) k.shape : torch.Size([2, 20, 512]) v.shape : torch.Size([2, 20, 512])
q.shape : torch.Size([16, 10, 64]) k.shape : torch.Size([16, 20, 64]) v.shape : torch.Size([16, 20, 64])
sim.shape : torch.Size([16, 10, 20])
torch.Size([2, 10, 512])
