## lets learn

lets learn how the attention is going to work . curious to know what does the attention block calculate and gives us

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from utils import Normalize
from efficientnet_pytorch import EfficientNet

In [23]:
def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)

class AttBlock(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear",
                 temperature=1.0):
        super().__init__()

        self.activation = activation
        self.temperature = temperature
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.bn_att = nn.BatchNorm1d(out_features)
        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)
        init_bn(self.bn_att)

    def forward(self, x):
        print('input x ',x.shape)
       
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
        print('input norm_att ',norm_att.shape)
        #print('this is the value ',norm_att)
        cla = self.nonlinear_transform(self.cla(x))
        print('cla ',cla.shape)
        
        x = torch.sum(norm_att * cla, dim=2)
        print('sum shape ',x.shape)
        print('sum shape ',x)
        
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)

class HpaSub(nn.Module):
    def __init__(self, classes, features):
        super(HpaSub, self).__init__()
        self.species = nn.Sequential(
            nn.Linear(features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(512, classes),
        )

    def forward(self, GAP):
        #print('GAP ',GAP.shape)
        GAP = F.avg_pool2d(GAP, GAP.size()[2:]).squeeze()
        #rint('GAP ',GAP.shape)
        spe = self.species(GAP)
        return spe

class HpaModel(nn.Module):
    def __init__(self, classes, device, base_model_name, pretrained, features):
        super(HpaModel, self).__init__()
        self.base_model_name = base_model_name
        mean_list = [0.083170892049318, 0.08627143702844145, 0.05734662013795027, 0.06582942296076659,0.0]
        std_list = [0.13561066140407024, 0.13301454127989584, 0.09142918497144226, 0.15651865713966945,1.]
        self.transform=transforms.Compose([Normalize(mean= mean_list,
                              std= std_list,
                              device = device)])

        if 'efficientnet' in self.base_model_name:
            self.model = EfficientNet.from_pretrained(self.base_model_name)#torch.hub.load('lukemelas/EfficientNet-PyTorch', self.base_model_name, pretrained=pretrained)
            #print(self.model)
        else:
            base_model = torch.hub.load('zhanghang1989/ResNeSt', self.base_model_name, pretrained=pretrained) 
            #print('the list ',list(base_model.children()))
            layers = list(base_model.children())[:-2]
            self.model = nn.Sequential(*layers)
        self.init_layer = nn.Conv2d(in_channels=5, out_channels=3, kernel_size=1, stride=1,bias= True)
        self.fc1 = nn.Linear(features, features, bias=True)
        self.att_block = AttBlock(features, classes, activation="linear")

    def forward(self, x):
        batch_size, cells, C, H, W = x.size()
        c_in = self.transform(x.view(batch_size * cells, C, H, W))
        #print('input c_in ',c_in.shape)
        c_in = F.relu(self.init_layer(c_in))
        #print('init layer c_in ',c_in.shape)
        if 'efficientnet' in self.base_model_name:
            spe = self.model.extract_features(c_in)
        else:
            spe = self.model(c_in)
        spe = F.avg_pool2d(spe, spe.size()[2:]).squeeze()
        #print('enc shape ',spe.shape)
        spe = F.relu(self.fc1(F.dropout(spe.contiguous().view(batch_size, cells, -1), p=0.5, training=self.training))).permute(0,2,1)
        #print('spe shape ',spe.shape)
        final_output, norm_att, cell_pred = self.att_block(F.dropout(spe, p=0.5, training=self.training))

        return {'final_output':final_output, 'cell_pred':cell_pred}

In [24]:
model = HpaModel(19, torch.device('cpu'), 'efficientnet-b0', True, 1280)

Loaded pretrained weights for efficientnet-b0


In [12]:
input_img = torch.from_numpy(np.zeros((1,4,5,224,224))).float()

In [10]:
input_img.shape

torch.Size([1, 4, 5, 224, 224])

In [25]:
output = model(input_img)

input x  torch.Size([1, 1280, 4])
input norm_att  torch.Size([1, 19, 4])
cla  torch.Size([1, 19, 4])
sum shape  torch.Size([1, 19])
sum shape  tensor([[ 0.3346, -0.2422, -0.0350, -0.0470,  0.1430, -0.0996, -0.0047, -0.0552,
          0.4018, -0.2664,  0.1328, -0.0169, -0.0669,  0.1333,  0.1455,  0.0601,
         -0.0338, -0.0407, -0.0215]], grad_fn=<SumBackward1>)
