# MHA attention mechanism applied over the activations of the previous layer

For the seek of those experiment I assume that all the layer have the same number of neurons. In this way we can apply directly the MHA without sampling the activations. This assumption simplify the implementation and the understanding of the results. However, it would be possible to apply this mechanism even if in between layers the number of neurons changes.

In [57]:
import torch
import torch.nn as nn
import time



In [59]:
# As we could see here we parallelize over the number of heads

# shape: (batch_size, num_heads, num_layer, activation_size)
attn = torch.randn(16, 10, 8, 8)

# activations in the current layer
a = torch.randn(16, 10, 8, 1)

print()
print('Without softmax: ')
# check that the matrix multiplication is correct
print(((torch.matmul( 
                    attn , 
                    attn.transpose(-1,-2)
            )/torch.sqrt(torch.tensor(8))    
        ) @ a).squeeze(-1)[0][0])


print(((attn[0][0]@attn[0][0].T/torch.sqrt(torch.tensor(8))) @ a[0][0]).squeeze(-1))


print()
print('With softmax: ')
# check that the matrix multiplication is correct
# with the softmax on the last dimension
print((nn.Softmax(dim=-1)(
                torch.matmul( 
                            attn , 
                            attn.transpose(-1,-2)
                )/torch.sqrt(torch.tensor(8))    
            ) @ a).squeeze(-1)[0][0])

print((nn.Softmax(dim=-1)(
                torch.matmul( 
                            attn[0][0] , 
                            attn[0][0].transpose(-1,-2)
                )/torch.sqrt(torch.tensor(8))    
            ) @ a[0][0]).squeeze(-1))




Without softmax: 
tensor([-2.8145,  9.0366,  3.9772,  3.2464,  1.7831,  4.6573,  4.3377,  1.5610])
tensor([-2.8145,  9.0366,  3.9772,  3.2464,  1.7831,  4.6573,  4.3377,  1.5610])

With softmax: 
tensor([0.3660, 1.0815, 1.0566, 1.1317, 0.9139, 0.5136, 0.7670, 0.5848])
tensor([0.3660, 1.0815, 1.0566, 1.1317, 0.9139, 0.5136, 0.7670, 0.5848])


In [90]:
# reshape to finally flatten over the number of heads
# and get the final result of the attention
a = torch.randn(1, 4, 2)

print(a)
print(a.shape)

print(a.reshape(1, 2*4))

tensor([[[ 1.1409, -1.6327],
         [-0.9834, -1.2864],
         [-0.5916,  0.3915],
         [-0.3678,  0.4640]]])
torch.Size([1, 4, 2])
tensor([[ 1.1409, -1.6327, -0.9834, -1.2864, -0.5916,  0.3915, -0.3678,  0.4640]])


In [91]:
def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """

    return (nn.Softmax(dim=-1)(
                torch.matmul( 
                            Q , 
                            K.transpose(-1,-2)
                )/torch.sqrt(torch.tensor(8))    
            ) @ V).squeeze(-1)



# shape: (batch_size, num_heads, num_layer, activation_size)
attn = torch.randn(16, 10, 8, 8)

# activations in the current layer
a = torch.randn(16, 10, 8, 1)

nn.Linear(80,8)(head_batched_attention_mechanism(attn, attn, a).reshape(16, 10*8)).shape

torch.Size([16, 8])

In [98]:
def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """

    return (nn.Softmax(dim=-1)(
                torch.matmul( 
                            Q , 
                            K.transpose(-1,-2)
                )/torch.sqrt(torch.tensor(8))    
            ) @ V).squeeze(-1)



class MultiHeadAttention(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(MultiHeadAttention, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_Q = nn.Linear(dim_emb, dim_emb)
        self.W_K = nn.Linear(dim_emb, dim_emb)
        self.W_V = nn.Linear(dim_emb, dim_emb)

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V).reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)




MultiHeadAttention(8, 4)(torch.randn(16, 8, 8), torch.randn(16, 8, 8), torch.randn(16, 8)).shape

torch.Size([16, 8])

torch.Size([16, 10, 8, 8])

In [None]:
torch.stack

tensor([[[-0.2142, -0.2232],
         [ 0.4802, -0.6151],
         [ 1.0031,  0.8155],
         [ 1.4394,  0.9198]]])
torch.Size([1, 4, 2])
tensor([[-0.2142, -0.2232,  0.4802, -0.6151,  1.0031,  0.8155,  1.4394,  0.9198]])


In [41]:
def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """

    return (nn.Softmax(dim=-1)(
                torch.matmul( 
                            Q.transpose(-1,-2) , 
                            K
                )/torch.sqrt(torch.tensor(8))    
            ) @ V).squeeze(-1)



class MultiHeadAttention(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(MultiHeadAttention, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_Q = nn.Linear(dim_emb, dim_emb)
        self.W_K = nn.Linear(dim_emb, dim_emb)
        self.W_V = nn.Linear(dim_emb, dim_emb)

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V).reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)
    



def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))


# # I have the batched activations for each layer for each sample
# layers = torch.stack([torch.randn(batch_size, number_activations) for _ in range(10)])
# print(layers.shape, '(num_layers, batch_size, number_activations)')

# # I get the activations for each layer for each sample
# obj_activations = torch.stack([layers[:,i,:] for i in range(layers.shape[1])])
# print(obj_activations.shape, '(nr_object, num_layers, activation_for_each_layer)')


batch_size = 16
number_activations = 8

a = torch.randn(batch_size, number_activations)


attn = get_layer_activations([torch.randn(batch_size, number_activations) for _ in range(10)])


MultiHeadAttention(8, 4)(attn, attn, a)

# list(extract_activations_per_sample(
#             extract_activations_layers(layers), 
#             mask=False
#         )
# )

torch.Size([16, 8])

In [26]:
activations = [torch.randn(batch_size, number_activations) for _ in range(10)]

activations[0].size()

torch.Size([16, 8])

In [49]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F


####################################################################

def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))

####################################################################


class MLPWD(nn.Module):
    def __init__(self):
        super(MLPWD, self).__init__()
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(784, 8)
        self.l2 = LinW(in_features=8, out_features=8, depth=0)
        self.l3 = LinW(in_features=8, out_features=8, depth=1, layers=[self.l2])
        self.l4 = nn.Linear(8, 10)
        self.gelu = nn.GELU()
        self.layers = [self.l2, self.l3]

    def forward(self, x):
        repr = []
        x = self.flatten(x)
        x = self.l1(x)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        # x = self.gelu(self.l2(x, repr))
        x = self.l2(x, repr)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        x = self.gelu(self.l3(x, repr))
        x = self.l4(x)
        return x
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __len__(self):
        return len(self.layers)
    

class LinW(nn.Linear):
    def __init__(self, in_features, out_features, depth, layers=[]):
        super(LinW, self).__init__(in_features=in_features, out_features=out_features)
        self.depth = depth
        self.layers = layers[:self.depth] if len(layers)>0 else layers
        self.mha = MultiHeadAttention(in_features, 4)

    def forward(self, input, activations=[]):
        activations = get_layer_activations(activations)
        return F.linear(self.mha(activations, activations, input), self.weight, self.bias)

EPOCHS = 10
BATCH_SIZE = 16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = MLPWD().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print("LinW layers:", "\n".join([f"Depth {model[i].depth}: {model[i]}" for i in range(len(model))]), sep="\n\n")

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(device, model, test_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%, Test accuracy: {test_accuracy:.2f}%')
    lr_scheduler.step()


LinW layers:

Depth 0: LinW(
  in_features=8, out_features=8, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=8, out_features=8, bias=True)
    (W_K): Linear(in_features=8, out_features=8, bias=True)
    (W_V): Linear(in_features=8, out_features=8, bias=True)
    (W_O): Linear(in_features=32, out_features=8, bias=True)
  )
)
Depth 1: LinW(
  in_features=8, out_features=8, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=8, out_features=8, bias=True)
    (W_K): Linear(in_features=8, out_features=8, bias=True)
    (W_V): Linear(in_features=8, out_features=8, bias=True)
    (W_O): Linear(in_features=32, out_features=8, bias=True)
  )
)
Epoch 1/10, Training Loss: 0.8250, Training Accuracy: 71.66%, Test accuracy: 85.08%
Epoch 2/10, Training Loss: 0.4606, Training Accuracy: 86.76%, Test accuracy: 87.75%
Epoch 3/10, Training Loss: 0.3847, Training Accuracy: 89.09%, Test accuracy: 89.36%
Epoch 4/10, Training Loss: 0.3413, Training Accuracy: 90.31%, Test

In [50]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F


####################################################################

def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))

####################################################################


class MLPWD(nn.Module):
    def __init__(self):
        super(MLPWD, self).__init__()
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(784, 8)
        self.l2 = LinW(in_features=8, out_features=8, depth=0)
        self.l3 = LinW(in_features=8, out_features=8, depth=1, layers=[self.l2])
        self.l4 = nn.Linear(8, 10)
        self.gelu = nn.GELU()
        self.layers = [self.l2, self.l3]

    def forward(self, x):
        repr = []
        x = self.flatten(x)
        x = self.l1(x)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        # x = self.gelu(self.l2(x, repr))
        x = self.l2(x, repr)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        x = self.gelu(self.l3(x, repr))
        x = self.l4(x)
        return x
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __len__(self):
        return len(self.layers)
    

class LinW(nn.Linear):
    def __init__(self, in_features, out_features, depth, layers=[]):
        super(LinW, self).__init__(in_features=in_features, out_features=out_features)
        self.depth = depth
        self.layers = layers[:self.depth] if len(layers)>0 else layers
        self.mha = MultiHeadAttention(in_features, 2)

    def forward(self, input, activations=[]):
        activations = get_layer_activations(activations)
        return F.linear(self.mha(activations, activations, input), self.weight, self.bias)

EPOCHS = 5
BATCH_SIZE = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = MLPWD().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print("LinW layers:", "\n".join([f"Depth {model[i].depth}: {model[i]}" for i in range(len(model))]), sep="\n\n")

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(device, model, test_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%, Test accuracy: {test_accuracy:.2f}%')
    lr_scheduler.step()


LinW layers:

Depth 0: LinW(
  in_features=8, out_features=8, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=8, out_features=8, bias=True)
    (W_K): Linear(in_features=8, out_features=8, bias=True)
    (W_V): Linear(in_features=8, out_features=8, bias=True)
    (W_O): Linear(in_features=16, out_features=8, bias=True)
  )
)
Depth 1: LinW(
  in_features=8, out_features=8, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=8, out_features=8, bias=True)
    (W_K): Linear(in_features=8, out_features=8, bias=True)
    (W_V): Linear(in_features=8, out_features=8, bias=True)
    (W_O): Linear(in_features=16, out_features=8, bias=True)
  )
)
Epoch 1/5, Training Loss: 0.9245, Training Accuracy: 69.55%, Test accuracy: 83.89%
Epoch 2/5, Training Loss: 0.4443, Training Accuracy: 87.33%, Test accuracy: 88.97%
Epoch 3/5, Training Loss: 0.3844, Training Accuracy: 89.08%, Test accuracy: 90.16%
Epoch 4/5, Training Loss: 0.3439, Training Accuracy: 90.35%, Test acc