## Some experiments with different version of Attention

In [1]:
%load_ext autoreload

In [2]:
%autoreload
import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
# from tqdm.auto import tqdm
from einops import rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

In [19]:
# From annotated_diffusion 
#
class Attention_1(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        print('hidden_dim:', hidden_dim)
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        print('q shape:', q.shape, ', k shape:', k.shape, ', v shape:', v.shape)

        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        print('torch.max(sim):', torch.max(sim), ', torch.min(sim):', torch.min(sim))
        print('q*vt:  sim shape:', sim.shape)
        print('sim.amax:', sim.amax(dim=-1, keepdim=True).shape)

        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        print('torch.max(sim):', torch.max(sim), ', torch.min(sim):', torch.min(sim))
        attn = sim.softmax(dim=-1)
        print('attn shape:', attn.shape)
        print('torch.min(attn):', torch.min(attn), ', torch.max(attn):', torch.max(attn))
    

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        print('out shape:', out.shape)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        print('out shape:', out.shape)
        return self.to_out(out)


In [20]:
# From my code
class Attention_2(nn.Module):
    def __init__(self, dim, num_heads=4, dim_head=32, numgroups=8, dropout=0.):  
        super().__init__()        
        inner_dim = dim_head * num_heads
        print('inner_dim:', inner_dim)
        project_out = not (num_heads == 1 and dim_head == dim)
        self.heads = num_heads
        self.attention_norm = nn.GroupNorm(numgroups, dim)
        self.scale = float(dim_head) ** -0.5
        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  #use Conv2d instead of Linear????
        # self.attn_dropout = nn.Dropout(dropout)  # Don't do dropout
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),  #can use conv2d instead of Linear
            # nn.Dropout(dropout)   # Don't do dropout
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, c, h, w = x.shape
        in_attn = x.reshape(b, c, h * w)
        # GroupNorm applies only to the c channels, so the dimensions of the tensor 
        # after that is probably not important either way
        in_attn = self.attention_norm(in_attn) 
        print('in_attn shape:', in_attn.shape)
        in_attn = in_attn.transpose(1,2)
        print('in_attn shape:', in_attn.shape)
        qkv = self.to_qkv(in_attn).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        print('q shape:', q.shape, ', k shape:', k.shape, ', v shape:', v.shape)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        print('q*vt shape:', dots.shape)
        attn = self.attend(dots)
        # attn = self.attn_dropout(attn)  Don't do dropout
        out = torch.matmul(attn, v)
        print('out shape:', out.shape)
        out = rearrange(out, 'b h n d -> b n (h d)')
        print('out shape:', out.shape)
        out = self.to_out(out)
        print('out shape:', out.shape)
        out = out.transpose(1, 2).reshape(b, c, h, w)
        print('out shape:', out.shape)
        return out    

In [21]:
# Uses pytorch's MultiheadAttention

class Attention_3(nn.Module):
    def __init__(self, dim, num_heads=4, numgroups=8, dropout=0.):
        super().__init__()
        self.attention_norms = nn.GroupNorm(numgroups, dim)
        self.attentions = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)

    def forward(self, x):
        out = x
        # Attention block of Unet
        batch_size, channels, h, w = out.shape
        in_attn = out.reshape(batch_size, channels, h * w)
        in_attn = self.attention_norms(in_attn)
        in_attn = in_attn.transpose(1, 2)    #So, I guess: [N, (h*w), C] where (h*w) is the target "sequence length", and C is the embedding dimension
        out_attn, _ = self.attentions(in_attn, in_attn, in_attn)
        out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
        return out_attn


In [22]:
# From annotated_diffusion

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        print('hidden_dim:', hidden_dim)
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        print('q shape:', q.shape, ', k shape:', k.shape, ', v shape:', v.shape)

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
        print('context shape:', context.shape)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        print('out shape:', out.shape)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        print('out shape:', out.shape) 
        out = self.to_out(out)
        print('out shape:', out.shape) 
        return out

In [23]:
x = torch.randn([2, 64, 32, 32])
print('in x shape:', x.shape)

in x shape: torch.Size([2, 64, 32, 32])


In [24]:
attn1 = Attention_1(64, 4, 32)
out = attn1(x)
print(out.shape)

hidden_dim: 128
q shape: torch.Size([2, 4, 32, 1024]) , k shape: torch.Size([2, 4, 32, 1024]) , v shape: torch.Size([2, 4, 32, 1024])
torch.max(sim): tensor(2.0663, grad_fn=<MaxBackward1>) , torch.min(sim): tensor(-2.1919, grad_fn=<MinBackward1>)
q*vt:  sim shape: torch.Size([2, 4, 1024, 1024])
sim.amax: torch.Size([2, 4, 1024, 1])
torch.max(sim): tensor(0., grad_fn=<MaxBackward1>) , torch.min(sim): tensor(-4.1100, grad_fn=<MinBackward1>)
attn shape: torch.Size([2, 4, 1024, 1024])
torch.min(attn): tensor(9.2696e-05, grad_fn=<MinBackward1>) , torch.max(attn): tensor(0.0070, grad_fn=<MaxBackward1>)
out shape: torch.Size([2, 4, 1024, 32])
out shape: torch.Size([2, 128, 32, 32])
torch.Size([2, 64, 32, 32])


In [25]:
attn2 = Attention_2(64, 4, 32)
out = attn2(x)
print(out.shape)

inner_dim: 128
in_attn shape: torch.Size([2, 64, 1024])
in_attn shape: torch.Size([2, 1024, 64])
q shape: torch.Size([2, 4, 1024, 32]) , k shape: torch.Size([2, 4, 1024, 32]) , v shape: torch.Size([2, 4, 1024, 32])
q*vt shape: torch.Size([2, 4, 1024, 1024])
out shape: torch.Size([2, 4, 1024, 32])
out shape: torch.Size([2, 1024, 128])
out shape: torch.Size([2, 1024, 64])
out shape: torch.Size([2, 64, 32, 32])
torch.Size([2, 64, 32, 32])


In [17]:
attn3 = Attention_3(64, 4)
out = attn3(x)
print(out.shape)

torch.Size([2, 64, 32, 32])


In [18]:
attn4 = LinearAttention(64, 4, 32)
out = attn4(x)
print(out.shape)

hidden_dim: 128
q shape: torch.Size([2, 4, 32, 1024]) , k shape: torch.Size([2, 4, 32, 1024]) , v shape: torch.Size([2, 4, 32, 1024])
context shape: torch.Size([2, 4, 32, 32])
out shape: torch.Size([2, 4, 32, 1024])
out shape: torch.Size([2, 128, 32, 32])
out shape: torch.Size([2, 64, 32, 32])
torch.Size([2, 64, 32, 32])


In [None]:
dim = 256
heads = 4
dim_head =128
inner_dim = dim_head *  heads
numgroups = 8

x = torch.randn([2, 256, 32, 32])
print('in x shape:', x.shape)

b, c, h, w = x.shape
norm = nn.GroupNorm(numgroups, dim)
in_attn = norm(x)
in_attn = x.reshape(b, h * w, c)
# in_attn = in_attn.transpose(1, 2)  # reshape to [b, (h*w), c] i.e. [b, seq, emb_dim]
print('in_attn shape:', in_attn.shape)


to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print('to_qkv, mean:', torch.mean(to_qkv.weight.data), ', std:', torch.std(to_qkv.weight.data))
nn.init.normal_(to_qkv.weight.data, mean=0., std=np.sqrt(2 / (dim+inner_dim)))
print('to_qkv, mean:', torch.mean(to_qkv.weight.data), ', std:', torch.std(to_qkv.weight.data))
nn.init.xavier_normal_(to_qkv.weight.data)
print('to_qkv, mean:', torch.mean(to_qkv.weight.data), ', std:', torch.std(to_qkv.weight.data))


qkv = to_qkv(in_attn)
print('out shape:', qkv.shape)

qkv = qkv.chunk(3, dim = -1)
print('q shape:', qkv[0].shape)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print('q shape:', q.shape)


dots = torch.matmul(q, k.transpose(-1, -2)) 
print('dots shape:', dots.shape)

out = torch.matmul(dots, v)
print('1 out shape:', out.shape)

out = rearrange(out, 'b h n d -> b n (h d)')
print('2 out shape:', out.shape)

to_out = nn.Linear(inner_dim, dim)

out = to_out(out)
print('3 out shape:', out.shape)

out = out.transpose(1, 2).reshape(b, c, h, w)
print('4 out shape:', out.shape)
