![](../../img/pseudo-attention-crossingovermutations.png)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

####################################################################
### GLOBAL VARIABLES
CROSSOVER_MAGNITUDE = 0.3
MUTATION_FACTOR = 0.3

####################################################################

def evaluate(device, model, dataloader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

####################################################################

def train(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_loader.dataset)
    accuracy = 100. * correct / total
    return train_loss, accuracy

################################################################

def attention_map_crossover(attention_map):
    """ Apply the crossover over the attention maps of each head.
    The crosseover consists in picking a random index in the maxtrix over
    the columns and swapping the values in between the columns of the
    attention map.
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, number_of_heads, activation_size, activation_size)
        
    Returns:
        torch.Tensor: shape (batch_size, number_of_heads, activation_size, activation_size)
    
    """
    
    # get the crossover magnitude
    crossover_magnitude = CROSSOVER_MAGNITUDE
    
    # get the batch size
    dim_batch = attention_map.shape[0]
    
    # get the number of heads
    number_of_heads = attention_map.shape[1]
    
    for idx_batch in range(dim_batch):
        for idx_head in range(number_of_heads):
            
            # get the crossover index
            crossover_index = attention_map.shape[2] - int(attention_map.shape[2]*crossover_magnitude)
            
            # get two random indexes
            random_index_1 = torch.randint(0, attention_map.shape[2],(1,))[0]
            random_index_2 = torch.randint(0, attention_map.shape[2],(1,))[0]
            
            # swap the values in that position over the columns
            for idx, (x_1, x_2) in enumerate(zip(attention_map[idx_batch][idx_head][random_index_1][crossover_index:].detach(), attention_map[idx_batch][idx_head][random_index_2][crossover_index:].detach())):
                
                # debug
                # print(attention_map[idx_batch][idx_head].shape, random_index_1, random_index_2, crossover_index)         
                # print(x_1, x_2,idx_batch, idx_head, idx)
                
                # swap the values in that position over the columns
                attention_map[idx_batch][idx_head][random_index_1][crossover_index+idx] = x_2 # make crossover
                attention_map[idx_batch][idx_head][random_index_2][crossover_index+idx] = x_1 # make crossover
    
    return attention_map

################################################################

def mutate_attention_map(attention_map):
    """ Mutate the attention map by making an elementwise multiplication with 
    a random tensor with values between 1-mutation_factor and 1+mutation_factor
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, num_heads, activation_size, activation_size)
        
    Returns:    
        torch.Tensor: shape (batch_size, num_heads, activation_size, activation_size)
    
    
    """
    # get the mutation factor
    mutation_factor = MUTATION_FACTOR
    # return the mutated attention map
    # multiplied elementwise with a 
    # random matrix with values between 
    # 1-mutation_factor and 1+mutation_factor
    return torch.mul(attention_map, torch.randn(attention_map.shape).uniform_(1-mutation_factor,1+mutation_factor).to(attention_map.device))
    
    
################################################################


def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """
    
    # with probability p
    p = torch.rand(1)
    
    # p <= 0.6 apply the mutation only
    if p <= 0.6:
        return (nn.Softmax(dim=-1)(
                    mutate_attention_map(torch.matmul( 
                                Q , 
                                K.transpose(-1,-2)
                    )/torch.sqrt(torch.tensor(8)))    
                ) @ V).squeeze(-1)
        
    # p > 0.6 apply the crossover only
    else:
        return (nn.Softmax(dim=-1)(
                attention_map_crossover(torch.matmul( 
                            Q , 
                            K.transpose(-1,-2)
                )/torch.sqrt(torch.tensor(8)))    
            ) @ V).squeeze(-1)



class LinW_Attention_Module_C_M(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(LinW_Attention_Module_C_M, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        V = V.reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)
    

####################################################################


def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])


####################################################################


def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))


####################################################################


def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])


####################################################################


def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))


####################################################################


class MLPWD(nn.Module):
    def __init__(self):
        super(MLPWD, self).__init__()
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(3072, 8)
        self.layer_norm = nn.LayerNorm(8)
        self.l2 = LinW(in_features=8, out_features=8)
        self.l3 = nn.Linear(8, 10)
        self.gelu = nn.GELU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        repr = []
        x = self.flatten(x)
        x = self.l1(x)
        x = self.layer_norm(x)
        repr.append(x)
        x = self.gelu(x)
        x = self.gelu(self.l2(x, repr))
        x = self.l3(x)
        x = self.softmax(x)
        return x
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __len__(self):
        return len(self.layers)
    

class LinW(nn.Linear):
    def __init__(self, in_features, out_features):
        super(LinW, self).__init__(in_features=in_features, out_features=out_features)
        self.mha = LinW_Attention_Module_C_M(in_features, 2)

    def forward(self, input, activations=[]):
        activations = get_layer_activations(activations)
        return F.linear(self.mha(activations, activations, input), self.weight, self.bias)

EPOCHS = 7
BATCH_SIZE = 256

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
# train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
# test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = MLPWD().to(device)
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
criterion = nn.CrossEntropyLoss()
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)


print(model)

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(device, model, test_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%, Test accuracy: {test_accuracy:.2f}%')
    scheduler.step()
