# Introducing a mechanism of mutations and crossover only on the attention maps 

Here we introduce a mechanism of mutations and crossover only on the attention maps. The idea is to propose a mechanism which evolve individually the attention map of each head obtained in the previous notebook by using crossover and mutations. The key idea is that we use those attention maps to compute the activations of the next layer. The operation is made by making a weighted average of the activatition in the previous layer where the weights are the attention maps (where we applied mutations and crossover). The objective is to evolve the attention maps in order to obtain a better representation of the input, exploiting faster weights that lead to speeding up the convergence.

In [1]:
import torch
import torch.nn as nn
import time

In [22]:
def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """

    return (nn.Softmax(dim=-1)(
                torch.matmul( 
                            Q.transpose(-1,-2) , 
                            K
                )/torch.sqrt(torch.tensor(8))    
            ) @ V).squeeze(-1)



class MultiHeadAttention(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(MultiHeadAttention, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_Q = nn.Linear(dim_emb, dim_emb)
        self.W_K = nn.Linear(dim_emb, dim_emb)
        self.W_V = nn.Linear(dim_emb, dim_emb)

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V).reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)
    



def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))


# # I have the batched activations for each layer for each sample
# layers = torch.stack([torch.randn(batch_size, number_activations) for _ in range(10)])
# print(layers.shape, '(num_layers, batch_size, number_activations)')

# # I get the activations for each layer for each sample
# obj_activations = torch.stack([layers[:,i,:] for i in range(layers.shape[1])])
# print(obj_activations.shape, '(nr_object, num_layers, activation_for_each_layer)')


batch_size = 16
number_activations = 8

a = torch.randn(batch_size, number_activations)


attn = get_layer_activations([torch.randn(batch_size, number_activations) for _ in range(10)])


MultiHeadAttention(8, 4)(attn, attn, a).shape

# list(extract_activations_per_sample(
#             extract_activations_layers(layers), 
#             mask=False
#         )
# )

torch.Size([16, 8])

In [265]:
# modified attention mechanism

### GLOBAL VARIABLES
CROSSOVER_MAGNITUDE = 0.3
MUTATION_FACTOR = 0.3



################################################################

def attention_map_crossover(attention_map):
    """ Apply the crossover over the attention maps of each head.
    The crosseover consists in picking a random index in the maxtrix over
    the columns and swapping the values in between the columns of the
    attention map.
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, number_of_heads, activation_size, activation_size)
        
    Returns:
        torch.Tensor: shape (batch_size, number_of_heads, activation_size, activation_size)
    
    """
    
    # get the crossover magnitude
    crossover_magnitude = CROSSOVER_MAGNITUDE
    
    # get the batch size
    dim_batch = attention_map.shape[0]
    
    # get the number of heads
    number_of_heads = attention_map.shape[1]
    
    for idx_batch in range(dim_batch):
        for idx_head in range(number_of_heads):
            
            # get the crossover index
            crossover_index = attention_map.shape[2] - int(attention_map.shape[2]*crossover_magnitude)
            
            # get two random indexes
            random_index_1 = torch.randint(0, attention_map.shape[2],(1,))[0]
            random_index_2 = torch.randint(0, attention_map.shape[2],(1,))[0]
            
            # swap the values in that position over the columns
            for idx, (x_1, x_2) in enumerate(zip(attention_map[idx_batch][idx_head][random_index_1][crossover_index:].detach(), attention_map[idx_batch][idx_head][random_index_2][crossover_index:].detach())):
                
                # debug
                # print(attention_map[idx_batch][idx_head].shape, random_index_1, random_index_2, crossover_index)         
                # print(x_1, x_2,idx_batch, idx_head, idx)
                
                # swap the values in that position over the columns
                attention_map[idx_batch][idx_head][random_index_1][crossover_index+idx] = x_2 # make crossover
                attention_map[idx_batch][idx_head][random_index_2][crossover_index+idx] = x_1 # make crossover
    
    return attention_map

################################################################

def mutate_attention_map(attention_map):
    """ Mutate the attention map by making an elementwise multiplication with 
    a random tensor with values between 1-mutation_factor and 1+mutation_factor
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, num_heads, activation_size, activation_size)
        
    Returns:    
        torch.Tensor: shape (batch_size, num_heads, activation_size, activation_size)
    
    
    """
    # get the mutation factor
    mutation_factor = MUTATION_FACTOR
    # return the mutated attention map
    # multiplied elementwise with a 
    # random matrix with values between 
    # 1-mutation_factor and 1+mutation_factor
    return torch.mul(attention_map, torch.randn(attention_map.shape).uniform_(1-mutation_factor,1+mutation_factor))
    
    
################################################################


def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """
    
    # with probability p
    p = torch.rand(1)
    
    # p <= 0.6 apply the mutation only
    if p <= 0.6:
        return (nn.Softmax(dim=-1)(
                    mutate_attention_map(torch.matmul( 
                                Q.transpose(-1,-2) , 
                                K
                    )/torch.sqrt(torch.tensor(8)))    
                ) @ V).squeeze(-1)
        
    # p > 0.6 apply the crossover only
    else:
        return (nn.Softmax(dim=-1)(
                attention_map_crossover(torch.matmul( 
                            Q.transpose(-1,-2) , 
                            K
                )/torch.sqrt(torch.tensor(8)))    
            ) @ V).squeeze(-1)



class MultiHeadAttention(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(MultiHeadAttention, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_Q = nn.Linear(dim_emb, dim_emb)
        self.W_K = nn.Linear(dim_emb, dim_emb)
        self.W_V = nn.Linear(dim_emb, dim_emb)

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V).reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)
    



def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))



# EXAMPLE 
#######################################################################
#######################################################################
#######################################################################

# inputs
batch_size = 16
number_activations = 8

#######################################################################
number_layers = 10

activations = [torch.randn(batch_size, number_activations) for _ in range(number_layers)]
# get layered activations
attn = get_layer_activations(activations)
print(len(activations), activations[0].size(),'(num_layers, batch_size, number_activations)')

print(activations[0].size()) # (batch_size, number_activations)

print(attn.shape) # (batch_size, number_layers, number_activations)

# get shape of the attention map
# shape: (batch_size, num_heads, activation_size, activation_size)
print(MultiHeadAttention(8, 4)(attn, attn, a)[1].shape)

# extract attention map for each head
out = MultiHeadAttention(8, 4)(attn, attn, activations[-1])

out.shape # (batch_size, activation_size)


10 torch.Size([16, 8]) (num_layers, batch_size, number_activations)
torch.Size([16, 8])
torch.Size([16, 10, 8])
torch.Size([8])


torch.Size([16, 8])

In [262]:
# # # EXAMPLE element-wise multiplication
# # mat = torch.stack([torch.Tensor([i for i in range(3)]) for _ in range(3)], 1) 
# # print(mat)
# # print(torch.mul(mat, mat)) # element-wise multiplication

# # # multiply two torch matrix element-wise
# # mutation_factor = 0.03

# # # how to mutate the attention map
# # mutated_attention_map = torch.mul(attention_map, torch.randn(attention_map.shape).uniform_(1-mutation_factor,1+mutation_factor))
# # # get the difference between the two matrices mutated_attention_map and attention_map
# # print('mutation magnitude:',abs(torch.sum(mutated_attention_map-attention_map).detach().numpy()))


# # indexing over the columns of the attention map

# # def crossover_attention_map(attention_map, crossover_magnitude):


# import multiprocessing

# print(multiprocessing.cpu_count())

# def attention_map_crossover(attention_map):
    
#     crossover_magnitude = 0.3
    
#     dim_batch = attention_map.shape[0]
#     number_of_heads = attention_map.shape[1]
    
#     for idx_batch in range(dim_batch):
#         for idx_head in range(number_of_heads):
            
#             print(idx_head)
            
#             crossover_index = attention_map.shape[2] - int(attention_map.shape[2]*crossover_magnitude)
                
#             random_index_1 = torch.randint(0, attention_map.shape[2],(1,))[0]
#             random_index_2 = torch.randint(0, attention_map.shape[2],(1,))[0]
            
#             for idx, (x_1, x_2) in enumerate(zip(attention_map[idx_batch][idx_head][random_index_1][crossover_index:].detach(), attention_map[idx_batch][idx_head][random_index_2][crossover_index:].detach())):
                
#                 print(attention_map[idx_batch][idx_head].shape, random_index_1, random_index_2, crossover_index)
                            
#                 print(x_1, x_2,idx_batch, idx_head, idx)
                
#                 attention_map[idx_batch][idx_head][random_index_1][crossover_index+idx] = x_2 # make crossover
#                 attention_map[idx_batch][idx_head][random_index_2][crossover_index+idx] = x_1 # make crossover
    
#     return attention_map
    

# import time

# start = time.time()
# attention_map_crossover(attention_map).shape
# end = time.time()

# print('time required to perform crossover:',end-start)

# torch.rand(1)

#### Evaluation

Evaluate the layer with few neurons on the MNIST dataset. The results are good even with few neurons. The evidence shows that introducing mutations and crossover on the attention map used to weight the activation of the next layer makes the netwoek converge. It would be intresting to see the evaluation with more neurons and on more complex dataset.

In [267]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F


def evaluate(device, model, dataloader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def train(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_loader.dataset)
    accuracy = 100. * correct / total
    return train_loss, accuracy

### GLOBAL VARIABLES
CROSSOVER_MAGNITUDE = 0.3
MUTATION_FACTOR = 0.3



################################################################

def attention_map_crossover(attention_map):
    """ Apply the crossover over the attention maps of each head.
    The crosseover consists in picking a random index in the maxtrix over
    the columns and swapping the values in between the columns of the
    attention map.
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, number_of_heads, activation_size, activation_size)
        
    Returns:
        torch.Tensor: shape (batch_size, number_of_heads, activation_size, activation_size)
    
    """
    
    # get the crossover magnitude
    crossover_magnitude = CROSSOVER_MAGNITUDE
    
    # get the batch size
    dim_batch = attention_map.shape[0]
    
    # get the number of heads
    number_of_heads = attention_map.shape[1]
    
    for idx_batch in range(dim_batch):
        for idx_head in range(number_of_heads):
            
            # get the crossover index
            crossover_index = attention_map.shape[2] - int(attention_map.shape[2]*crossover_magnitude)
            
            # get two random indexes
            random_index_1 = torch.randint(0, attention_map.shape[2],(1,))[0]
            random_index_2 = torch.randint(0, attention_map.shape[2],(1,))[0]
            
            # swap the values in that position over the columns
            for idx, (x_1, x_2) in enumerate(zip(attention_map[idx_batch][idx_head][random_index_1][crossover_index:].detach(), attention_map[idx_batch][idx_head][random_index_2][crossover_index:].detach())):
                
                # debug
                # print(attention_map[idx_batch][idx_head].shape, random_index_1, random_index_2, crossover_index)         
                # print(x_1, x_2,idx_batch, idx_head, idx)
                
                # swap the values in that position over the columns
                attention_map[idx_batch][idx_head][random_index_1][crossover_index+idx] = x_2 # make crossover
                attention_map[idx_batch][idx_head][random_index_2][crossover_index+idx] = x_1 # make crossover
    
    return attention_map

################################################################

def mutate_attention_map(attention_map):
    """ Mutate the attention map by making an elementwise multiplication with 
    a random tensor with values between 1-mutation_factor and 1+mutation_factor
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, num_heads, activation_size, activation_size)
        
    Returns:    
        torch.Tensor: shape (batch_size, num_heads, activation_size, activation_size)
    
    
    """
    # get the mutation factor
    mutation_factor = MUTATION_FACTOR
    # return the mutated attention map
    # multiplied elementwise with a 
    # random matrix with values between 
    # 1-mutation_factor and 1+mutation_factor
    return torch.mul(attention_map, torch.randn(attention_map.shape).uniform_(1-mutation_factor,1+mutation_factor))
    
    
################################################################


def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """
    
    # with probability p
    p = torch.rand(1)
    
    # p <= 0.6 apply the mutation only
    if p <= 0.6:
        return (nn.Softmax(dim=-1)(
                    mutate_attention_map(torch.matmul( 
                                Q.transpose(-1,-2) , 
                                K
                    )/torch.sqrt(torch.tensor(8)))    
                ) @ V).squeeze(-1)
        
    # p > 0.6 apply the crossover only
    else:
        return (nn.Softmax(dim=-1)(
                attention_map_crossover(torch.matmul( 
                            Q.transpose(-1,-2) , 
                            K
                )/torch.sqrt(torch.tensor(8)))    
            ) @ V).squeeze(-1)



class MultiHeadAttention(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(MultiHeadAttention, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_Q = nn.Linear(dim_emb, dim_emb)
        self.W_K = nn.Linear(dim_emb, dim_emb)
        self.W_V = nn.Linear(dim_emb, dim_emb)

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V).reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)
    



def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))

####################################################################

def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))

####################################################################


class MLPWD(nn.Module):
    def __init__(self):
        super(MLPWD, self).__init__()
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(784, 8)
        self.l2 = LinW(in_features=8, out_features=8, depth=0)
        self.l3 = LinW(in_features=8, out_features=8, depth=1, layers=[self.l2])
        self.l4 = nn.Linear(8, 10)
        self.gelu = nn.GELU()
        self.layers = [self.l2, self.l3]

    def forward(self, x):
        repr = []
        x = self.flatten(x)
        x = self.l1(x)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        # x = self.gelu(self.l2(x, repr))
        x = self.l2(x, repr)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        x = self.gelu(self.l3(x, repr))
        x = self.l4(x)
        return x
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __len__(self):
        return len(self.layers)
    

class LinW(nn.Linear):
    def __init__(self, in_features, out_features, depth, layers=[]):
        super(LinW, self).__init__(in_features=in_features, out_features=out_features)
        self.depth = depth
        self.layers = layers[:self.depth] if len(layers)>0 else layers
        self.mha = MultiHeadAttention(in_features, 4)

    def forward(self, input, activations=[]):
        activations = get_layer_activations(activations)
        return F.linear(self.mha(activations, activations, input), self.weight, self.bias)

EPOCHS = 10
BATCH_SIZE = 16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = MLPWD().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print("LinW layers:", "\n".join([f"Depth {model[i].depth}: {model[i]}" for i in range(len(model))]), sep="\n\n")

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(device, model, test_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%, Test accuracy: {test_accuracy:.2f}%')
    lr_scheduler.step()


LinW layers:

Depth 0: LinW(
  in_features=8, out_features=8, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=8, out_features=8, bias=True)
    (W_K): Linear(in_features=8, out_features=8, bias=True)
    (W_V): Linear(in_features=8, out_features=8, bias=True)
    (W_O): Linear(in_features=32, out_features=8, bias=True)
  )
)
Depth 1: LinW(
  in_features=8, out_features=8, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=8, out_features=8, bias=True)
    (W_K): Linear(in_features=8, out_features=8, bias=True)
    (W_V): Linear(in_features=8, out_features=8, bias=True)
    (W_O): Linear(in_features=32, out_features=8, bias=True)
  )
)
Epoch 1/10, Training Loss: 0.9341, Training Accuracy: 68.37%, Test accuracy: 83.95%
Epoch 2/10, Training Loss: 0.4378, Training Accuracy: 87.18%, Test accuracy: 89.38%
Epoch 3/10, Training Loss: 0.3512, Training Accuracy: 89.89%, Test accuracy: 89.93%
Epoch 4/10, Training Loss: 0.3149, Training Accuracy: 90.98%, Test

Evaluation on CIFAR10

In [270]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F


def evaluate(device, model, dataloader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def train(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_loader.dataset)
    accuracy = 100. * correct / total
    return train_loss, accuracy

### GLOBAL VARIABLES
CROSSOVER_MAGNITUDE = 0.3
MUTATION_FACTOR = 0.3



################################################################

def attention_map_crossover(attention_map):
    """ Apply the crossover over the attention maps of each head.
    The crosseover consists in picking a random index in the maxtrix over
    the columns and swapping the values in between the columns of the
    attention map.
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, number_of_heads, activation_size, activation_size)
        
    Returns:
        torch.Tensor: shape (batch_size, number_of_heads, activation_size, activation_size)
    
    """
    
    # get the crossover magnitude
    crossover_magnitude = CROSSOVER_MAGNITUDE
    
    # get the batch size
    dim_batch = attention_map.shape[0]
    
    # get the number of heads
    number_of_heads = attention_map.shape[1]
    
    for idx_batch in range(dim_batch):
        for idx_head in range(number_of_heads):
            
            # get the crossover index
            crossover_index = attention_map.shape[2] - int(attention_map.shape[2]*crossover_magnitude)
            
            # get two random indexes
            random_index_1 = torch.randint(0, attention_map.shape[2],(1,))[0]
            random_index_2 = torch.randint(0, attention_map.shape[2],(1,))[0]
            
            # swap the values in that position over the columns
            for idx, (x_1, x_2) in enumerate(zip(attention_map[idx_batch][idx_head][random_index_1][crossover_index:].detach(), attention_map[idx_batch][idx_head][random_index_2][crossover_index:].detach())):
                
                # debug
                # print(attention_map[idx_batch][idx_head].shape, random_index_1, random_index_2, crossover_index)         
                # print(x_1, x_2,idx_batch, idx_head, idx)
                
                # swap the values in that position over the columns
                attention_map[idx_batch][idx_head][random_index_1][crossover_index+idx] = x_2 # make crossover
                attention_map[idx_batch][idx_head][random_index_2][crossover_index+idx] = x_1 # make crossover
    
    return attention_map

################################################################

def mutate_attention_map(attention_map):
    """ Mutate the attention map by making an elementwise multiplication with 
    a random tensor with values between 1-mutation_factor and 1+mutation_factor
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, num_heads, activation_size, activation_size)
        
    Returns:    
        torch.Tensor: shape (batch_size, num_heads, activation_size, activation_size)
    
    
    """
    # get the mutation factor
    mutation_factor = MUTATION_FACTOR
    # return the mutated attention map
    # multiplied elementwise with a 
    # random matrix with values between 
    # 1-mutation_factor and 1+mutation_factor
    return torch.mul(attention_map, torch.randn(attention_map.shape).uniform_(1-mutation_factor,1+mutation_factor))
    
    
################################################################


def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """
    
    # with probability p
    p = torch.rand(1)
    
    # p <= 0.6 apply the mutation only
    if p <= 0.6:
        return (nn.Softmax(dim=-1)(
                    mutate_attention_map(torch.matmul( 
                                Q.transpose(-1,-2) , 
                                K
                    )/torch.sqrt(torch.tensor(8)))    
                ) @ V).squeeze(-1)
        
    # p > 0.6 apply the crossover only
    else:
        return (nn.Softmax(dim=-1)(
                attention_map_crossover(torch.matmul( 
                            Q.transpose(-1,-2) , 
                            K
                )/torch.sqrt(torch.tensor(8)))    
            ) @ V).squeeze(-1)



class MultiHeadAttention(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(MultiHeadAttention, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_Q = nn.Linear(dim_emb, dim_emb)
        self.W_K = nn.Linear(dim_emb, dim_emb)
        self.W_V = nn.Linear(dim_emb, dim_emb)

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V).reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)
    



def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))


####################################################################

def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))

####################################################################


class MLPWD(nn.Module):
    def __init__(self):
        super(MLPWD, self).__init__()
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(3072, 32)
        self.l2 = LinW(in_features=32, out_features=32, depth=0)
        self.l3 = LinW(in_features=32, out_features=32, depth=1, layers=[self.l2])
        self.l4 = nn.Linear(32, 10)
        self.gelu = nn.GELU()
        self.layers = [self.l2, self.l3]

    def forward(self, x):
        repr = []
        x = self.flatten(x)
        x = self.l1(x)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        # x = self.gelu(self.l2(x, repr))
        x = self.l2(x, repr)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        x = self.gelu(self.l3(x, repr))
        x = self.l4(x)
        return x
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __len__(self):
        return len(self.layers)
    

class LinW(nn.Linear):
    def __init__(self, in_features, out_features, depth, layers=[]):
        super(LinW, self).__init__(in_features=in_features, out_features=out_features)
        self.depth = depth
        self.layers = layers[:self.depth] if len(layers)>0 else layers
        self.mha = MultiHeadAttention(in_features, 4)

    def forward(self, input, activations=[]):
        activations = get_layer_activations(activations)
        return F.linear(self.mha(activations, activations, input), self.weight, self.bias)

EPOCHS = 10
BATCH_SIZE = 120

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = MLPWD().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print("LinW layers:", "\n".join([f"Depth {model[i].depth}: {model[i]}" for i in range(len(model))]), sep="\n\n")

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(device, model, test_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%, Test accuracy: {test_accuracy:.2f}%')
    lr_scheduler.step()


Files already downloaded and verified
Files already downloaded and verified
LinW layers:

Depth 0: LinW(
  in_features=32, out_features=32, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=32, out_features=32, bias=True)
    (W_K): Linear(in_features=32, out_features=32, bias=True)
    (W_V): Linear(in_features=32, out_features=32, bias=True)
    (W_O): Linear(in_features=128, out_features=32, bias=True)
  )
)
Depth 1: LinW(
  in_features=32, out_features=32, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=32, out_features=32, bias=True)
    (W_K): Linear(in_features=32, out_features=32, bias=True)
    (W_V): Linear(in_features=32, out_features=32, bias=True)
    (W_O): Linear(in_features=128, out_features=32, bias=True)
  )
)
Epoch 1/10, Training Loss: 1.9556, Training Accuracy: 27.93%, Test accuracy: 34.38%


KeyboardInterrupt: 

Here I would like to test if adding ADD & NORM when we perform the crossover improve the results.


In [273]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F


def evaluate(device, model, dataloader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def train(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_loader.dataset)
    accuracy = 100. * correct / total
    return train_loss, accuracy

### GLOBAL VARIABLES
CROSSOVER_MAGNITUDE = 0.3
MUTATION_FACTOR = 0.3



################################################################
# attention_map_crossover + ADD & NORM

def attention_map_crossover(attention_map):
    """ Apply the crossover over the attention maps of each head.
    The crosseover consists in picking a random index in the maxtrix over
    the columns and swapping the values in between the columns of the
    attention map.
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, number_of_heads, activation_size, activation_size)
        
    Returns:
        torch.Tensor: shape (batch_size, number_of_heads, activation_size, activation_size)
    
    """
    
    # get the crossover magnitude
    crossover_magnitude = CROSSOVER_MAGNITUDE
    
    # get the batch size
    dim_batch = attention_map.shape[0]
    
    # get the number of heads
    number_of_heads = attention_map.shape[1]
    
    for idx_batch in range(dim_batch):
        for idx_head in range(number_of_heads):
            
            # get the crossover index
            crossover_index = attention_map.shape[2] - int(attention_map.shape[2]*crossover_magnitude)
            
            # get two random indexes
            random_index_1 = torch.randint(0, attention_map.shape[2],(1,))[0]
            random_index_2 = torch.randint(0, attention_map.shape[2],(1,))[0]
            
            # swap the values in that position over the columns
            for idx, (x_1, x_2) in enumerate(zip(attention_map[idx_batch][idx_head][random_index_1][crossover_index:].detach(), attention_map[idx_batch][idx_head][random_index_2][crossover_index:].detach())):
                
                # debug
                # print(attention_map[idx_batch][idx_head].shape, random_index_1, random_index_2, crossover_index)         
                # print(x_1, x_2,idx_batch, idx_head, idx)
                
                temp = attention_map[idx_batch][idx_head]
                
                # swap the values in that position over the columns
                attention_map[idx_batch][idx_head][random_index_1][crossover_index+idx] = x_2 # make crossover
                attention_map[idx_batch][idx_head][random_index_2][crossover_index+idx] = x_1 # make crossover
                
                # ADD & NORM
                attention_map[idx_batch][idx_head] = torch.nn.LayerNorm(attention_map[idx_batch][idx_head].shape)(attention_map[idx_batch][idx_head] + temp)
    
    return attention_map

################################################################

def mutate_attention_map(attention_map):
    """ Mutate the attention map by making an elementwise multiplication with 
    a random tensor with values between 1-mutation_factor and 1+mutation_factor
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, num_heads, activation_size, activation_size)
        
    Returns:    
        torch.Tensor: shape (batch_size, num_heads, activation_size, activation_size)
    
    
    """
    # get the mutation factor
    mutation_factor = MUTATION_FACTOR
    # return the mutated attention map
    # multiplied elementwise with a 
    # random matrix with values between 
    # 1-mutation_factor and 1+mutation_factor
    return torch.mul(attention_map, torch.randn(attention_map.shape).uniform_(1-mutation_factor,1+mutation_factor))
    
    
################################################################


def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """
    
    # with probability p
    p = torch.rand(1)
    
    # p <= 0.6 apply the mutation only
    if p <= 0.6:
        return (nn.Softmax(dim=-1)(
                    mutate_attention_map(torch.matmul( 
                                Q.transpose(-1,-2) , 
                                K
                    )/torch.sqrt(torch.tensor(8)))    
                ) @ V).squeeze(-1)
        
    # p > 0.6 apply the crossover only
    else:
        return (nn.Softmax(dim=-1)(
                attention_map_crossover(torch.matmul( 
                            Q.transpose(-1,-2) , 
                            K
                )/torch.sqrt(torch.tensor(8)))    
            ) @ V).squeeze(-1)



class MultiHeadAttention(nn.Module):
    def __init__(self, dim_emb, n_head) -> None:
        super(MultiHeadAttention, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        self.dim_emb = dim_emb
        self.n_head = n_head

        self.W_Q = nn.Linear(dim_emb, dim_emb)
        self.W_K = nn.Linear(dim_emb, dim_emb)
        self.W_V = nn.Linear(dim_emb, dim_emb)

        self.W_O = nn.Linear(dim_emb*n_head, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # check the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # move to device
        Q.to(device)
        K.to(device)
        V.to(device)
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V).reshape(batch_size, self.n_head, activation_size, 1)
        
        # apply attention mechanism
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*activation_size)

        # apply linear transformation
        return self.W_O(out_attention)
    



def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))


####################################################################

def get_activations_per_object(activations):
    """ Get the activations for each object per layer

    Args:
        activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

def get_layer_activations(activations):
    """ Get the activations for each layer for each sample

    Args:
        activations (torch.Tensor): shape (batch_size, number_activations)
        batch_size (int): batch size
        number_activations (int): number of activations

    Returns:
        torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

    """
    return get_activations_per_object(torch.stack(activations))

####################################################################


class MLPWD(nn.Module):
    def __init__(self):
        super(MLPWD, self).__init__()
        self.flatten = nn.Flatten()
        self.l1 = nn.Linear(3072, 32)
        self.l2 = LinW(in_features=32, out_features=32, depth=0)
        self.l3 = LinW(in_features=32, out_features=32, depth=1, layers=[self.l2])
        self.l4 = nn.Linear(32, 10)
        self.gelu = nn.GELU()
        self.layers = [self.l2, self.l3]

    def forward(self, x):
        repr = []
        x = self.flatten(x)
        x = self.l1(x)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        # x = self.gelu(self.l2(x, repr))
        x = self.l2(x, repr)
        repr.append(x)
        x = self.gelu(x)
        # repr.append(x.detach().cpu().numpy())
        x = self.gelu(self.l3(x, repr))
        x = self.l4(x)
        return x
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __len__(self):
        return len(self.layers)
    

class LinW(nn.Linear):
    def __init__(self, in_features, out_features, depth, layers=[]):
        super(LinW, self).__init__(in_features=in_features, out_features=out_features)
        self.depth = depth
        self.layers = layers[:self.depth] if len(layers)>0 else layers
        self.mha = MultiHeadAttention(in_features, 4)

    def forward(self, input, activations=[]):
        activations = get_layer_activations(activations)
        return F.linear(self.mha(activations, activations, input), self.weight, self.bias)

EPOCHS = 10
BATCH_SIZE = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = MLPWD().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print("LinW layers:", "\n".join([f"Depth {model[i].depth}: {model[i]}" for i in range(len(model))]), sep="\n\n")

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_accuracy = evaluate(device, model, test_loader)
    print(f'Epoch {epoch + 1}/{EPOCHS}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%, Test accuracy: {test_accuracy:.2f}%')
    lr_scheduler.step()


Files already downloaded and verified
Files already downloaded and verified
LinW layers:

Depth 0: LinW(
  in_features=32, out_features=32, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=32, out_features=32, bias=True)
    (W_K): Linear(in_features=32, out_features=32, bias=True)
    (W_V): Linear(in_features=32, out_features=32, bias=True)
    (W_O): Linear(in_features=128, out_features=32, bias=True)
  )
)
Depth 1: LinW(
  in_features=32, out_features=32, bias=True
  (mha): MultiHeadAttention(
    (W_Q): Linear(in_features=32, out_features=32, bias=True)
    (W_K): Linear(in_features=32, out_features=32, bias=True)
    (W_V): Linear(in_features=32, out_features=32, bias=True)
    (W_O): Linear(in_features=128, out_features=32, bias=True)
  )
)


KeyboardInterrupt: 