In [1]:
import torch
from torch import nn
from torchinfo import summary
from torchvision.models.resnet import resnet50
from torch.nn import BatchNorm2d, Softmax2d
from torchvision.models.resnet import ResNet

In [5]:
a, b = torch.randn(10, 50), torch.randn(20, 50)
torch.nn.functional.softmax(torch.cdist(a, b, p=2), dim=-1).shape

torch.Size([10, 20])

In [7]:
torch.nn.Conv2d(1024 + 2048, 10, 1, bias=False).weight.shape

torch.Size([10, 3072, 1, 1])

In [30]:
class IndividualLandmarkNet(torch.nn.Module):
    def __init__(self, init_model: ResNet, num_landmarks: int = 8,
                 num_classes: int = 2000, landmark_dropout: float = 0.3) -> None:
        """
        Parameters
        ----------
        init_model: ResNet
            The pretrained ResNet model
        num_landmarks: int
            Number of landmarks to detect
        num_classes: int
            Number of classes for the classification
        landmark_dropout: float
            Probability of dropping out a given landmark
        """
        super().__init__()

        # The base model
        self.num_landmarks = num_landmarks
        self.conv1 = init_model.conv1
        self.bn1 = init_model.bn1
        self.relu = init_model.relu
        self.maxpool = init_model.maxpool
        self.layer1 = init_model.layer1
        self.layer2 = init_model.layer2
        self.layer3 = init_model.layer3
        self.layer4 = init_model.layer4
        self.finalpool = torch.nn.AdaptiveAvgPool2d(1)

        # New part of the model
        self.softmax: Softmax2d = torch.nn.Softmax2d()
        self.batchnorm = BatchNorm2d(11)
        self.fc_landmarks = torch.nn.Conv2d(1024 + 2048, num_landmarks + 1, 1, bias=False)
        self.fc_class_landmarks = torch.nn.Linear(1024 + 2048, num_classes, bias=False)
        self.modulation = torch.nn.Parameter(torch.ones((1,1024 + 2048,num_landmarks + 1)))
        self.dropout = torch.nn.Dropout(landmark_dropout)
        self.dropout_full_landmarks = torch.nn.Dropout1d(landmark_dropout)

    def forward(self, x: torch.Tensor):
        """

        Parameters
        ----------
        x: torch.Tensor
            Input image

        Returns
        -------
        all_features: torch.Tensor
            Features per landmark
        maps: torch.Tensor
            Attention maps per landmark
        scores: torch.Tensor
            Classification scores per landmark
        """
        # Pretrained ResNet part of the model
        x = self.conv1(x) # shape: [b, 64, h1, w1], e.g. h1=w1=112
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x) # shape: [b, 64, h2, w2], e.g. h2=w2=56
        x = self.layer1(x) # shape: [b, 256, h2, w2], e.g. h2=w2=56
        x = self.layer2(x) # shape: [b, 512, h3, w3], e.g. h2=w2=28
        l3 = self.layer3(x) # shape: [b, 1024, h3, w3], e.g. h2=w2=28
        x = self.layer4(l3) # shape: [b, 2048, h4, w4], e.g. h2=w2=7
        # x = torch.nn.functional.upsample_bilinear(x, size=(l3.shape[-2], l3.shape[-1]))
        x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear') # shape: [b, 2048, h, w], e.g. h=w=14
        x = torch.cat((x, l3), dim=1) # shape: [b, 2048 + 1024, h, w], e.g. h=w=14

        # Compute per landmark attention maps
        # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
        batch_size = x.shape[0]
        ab = self.fc_landmarks(x) # shape: [b, nlandmark + 1, h, w]
        print('ab shape:', ab.shape)
        b_sq = x.pow(2).sum(1, keepdim=True) # shape: [b, 1, h, w]
        print('b_sq shape:', b_sq.shape)
        b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1) # shape: [b, nlandmark + 1, h, w]
        print('b_sq shape:', b_sq.shape)
        print('fc_landmarks.weight shape:', self.fc_landmarks.weight.shape)
        a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2], x.shape[-1])
        print('a_sq shape:', a_sq.shape)
        a_sq = a_sq.permute(1, 0, 2, 3)
        maps = b_sq - 2 * ab + a_sq
        maps = -maps

        # Softmax so that the attention maps for each pixel add up to 1
        print('maps shape:', maps.shape)
        maps = self.softmax(maps)

        # Use maps to get weighted average features per landmark
        feature_tensor = x
        all_features = ((maps).unsqueeze(1) * feature_tensor.unsqueeze(2)).mean(-1).mean(-1)
        print('all_features shape:', all_features.shape)

        # Classification based on the landmarks
        all_features_modulated = all_features * self.modulation
        all_features_modulated = self.dropout_full_landmarks(all_features_modulated.permute(0,2,1)).permute(0,2,1)
        scores = self.fc_class_landmarks(all_features_modulated.permute(0, 2, 1)).permute(0, 2, 1)

        return all_features, maps, scores

In [31]:
resnet = resnet50()
landmark_net = IndividualLandmarkNet(init_model=resnet)
x = torch.randn(3, 224, 224)

In [32]:
feats, maps, scores = landmark_net(x.unsqueeze(0))

ab shape: torch.Size([1, 9, 14, 14])
b_sq shape: torch.Size([1, 1, 14, 14])
b_sq shape: torch.Size([1, 9, 14, 14])
fc_landmarks.weight shape: torch.Size([9, 3072, 1, 1])
a_sq shape: torch.Size([9, 1, 14, 14])
maps shape: torch.Size([1, 9, 14, 14])
all_features shape: torch.Size([1, 3072, 9])


In [2]:
class IndividualLandmarkNetModified(torch.nn.Module):
    def __init__(self, init_model: ResNet, num_landmarks: int = 8,
                 num_classes: int = 2000, landmark_dropout: float = 0.3) -> None:
        """
        Parameters
        ----------
        init_model: ResNet
            The pretrained ResNet model
        num_landmarks: int
            Number of landmarks to detect
        num_classes: int
            Number of classes for the classification
        landmark_dropout: float
            Probability of dropping out a given landmark
        """
        super().__init__()

        # The base model
        self.num_landmarks = num_landmarks
        self.conv1 = init_model.conv1
        self.bn1 = init_model.bn1
        self.relu = init_model.relu
        self.maxpool = init_model.maxpool
        self.layer1 = init_model.layer1
        self.layer2 = init_model.layer2
        self.layer3 = init_model.layer3
        self.layer4 = init_model.layer4
        self.finalpool = torch.nn.AdaptiveAvgPool2d(1)

        # New part of the model
        self.softmax: Softmax2d = torch.nn.Softmax2d()
        self.batchnorm = BatchNorm2d(11)
        # self.fc_landmarks = torch.nn.Conv2d(1024 + 2048, num_landmarks + 1, 1, bias=False)
        self.landmarks = torch.nn.Parameter(torch.randn(num_landmarks + 1, 1024 + 2048))

        self.fc_class_landmarks = torch.nn.Linear(1024 + 2048, num_classes, bias=False)
        self.modulation = torch.nn.Parameter(torch.ones((1,1024 + 2048,num_landmarks + 1)))
        self.dropout = torch.nn.Dropout(landmark_dropout)
        self.dropout_full_landmarks = torch.nn.Dropout1d(landmark_dropout)

    def forward(self, x: torch.Tensor):
        """

        Parameters
        ----------
        x: torch.Tensor
            Input image

        Returns
        -------
        all_features: torch.Tensor
            Features per landmark
        maps: torch.Tensor
            Attention maps per landmark
        scores: torch.Tensor
            Classification scores per landmark
        """
        # Pretrained ResNet part of the model
        x = self.conv1(x) # shape: [b, 64, h1, w1], e.g. h1=w1=112
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x) # shape: [b, 64, h2, w2], e.g. h2=w2=56
        x = self.layer1(x) # shape: [b, 256, h2, w2], e.g. h2=w2=56
        x = self.layer2(x) # shape: [b, 512, h3, w3], e.g. h2=w2=28
        l3 = self.layer3(x) # shape: [b, 1024, h3, w3], e.g. h2=w2=28
        x = self.layer4(l3) # shape: [b, 2048, h4, w4], e.g. h2=w2=7
        # x = torch.nn.functional.upsample_bilinear(x, size=(l3.shape[-2], l3.shape[-1]))
        x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear') # shape: [b, 2048, h, w], e.g. h=w=14
        x = torch.cat((x, l3), dim=1) # shape: [b, 2048 + 1024, h, w], e.g. h=w=14

        # Compute per landmark attention maps
        b, c, h, w = x.shape
        x_flat = x.reshape(b, c, h*w).permute(0, 2, 1) # shape: [b, h*w, 2048 + 1024]
        maps = torch.cdist(x_flat, self.landmarks, p=2) # shape: [b, h*w, nlandmarks]
        maps = maps.permute(0, 2, 1).reshape(b, -1, h, w) # shape: [b, nlandmarks, h, w]
        # Softmax so that the attention maps for each pixel add up to 1
        maps = self.softmax(-maps)

        # Use maps to get weighted average features per landmark
        feature_tensor = x
        all_features = ((maps).unsqueeze(1) * feature_tensor.unsqueeze(2)).mean(-1).mean(-1)

        # Classification based on the landmarks
        all_features_modulated = all_features * self.modulation
        all_features_modulated = self.dropout_full_landmarks(all_features_modulated.permute(0,2,1)).permute(0,2,1)
        scores = self.fc_class_landmarks(all_features_modulated.permute(0, 2, 1)).permute(0, 2, 1)

        return all_features, maps, scores

In [4]:
resnet = resnet50()
landmark_net = IndividualLandmarkNetModified(init_model=resnet)
x = torch.randn(3, 224, 224)
summary(landmark_net)

Layer (type:depth-idx)                   Param #
IndividualLandmarkNetModified            55,296
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─Bottleneck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─Bottleneck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13           

In [49]:
feats, maps, scores = landmark_net(x.unsqueeze(0))

maps shape: torch.Size([1, 9, 14, 14])
all_features shape: torch.Size([1, 3072, 9])
