In [1]:
!pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import torch
import torch.nn as nn

import torch.nn.functional as F

import pytorch_lightning as pl

In [3]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
### GLOBAL VARIABLES
CROSSOVER_MAGNITUDE = 0.3
MUTATION_FACTOR = 0.3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

################################################################

def attention_map_crossover(attention_map):
    """ Apply the crossover over the attention maps of each head.
    The crosseover consists in picking a random index in the maxtrix over
    the columns and swapping the values in between the columns of the
    attention map.
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, number_of_heads, activation_size, activation_size)
        
    Returns:
        torch.Tensor: shape (batch_size, number_of_heads, activation_size, activation_size)
    
    """
    
    # get the crossover magnitude
    crossover_magnitude = CROSSOVER_MAGNITUDE
    
    # get the batch size
    dim_batch = attention_map.shape[0]
    
    # get the number of heads
    number_of_heads = attention_map.shape[1]
    
    for idx_batch in range(dim_batch):
        for idx_head in range(number_of_heads):
            
            # get the crossover index
            crossover_index = attention_map.shape[2] - int(attention_map.shape[2]*crossover_magnitude)
            
            # get two random indexes
            random_index_1 = torch.randint(0, attention_map.shape[2],(1,))[0]
            random_index_2 = torch.randint(0, attention_map.shape[2],(1,))[0]
            
            # swap the values in that position over the columns
            for idx, (x_1, x_2) in enumerate(zip(attention_map[idx_batch][idx_head][random_index_1][crossover_index:].detach(), attention_map[idx_batch][idx_head][random_index_2][crossover_index:].detach())):
                
                # debug
                # print(attention_map[idx_batch][idx_head].shape, random_index_1, random_index_2, crossover_index)         
                # print(x_1, x_2,idx_batch, idx_head, idx)
                
                # swap the values in that position over the columns
                attention_map[idx_batch][idx_head][random_index_1][crossover_index+idx] = x_2 # make crossover
                attention_map[idx_batch][idx_head][random_index_2][crossover_index+idx] = x_1 # make crossover
    
    return attention_map

################################################################

def mutate_attention_map(attention_map):
    """ Mutate the attention map by making an elementwise multiplication with 
    a random tensor with values between 1-mutation_factor and 1+mutation_factor
    
    Args:
        attention_map (torch.Tensor): shape (batch_size, num_heads, activation_size, activation_size)
        
    Returns:    
        torch.Tensor: shape (batch_size, num_heads, activation_size, activation_size)
    
    
    """
    # get the mutation factor
    mutation_factor = MUTATION_FACTOR
    # return the mutated attention map
    # multiplied elementwise with a 
    # random matrix with values between 
    # 1-mutation_factor and 1+mutation_factor
    return torch.mul(attention_map, torch.randn(attention_map.shape).uniform_(1-mutation_factor,1+mutation_factor).to(attention_map.device))
    
    
################################################################


def head_batched_attention_mechanism(Q, K, V):
    """
    Args:
        Q: (batch_size, num_heads, num_layer, activation_size)
        K: (batch_size, num_heads, num_layer, activation_size)
        V: (batch_size, num_heads, activation_size, 1) # activations in the current layer

    Returns:
        attention: (batch_size, num_heads, activation_size)

        # attention mechanism
        # # (batch_size, num_heads, activation_size, activation_size)
        # attention = torch.matmul(Q, K.transpose(-1,-2))
        # attention = attention / torch.sqrt(torch.tensor(activation_size).float())

        # # (batch_size, num_heads, activation_size, 1)
        # attention = nn.Softmax(dim=-1)(attention)

        # # (batch_size, num_heads, activation_size, 1)
        # attention = torch.matmul(attention, V)

        # # (batch_size, num_heads, activation_size)
        # attention = attention.squeeze(-1)

    """
    
    # with probability p
    p = torch.rand(1)
    
    # p <= 0.6 apply the mutation only
    if p <= 0.6:
        return (nn.Softmax(dim=-1)(
                    mutate_attention_map(torch.matmul( 
                                Q.transpose(-1,-2) , 
                                K
                    )/torch.sqrt(torch.tensor(V.shape[2])))    
                ) @ V).squeeze(-1)
        
    # p > 0.6 apply the crossover only
    else:
        return (nn.Softmax(dim=-1)(
                attention_map_crossover(torch.matmul( 
                            Q.transpose(-1,-2) , 
                            K
                )/torch.sqrt(torch.tensor(V.shape[2])))    
            ) @ V).squeeze(-1)



class LinW_Attention(nn.Module):
    def __init__(self, dim_emb = 32, n_head = 2) -> None:
        super(LinW_Attention, self).__init__()

        assert dim_emb % n_head == 0, 'dim_emb must be divisible by n_head'

        # dim of emb
        self.dim_emb = dim_emb
        # number of heads
        self.n_head = n_head

        # linear transformations
        # which limit the dimension of the saliency map
        flatten_activations = 512*3*3
        self.W_Q = nn.Linear(flatten_activations, dim_emb)
        self.W_K = nn.Linear(flatten_activations, dim_emb)
        self.W_V = nn.Linear(flatten_activations, dim_emb)

        # linear to go back to the original
        # dimension
        self.W_O = nn.Linear(self.n_head*self.dim_emb, dim_emb)

    def forward(self, Q, K, V):
        # get the shape of the input
        batch_size, activation_size, activation_size = Q.size()
        
        # reshape Q, K, V
        # parallelize over the number of heads
        # (batch_size, num_heads, num_layer, activation_size)
        Q = torch.stack([Q for _ in range(self.n_head)], 1)
        K = torch.stack([K for _ in range(self.n_head)], 1)
        V = torch.stack([V for _ in range(self.n_head)], 1)

        # apply linear transformation
        Q = self.W_Q(Q)
        K = self.W_K(K)
        # (batch_size, self.n_head, activation_size, 1)
        V = self.W_V(V.unsqueeze(-2)).permute(0,1,3,2)
        # print(Q.shape, K.shape, V.shape)

        # V = V.reshape(batch_size, self.n_head, activation_size, 1)

        # apply attention mechanism + concatenate over the number of heads
        out_attention = head_batched_attention_mechanism(Q, K, V).reshape(batch_size, self.n_head*self.dim_emb)

        # apply linear transformation
        return self.W_O(out_attention)

In [5]:
import torch
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional, cast
from torch import Tensor
from collections import OrderedDict 

class RN50_forward_hook(nn.Module):
    def __init__(self, output_layers, *args):
        super().__init__(*args)
        self.output_layers = output_layers
        #print(self.output_layers)
        self.selected_out = OrderedDict()

        #PRETRAINED MODEL
        self.pretrained = nn.Sequential(*list(torchvision.models.resnet50(pretrained=True).children())[:-2])
        for param in self.pretrained.parameters():
            param.requires_grad = False

        self.fhooks = []

        for i,l in enumerate(list(self.pretrained._modules.keys())):
            if i in self.output_layers:
                self.fhooks.append(getattr(self.pretrained,l).register_forward_hook(self.forward_hook(l)))
    
    def forward_hook(self,layer_name):
        def hook(module, input, output):
            self.selected_out[layer_name] = output
        return hook

    def forward(self, x):
        out = self.pretrained(x)
        return out, self.selected_out



class LinW_Module(nn.Module):
    def __init__(self, num_classes=10):
        super(LinW_Module, self).__init__()

        self.rn50 = RN50_forward_hook(output_layers = [5,6,7])
        self.pooling = torch.nn.AdaptiveAvgPool3d((512, 3, 3))

        self.attention_block = LinW_Attention()
        
        act = 32
        self.classify = nn.Sequential(
            nn.Linear(act, act*2),
            nn.GELU(),
            nn.Linear(act*2, num_classes),
            nn.GELU()
        )

        
    def get_activations_per_object(self, activations):
        """ Get the activations for each object per layer

        Args:
            activations (torch.Tensor): shape (num_layers, batch_size, number_activations)

        Returns:
            torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

        """
        return torch.stack([activations[:,i,:] for i in range(activations.shape[1])])

    def get_layer_activations(self, activations):
        """ Get the activations for each layer for each sample

        Args:
            activations (list): shape [(batch_size, number_activations), (batch_size, number_activations), ...]
            batch_size (int): batch size
            number_activations (int): number of activations

        Returns:
            torch.Tensor: shape (nr_object, num_layers, activation_for_each_layer)

        """
        return self.get_activations_per_object(torch.stack(activations))

    def forward(self, x):
        # with torch.no_grad():
        # out: activation of the final layer (n_batch, )
        # layerout: activation of the preceding layers (n_layers, n_batch, n_activations, n_activations)
        out, layerout = self.rn50(x)
        

        # out, layerout = vgg(torch.stack([trainset[idx][0] for idx in range(10)]))
        # print(pooling(layerout['24']).flatten().shape, pooling(layerout['31']).flatten().shape)

        # extract and apply adaptive avg pooling on the
        # activations + flatten the activations
        representations = []
        for key in layerout.keys():
          representations.append(self.pooling(layerout[key]).flatten(start_dim=1, end_dim=- 1))

        del layerout

        # precedincg activations
        # representations shape: [(n_batch, n_activations)]*n_layers
        # [(n_batch, n_activations)]*n_layers -> (nr_sample, num_layers, activation_for_each_layer)
        representations = self.get_layer_activations(representations)
        
        # combine the representations with the past activations
        out = self.attention_block(representations, representations, self.pooling(out).flatten(start_dim=1, end_dim=- 1))

        return self.classify(out)



import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from torchmetrics.classification import Accuracy, F1Score, AveragePrecision

class Net(pl.LightningModule):
    def __init__(self, n_class = 10, loss = nn.CrossEntropyLoss()):
        super().__init__()
        
        self.linw_finetune_module = LinW_Module()
        self.loss = loss
        self.accuracy = Accuracy(task="multiclass", num_classes=n_class)
        self.f1_score = F1Score(task="multiclass", num_classes=n_class)
        self.average_precision = AveragePrecision(task="multiclass", num_classes=n_class)

    def forward(self, x):
        return self.linw_finetune_module(x)

    def training_step(self, batch, batch_idx):

        x, y = batch

        pred = self(x)
        loss = self.loss(pred, y)

        # Logging to TensorBoard by default
        self.log('train_loss', loss)
        self.log('train_accuracy', self.accuracy(pred, y))
        self.log('train_f1_score', self.f1_score(pred, y))
        self.log('train_average_precision', self.average_precision(pred, y))

        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch

        pred = self(x)
        loss = self.loss(pred, y)

        # Logging to TensorBoard by default
        self.log('val_loss', loss)
        self.log('val_accuracy', self.accuracy(pred, y))
        self.log('val_f1_score', self.f1_score(pred, y))
        self.log('val_average_precision', self.average_precision(pred, y))

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=5e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
        return [optimizer], [scheduler]


In [6]:
# import torch
# import torch.nn as nn
# from typing import Type, Any, Callable, Union, List, Optional, cast
# from torch import Tensor
# from collections import OrderedDict 

# class RN50_forward_hook(nn.Module):
#     def __init__(self, output_layers, *args):
#         super().__init__(*args)
#         self.output_layers = output_layers
#         #print(self.output_layers)
#         self.selected_out = OrderedDict()

#         #PRETRAINED MODEL
#         self.pretrained = nn.Sequential(*list(torchvision.models.resnet50(pretrained=True).children())[:-2])
#         for param in self.pretrained.parameters():
#             param.requires_grad = False

#         self.fhooks = []

#         for i,l in enumerate(list(self.pretrained._modules.keys())):
#             if i in self.output_layers:
#                 self.fhooks.append(getattr(self.pretrained,l).register_forward_hook(self.forward_hook(l)))
    
#     def forward_hook(self,layer_name):
#         def hook(module, input, output):
#             self.selected_out[layer_name] = output
#         return hook

#     def forward(self, x):
#         out = self.pretrained(x)
#         return out, self.selected_out


# net = RN50_forward_hook(output_layers = [3,4,5,6,7])



# o, l = net(torch.stack([trainset[key][0] for key in range(10)]))

In [7]:
# pooling = torch.nn.AdaptiveAvgPool3d((256, 7, 7))

# [pooling(l[i]).shape for i in l]

In [None]:
from torch.utils.data import DataLoader, TensorDataset

batch_size = 256

############################################################################

# CIFAR10
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# MNIST
'''transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.Grayscale(num_output_channels=3),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))
     ])'''

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

############################################################################

# init model
model = Net().cuda()

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=4, verbose=True, mode="min")
checkpoint_callback = ModelCheckpoint(dirpath='/content/drive/MyDrive/Colab Notebooks/bio inspired/rn50_checkpoint/', monitor="val_loss", mode="min")

print(model)
print('number of parameter: ',sum(p.numel() for p in model.parameters() if p.requires_grad)/1000000.0, 'M')

trainer = pl.Trainer(accelerator='auto', max_epochs=30, callbacks=[early_stop_callback, checkpoint_callback])

# train the model
trainer.fit(model, trainloader, testloader)


Files already downloaded and verified
Files already downloaded and verified




Net(
  (linw_finetune_module): LinW_Module(
    (rn50): RN50_forward_hook(
      (pretrained): Sequential(
        (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (4): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, 

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                 | Type                       | Params
--------------------------------------------------------------------
0 | linw_finetune_module | LinW_Module                | 24.0 M
1 | loss                 | CrossEntropyLoss           | 0     
2 | accuracy             | MulticlassAccuracy         | 0     
3 | f1_score             | MulticlassF1Score          | 0     
4 | average_precision    | MulticlassAveragePrecision | 0     
---------------------

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved. New best score: 2.269


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.035 >= min_delta = 0.0. New best score: 2.234


In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/