In [2]:
import torch

In [3]:
n = 10
d = 64
sqrt_d = torch.sqrt(torch.tensor(d)).int().item()

# Downsampling

In [4]:
data = torch.rand(d, n)
data.shape

torch.Size([64, 10])

In [5]:
weights = torch.rand(sqrt_d, sqrt_d, n)
weights.shape

torch.Size([8, 8, 10])

In [7]:
data_reshaped = data.view(sqrt_d, sqrt_d, n)
data_reshaped.shape

torch.Size([8, 8, 10])

In [8]:
product = data_reshaped * weights
product.shape

torch.Size([8, 8, 10])

In [9]:
# perform sum across in each window
updated_product = torch.sum(product, dim=1, keepdim=True)
print(updated_product.shape)
updated_product = updated_product.squeeze(1)
print(updated_product.shape)  # finally we have converted from dxn to sqrt(d)xn

torch.Size([8, 1, 10])
torch.Size([8, 10])


# Upsampling

In [12]:
attention_output = updated_product.clone()
print(attention_output.shape)

torch.Size([8, 10])


In [13]:
up_weights = torch.randn(sqrt_d, sqrt_d, n)
print(up_weights.shape)

torch.Size([8, 8, 10])


In [15]:
attention_output_reshaped = attention_output.view(1, -1)
print(attention_output_reshaped.shape)
attention_output_reshaped = attention_output_reshaped.repeat(sqrt_d, 1)  # repeat each row sqrt_d times
print(attention_output_reshaped.shape)
attention_output_reshaped = attention_output_reshaped.view(up_weights.shape)
print(attention_output_reshaped.shape)

torch.Size([1, 80])
torch.Size([8, 80])
torch.Size([8, 8, 10])


In [16]:
# Now multiply the reshaped data with weights
revived_output = up_weights * attention_output_reshaped
revived_output.shape

torch.Size([8, 8, 10])

# Quick Attention

In [22]:
q = torch.randn(d, n)
k = torch.randn(d, n)
print(f'q: {q.shape}')
print(f'k: {k.shape}')

q: torch.Size([64, 10])
k: torch.Size([64, 10])


In [23]:
collective_k = k.sum(1, keepdim=True)
collective_k.shape

torch.Size([64, 1])

In [24]:
# Broadcast explicitly
collective_k_bc = collective_k.repeat(1, n)
collective_k_bc.shape

torch.Size([64, 10])

In [25]:
# q multiply k
qk = q * collective_k_bc
qk.shape

torch.Size([64, 10])

In [26]:
attention_weights = torch.softmax(qk, dim=1)
attention_weights.shape

torch.Size([64, 10])

In [27]:
v = torch.randn(d, n)
v.shape

torch.Size([64, 10])

In [29]:
collective_v = v.sum(dim=1, keepdim=True)
collective_v.shape

torch.Size([64, 1])

In [30]:
collective_v_bc = collective_v.repeat(1, n)
collective_v_bc.shape

torch.Size([64, 10])

In [32]:
output = collective_v_bc * attention_weights
print(output.shape)

torch.Size([64, 10])
