
<p align="center">
 <img src=".../../img/pseudo-attention.svg" width="66%" >
</p>


image: ../../img/pseudo-attention.png

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F


####################################################################

def evaluate(device, model, dataloader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

####################################################################

def train(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_loader.dataset)
    accuracy = 100. * correct / total
    return train_loss, accuracy

################################################################

def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """

    return (nn.Softmax(dim=-1)(
                torch.matmul( 
                            Q.transpose(-1,-2) , 
                            K
                )/torch.sqrt(torch.tensor(8))    
            ) @ V).squeeze(-1)


################################################################


class LinW_Attention_Module(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(LinW_Attention_Module, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_Q = nn.Linear(dim_emb, dim_emb)
        self.W_K = nn.Linear(dim_emb, dim_emb)
        self.W_V = nn.Linear(dim_emb, dim_emb)

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V).reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)


####################################################################

def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

####################################################################

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))

####################################################################


class MLPWD(nn.Module):
    def __init__(self):
        super(MLPWD, self).__init__()
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(784, 8)
        self.l2 = LinW(in_features=8, out_features=8, depth=0)
        self.l3 = LinW(in_features=8, out_features=8, depth=1, layers=[self.l2])
        self.l4 = nn.Linear(8, 10)
        self.gelu = nn.GELU()
        self.layers = [self.l2, self.l3]

    def forward(self, x):
        repr = []
        x = self.flatten(x)
        x = self.l1(x)
        repr.append(x)
        x = self.gelu(x)
        x = self.l2(x, repr)
        repr.append(x)
        x = self.gelu(x)
        x = self.gelu(self.l3(x, repr))
        x = self.l4(x)
        return x
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __len__(self):
        return len(self.layers)
    

class LinW(nn.Linear):
    def __init__(self, in_features, out_features, depth, layers=[]):
        super(LinW, self).__init__(in_features=in_features, out_features=out_features)
        self.depth = depth
        self.layers = layers[:self.depth] if len(layers)>0 else layers
        self.mha = LinW_Attention_Module(in_features, 4)

    def forward(self, input, activations=[]):
        activations = get_layer_activations(activations)
        return F.linear(self.mha(activations, activations, input), self.weight, self.bias)

EPOCHS = 10
BATCH_SIZE = 16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = MLPWD().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print("LinW layers:", "\n".join([f"Depth {model[i].depth}: {model[i]}" for i in range(len(model))]), sep="\n\n")

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(device, model, test_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%, Test accuracy: {test_accuracy:.2f}%')
    lr_scheduler.step()
