# MTAN
- MTAN is a state-of-the art architecture for Multi-task learning
- The official implementation is a bit messy, here I will propose a simpler version and test that it produces the same outputs
- I will provide the implementation for MTAN on VGG16 and Resnet backbones
- To keep it simple enough for a notebook, these implementation will be for multi-domain classification rather than dense predictions
- Extending it to dense predictions can be done by attaching decoders/deeplab heads and potentially keep track of pooling indices
- Manage to significantly reduce size of VGG/ResNet code but using a dedicated attention module
- This notebook is structured as follows: notes to MTAN, My simplified implementation, Tests against original copy-pasted source code.

![](https://vitalab.github.io/article/images/MTAN/architecture.png)

Notes on MTAN:
- Soft-parameter sharing method, so each task has it's own parameters/activations at every level -> and a different computational graph
- There is one attention module per block/ stage
- The attention module learns to putput a mask per task, which is then applied to the representations of the shared backbone
- Each attention module has 3 inputs: two inputs from the shared backbone, and the output from the previous attention module
- Each attention module itself has task-specific and shared parameters
- The task specific attention parameters only compute the mask
- The shared attention parameters act as an alternative backbone; hence they will be built similar to the backbone architecture
- Inputs for the attention module are the intermediary outputs from each backbone block/stage. Where these outputs are taken from seems a bit arbitrary/architecture dependent. One input (on which the task specific mask is applied) is alwasy the output of the backbone stage. However the other input, the one that goes into the computation of the mask, can be takern either from the beginning or somewhere in the middle of the block/stage.
- The task specific mask is computed on information from both the backbone and the previous attention -> these are concatenated hence the 2x input size of attention convs.
- The reason the implmentation becomes messy is because of the inputs to the attention module: it requires outputs of conv layers / blocks from inside backbone stages. This is not great since one cannot just encapsulate the backbone, one has to run each layer manually and retreive the selected intermediary representation. 

# Simple Implementation
- The key is to define a dedicated attention module in a smart way
- Becuase the shared part of the attention module is architecture dependent, it has to be an argument
- Also generally spearking VGG and resnet architectures seem to be parametrised (as in having arch hyper-params) however the truth is that they have a lot of idiosyncracies which make them a lot less generalisable. Therefore sometimes the most elegat way is to just hardcode indices or stages.

In [1]:
import torch
import itertools
from torch import nn
from typing import List, Dict
import torch.nn.functional as F
from torchvision.models import vgg16_bn, resnet18, resnet50, resnet101
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1


class ConvBNReLU(nn.Sequential):
    """Shorthand for conv layer"""
    def __init__(self, in_channels, out_channels):
        super().__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

def init_model(model: nn.Module):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
    return model

## MTAN Attention
- shared_attention sub-module has to be defined outside as it depends on the backbone

In [2]:
class MTANAttention(nn.Module):
    """
    Combines task specific attentions and shared parameters
    Task specific params, which compute the attion mask, are always defined the same way
    Shared parameters depend on architecture as they try to emulate compoents in the backbone
    """
    
    def __init__(self, tasks: List[str], in_channels: int, mid_channels: int, out_channels: int, shared_att: nn.Module = None):
        """
        :param in_channels: If this modules receives inputs from a previous attention module, then in_channels will have to be 2x the equivalent backbone input
        :param shared_att: Architecture dependent - For VGG it's a 3x3 conv, for ResNets it's a bottleneck layer; Optional, if missing module outputs masked in2
        """
        super().__init__()
        self.task_att = nn.ModuleDict({
            task: nn.Sequential(
                    nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=1, padding=0),
                    nn.BatchNorm2d(mid_channels),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, padding=0),
                    nn.BatchNorm2d(out_channels),
                    nn.Sigmoid(),
                )
            for task in tasks
        }) 
        self.shared_att = shared_att if shared_att is not None else nn.Sequential()
            
    def forward(self, task: str, in1: torch.Tensor, in2: torch.Tensor, att_in: torch.Tensor = None):
        """
        :param in1: the output from somewhere in the shared backbone block/stage (first vertical line in figure)
        :param in2: the output from the shared backbone  block/stage (second vertical line in figure)
        :param att_in: Optional; the output from a previous attention module; 
        """
        out = in1
        if att_in is not None:
            out = torch.cat([out, att_in], dim=1) # merge output from previous attention module if available
        out = self.task_att[task](out)  # The two task-specific convs
        out = out * in2 # element-wise multiplication with second input (from shared blockss)
        out = self.shared_att(out) # The shared conv at the end
        return out

## VGG16

In [3]:
class MTANVGG16(nn.Module):
    """
    For multi-domain classification. Attaches a classifier on top of outputs from last attention module.
    Using vgg16_bn model as backbone from torchvision
    """
    
    def __init__(self, tasks: Dict[str, int]):
        """
        :param tasks: {task_name: n_classes} e.g {"T1": 5, "T2": 3}
        """
        super().__init__()
        self.tasks = tasks
        self.shared_layers = vgg16_bn().features[:-1] # dropping the last maxpool
        # Not worth trying to use loops
        self.attentions = nn.ModuleList([
            MTANAttention(tasks, 64, 64, 64, ConvBNReLU(64, 128)),
            MTANAttention(tasks, 2 * 128, 128, 128, ConvBNReLU(128, 256)), # in_channels is twice as big due to concatenation
            MTANAttention(tasks, 2 * 256, 256, 256, ConvBNReLU(256, 512)), 
            MTANAttention(tasks, 2 * 512, 512, 512, ConvBNReLU(512, 512)),
            MTANAttention(tasks, 2 * 512, 512, 512, ConvBNReLU(512, 512)),
        ])
        self.classifier = nn.ModuleDict({task: nn.Linear(512, task_classes) for task, task_classes in self.tasks.items()})
        
    def forward(self, task: str, X):
        """
        Classification output for a single task
        """
        
        # Need to keep activations from conv-bn-relu layers, not more granular than that
        sh_outs = []
        for layer in self.shared_layers:
            X = layer(X)
            if isinstance(X, nn.ReLU):
                sh_outs.append(X)
        
        maxpool = lambda x: F.max_pool2d(x, kernel_size=2, stride=2)
        sh_out_idx = [(0, 1), (2, 3), (4, 6), (7, 9), (10, 12)] # indices of the relevant outputs from the shared backbone that go into the attention module
        att_out = None
        for block, (in1_idx, in2_idx) in enumerate(sh_out_idx):
            att_out = maxpool(self.attentions[block](task, sh_outs[in1_idx], sh_outs[in2_idx], att_out))
            
        out = torch.flatten(att_out, 1)
        out = self.classifier[task](out)
        return out

## ResNets

In [4]:
class MTANResNet(nn.Module):
    """
    Should work for both Basic and Bottleneck based ResNets
    Assuming the first input to the attention module is always the output of the penultimate block in a stage
    Task-specific layers have a 4x bottleneck
    Uses maxpool between stages
    """
    def __init__(self, tasks: Dict[str, int], backbone=resnet50()):
        """
        :param tasks: {task_name: n_classes} e.g {"T1": 5, "T2": 3}
        :param backbone: yes, I know, default value not good practice. 
        """
        super().__init__()
        self.tasks = tasks
        self.shared_conv = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool)
        
        # splitting each stage in two for the in1 and in2 inputs; and storing as one list
        self.shared_layers = nn.ModuleList(
            itertools.chain.from_iterable(
                [(layer[:-1], layer[-1]) 
                 for layer in [backbone.__getattr__(f"layer{i}") for i in range(1,5)]
                ]
        ))
        
        BlockCls = self.shared_layers[-1].__class__ # hopefuly BasicBlock or Bottleneck
        shared_att = lambda ch_in, ch_out: BlockCls(ch_in, ch_out, downsample=nn.Sequential(conv1x1(ch_in, 4 * ch_out), nn.BatchNorm2d(4 * ch_out)))  
        
        ch = [l.conv1.in_channels for l in self.shared_layers[1::2]] # first conv of last block
        self.attentions = nn.ModuleList([
            MTANAttention(tasks, ch[0], ch[0] // 4, ch[0], shared_att(ch[0], ch[1] // 4)),
            MTANAttention(tasks, 2 * ch[1], ch[1] // 4, ch[1], shared_att(ch[1], ch[2] // 4)),
            MTANAttention(tasks, 2 * ch[2], ch[2] // 4, ch[2], shared_att(ch[2], ch[3] // 4)),
            MTANAttention(tasks, 2 * ch[3], ch[3] // 4, ch[3]), # final attention doesnt have a shared part
        ])
        self.classifier = nn.ModuleDict({task: nn.Linear(ch[3], task_classes) for task, task_classes in self.tasks.items()})
        
    def forward(self, task: str, X):
        """
        Classification output for a single task
        """
        out = self.shared_conv(X)
        sh_outs = [out := l(out) for l in self.shared_layers]
        
        maxpool = lambda x: F.max_pool2d(x, kernel_size=2, stride=2)
        att_out = None
        for stage in range(4):
            att_out = self.attentions[stage](task, sh_outs[stage * 2], sh_outs[stage * 2 + 1], att_out)
            if stage < 3:
                # NOTE: This can/should be changed if backbone is different (e.g uses dilation)
                att_out = maxpool(att_out)
        
        out = F.avg_pool2d(att_out, 8)
        out = torch.flatten(out, 1)
        out = self.classifier[task](out)
        return out

# Tests
- Instead of running benchmarks I will just compare the output of my implementation to the output of the original
- Official MTAN implementation does not have a library structure, meaning I cannot import individual modules easily. I decided to copy-paste implementations and just comment out stuff that is not relevant to the comparison.
- Copying layer weights rather than fiddling with the seed

## VGG16
- VGG16 for domain classification is actually never used in the paper, however we can compare against the Encoder part of SegNet

### Original Implementation
- commenting out the decoder part
- outputting only the last attentions from the encoder (which would feed into classifiers if there were any)
- No other changes to the orignal code

In [5]:
class SegNet(nn.Module):
    """
    This is the original implementaiton @https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_mtan.py
    """
    def __init__(self):
        super(SegNet, self).__init__()
        # initialise network parameters
        filter = [64, 128, 256, 512, 512]
        self.class_nb = 13

        # define encoder decoder layers
        self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
        self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        for i in range(4):
            self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
            self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))

        # define convolution layer
        self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        for i in range(4):
            if i == 0:
                self.conv_block_enc.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
            else:
                self.conv_block_enc.append(nn.Sequential(self.conv_layer([filter[i + 1], filter[i + 1]]),
                                                         self.conv_layer([filter[i + 1], filter[i + 1]])))
                self.conv_block_dec.append(nn.Sequential(self.conv_layer([filter[i], filter[i]]),
                                                         self.conv_layer([filter[i], filter[i]])))

        # define task attention layers
        self.encoder_att = nn.ModuleList([nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])])
        self.decoder_att = nn.ModuleList([nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])])
        self.encoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[1]])])
        self.decoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])

        for j in range(3):
            if j < 2:
                self.encoder_att.append(nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])]))
                self.decoder_att.append(nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])]))
            for i in range(4):
                self.encoder_att[j].append(self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]]))
                self.decoder_att[j].append(self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]]))

        for i in range(4):
            if i < 3:
                self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 2]]))
                self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i]]))
            else:
                self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]]))

        self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True)
        self.pred_task2 = self.conv_layer([filter[0], 1], pred=True)
        self.pred_task3 = self.conv_layer([filter[0], 3], pred=True)

        # define pooling and unpooling functions
        self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)

        self.logsigma = nn.Parameter(torch.FloatTensor([-0.5, -0.5, -0.5]))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def conv_layer(self, channel, pred=False):
        if not pred:
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=3, padding=1),
                nn.BatchNorm2d(num_features=channel[1]),
                nn.ReLU(inplace=True),
            )
        else:
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=channel[0], out_channels=channel[0], kernel_size=3, padding=1),
                nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0),
            )
        return conv_block

    def att_layer(self, channel):
        att_block = nn.Sequential(
            nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0),
            nn.BatchNorm2d(channel[1]),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=channel[1], out_channels=channel[2], kernel_size=1, padding=0),
            nn.BatchNorm2d(channel[2]),
            nn.Sigmoid(),
        )
        return att_block

    def forward(self, x):
        g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ([0] * 5 for _ in range(5))
        for i in range(5):
            g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))

        # define attention list for tasks
        atten_encoder, atten_decoder = ([0] * 3 for _ in range(2))
        for i in range(3):
            atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2))
        for i in range(3):
            for j in range(5):
                atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2))

        # define global shared network
        for i in range(5):
            if i == 0:
                g_encoder[i][0] = self.encoder_block[i](x)
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
            else:
                g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])

        for i in range(5):
            if i == 0:
                g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
            else:
                g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])

        # define task dependent attention module
        for i in range(3):
            for j in range(5):
                if j == 0:
                    atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0])
                    atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
                    atten_encoder[i][j][2] = self.encoder_block_att[j](atten_encoder[i][j][1])
                    atten_encoder[i][j][2] = F.max_pool2d(atten_encoder[i][j][2], kernel_size=2, stride=2)
                else:
                    atten_encoder[i][j][0] = self.encoder_att[i][j](torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1))
                    atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
                    atten_encoder[i][j][2] = self.encoder_block_att[j](atten_encoder[i][j][1])
                    atten_encoder[i][j][2] = F.max_pool2d(atten_encoder[i][j][2], kernel_size=2, stride=2)
        
        return [att[-1][2] for att in atten_encoder] 


# MY EDIT
#             for j in range(5):
#                 if j == 0:
#                     atten_decoder[i][j][0] = F.interpolate(atten_encoder[i][-1][-1], scale_factor=2, mode='bilinear', align_corners=True)
#                     atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](atten_decoder[i][j][0])
#                     atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1))
#                     atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
#                 else:
#                     atten_decoder[i][j][0] = F.interpolate(atten_decoder[i][j - 1][2], scale_factor=2, mode='bilinear', align_corners=True)
#                     atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](atten_decoder[i][j][0])
#                     atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1))
#                     atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]

#         # define task prediction layers
#         t1_pred = F.log_softmax(self.pred_task1(atten_decoder[0][-1][-1]), dim=1)
#         t2_pred = self.pred_task2(atten_decoder[1][-1][-1])
#         t3_pred = self.pred_task3(atten_decoder[2][-1][-1])
#         t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)

#         return [t1_pred, t2_pred, t3_pred], self.logsigma

### Test Wrapper
- Wrapping my MTANVGG16 to output the encoder representations for each task, since we don't use the classifiers

In [6]:
class TestVGG(MTANVGG16):
    """
    Just early return representations before classifier
    """
    def __init__(self):
        super().__init__({str(i): i for i in range(3)}) # doesnt matter
    
    def forward(self, X):
        sh_outs = []
        for layer in self.shared_layers:
            X = layer(X)
            if isinstance(layer, nn.ReLU):
                sh_outs.append(X)
        
        maxpool = lambda x: F.max_pool2d(x, kernel_size=2, stride=2)
        sh_out_idx = [(0, 1), (2, 3), (4, 6), (7, 9), (10, 12)] # indices of the relevant outputs from the shared backbone that go into the attention module
        
        res = []
        for task in ["0", "1", "2"]:
            att_out = None
            for block, (in1_idx, in2_idx) in enumerate(sh_out_idx):
                att_out = maxpool(self.attentions[block](task, sh_outs[in1_idx], sh_outs[in2_idx], att_out))
            res.append(att_out)
        return res

### Initialisation
- For the implementations to be able to output the same thing, they need to have the same network parameters
- We only really care to copy the conv weights, everything else either doesnt have params or is initialised statically.

In [7]:
# init here to set the conv biases and BN layers
my = init_model(TestVGG())
ref = init_model(SegNet())

# COPY conv weights
with torch.no_grad():
    # attentions
    for block in range(5):
        for task in range(3):
            my.attentions[block].task_att[str(task)][0].weight.copy_(ref.encoder_att[task][block][0].weight)
            my.attentions[block].task_att[str(task)][3].weight.copy_(ref.encoder_att[task][block][3].weight)
        my.attentions[block].shared_att[0].weight.copy_(ref.encoder_block_att[block][0].weight)

    # backbone
    ref_backbone = nn.Sequential(*list(itertools.chain(*[list(itertools.chain(l1, l2)) for l1, l2 in zip(ref.encoder_block, ref.conv_block_enc)])))
    ref_backbone_convs = [m for m in ref_backbone.modules() if isinstance(m, nn.Conv2d)]
    my_backbone_convs = [m for m in my.shared_layers.modules() if isinstance(m, nn.Conv2d)]
    for my_conv, ref_conv in zip(my_backbone_convs, ref_backbone_convs):
        my_conv.weight.copy_(ref_conv.weight)



### Comparing Outputs
- Works :)

In [8]:
X = torch.randn(1, 3, 256, 256)

my_out = my(X)
ref_out = ref(X)

print("Outputs match for VGG16:", all([(torch.equal(y1, y2)) for y1, y2 in zip(my_out, ref_out)]))

Outputs match for VGG16: True


## ResNets
- Not used for multi-domain classificaiton in paper
- Usef as part of DeepLab architecture for dense predictions
- Backbone can be changed to other resnets

### Original Implementation
- We can comment out the Deeplab part and just keep the backbone
- Orig implementation uses dilated resnets based on own implementation - would have had to copy paste those as well, so I just simplified and and using standard resnet as backbone
- Adding maxpooling after stages to make up for not using dilation

In [9]:
class MTANDeepLabv3(nn.Module):
    """
    Orig Implementation from @https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_resnet_mtan/resnet_mtan.py
    """
    
    def __init__(self):
        super(MTANDeepLabv3, self).__init__()
#         backbone = ResnetDilated(resnet.__dict__['resnet50'](pretrained=True)) # MY EDIT
        backbone = resnet50() # Use default resnet instead of dilated for testing; just to minimise dependencies
        ch = [256, 512, 1024, 2048]
        
        self.tasks = ['segmentation', 'depth', 'normal']
        self.num_out_channels = {'segmentation': 13, 'depth': 1, 'normal': 3}
        
#         self.shared_conv = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu1, backbone.maxpool) # MY EDIT
        self.shared_conv = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool)

        
        
        # We will apply the attention over the last bottleneck layer in the ResNet. 
        self.shared_layer1_b = backbone.layer1[:-1] 
        self.shared_layer1_t = backbone.layer1[-1]

        self.shared_layer2_b = backbone.layer2[:-1]
        self.shared_layer2_t = backbone.layer2[-1]

        self.shared_layer3_b = backbone.layer3[:-1]
        self.shared_layer3_t = backbone.layer3[-1]

        self.shared_layer4_b = backbone.layer4[:-1]
        self.shared_layer4_t = backbone.layer4[-1]

        # Define task specific attention modules using a similar bottleneck design in residual block
        # (to avoid large computations)
        self.encoder_att_1 = nn.ModuleList([self.att_layer(ch[0], ch[0] // 4, ch[0]) for _ in self.tasks])
        self.encoder_att_2 = nn.ModuleList([self.att_layer(2 * ch[1], ch[1] // 4, ch[1]) for _ in self.tasks])
        self.encoder_att_3 = nn.ModuleList([self.att_layer(2 * ch[2], ch[2] // 4, ch[2]) for _ in self.tasks])
        self.encoder_att_4 = nn.ModuleList([self.att_layer(2 * ch[3], ch[3] // 4, ch[3]) for _ in self.tasks])

        # Define task shared attention encoders using residual bottleneck layers
        # We do not apply shared attention encoders at the last layer,
        # so the attended features will be directly fed into the task-specific decoders.
        self.encoder_block_att_1 = self.conv_layer(ch[0], ch[1] // 4)
        self.encoder_block_att_2 = self.conv_layer(ch[1], ch[2] // 4)
        self.encoder_block_att_3 = self.conv_layer(ch[2], ch[3] // 4)
        
        self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2)

        # Define task-specific decoders using ASPP modules
#         self.decoders = nn.ModuleList([DeepLabHead(2048, self.num_out_channels[t]) for t in self.tasks])
        
    def forward(self, x, out_size=None):
        # Shared convolution
        x = self.shared_conv(x)
        
        # Shared ResNet block 1
        u_1_b = self.shared_layer1_b(x)
        u_1_t = self.shared_layer1_t(u_1_b)

        # Shared ResNet block 2
        u_2_b = self.shared_layer2_b(u_1_t)
        u_2_t = self.shared_layer2_t(u_2_b)

        # Shared ResNet block 3
        u_3_b = self.shared_layer3_b(u_2_t)
        u_3_t = self.shared_layer3_t(u_3_b)
        
        # Shared ResNet block 4
        u_4_b = self.shared_layer4_b(u_3_t)
        u_4_t = self.shared_layer4_t(u_4_b)

        # Attention block 1 -> Apply attention over last residual block
        a_1_mask = [att_i(u_1_b) for att_i in self.encoder_att_1]  # Generate task specific attention map
        a_1 = [a_1_mask_i * u_1_t for a_1_mask_i in a_1_mask]  # Apply task specific attention map to shared features
        a_1 = [self.down_sampling(self.encoder_block_att_1(a_1_i)) for a_1_i in a_1]
    
        # Attention block 2 -> Apply attention over last residual block
        a_2_mask = [att_i(torch.cat((u_2_b, a_1_i), dim=1)) for a_1_i, att_i in zip(a_1, self.encoder_att_2)]
        a_2 = [a_2_mask_i * u_2_t for a_2_mask_i in a_2_mask]
#         a_2 = [self.encoder_block_att_2(a_2_i) for a_2_i in a_2] # MY EDIT
        a_2 = [self.down_sampling(self.encoder_block_att_2(a_2_i)) for a_2_i in a_2] # NOTE: Add maxpooling (like in visual decathlon impl) becuse regular of resnet backbone
        
        # Attention block 3 -> Apply attention over last residual block
        a_3_mask = [att_i(torch.cat((u_3_b, a_2_i), dim=1)) for a_2_i, att_i in zip(a_2, self.encoder_att_3)]
        a_3 = [a_3_mask_i * u_3_t for a_3_mask_i in a_3_mask]
#         a_3 = [self.encoder_block_att_3(a_3_i) for a_3_i in a_3] # MY EDIT
        a_3 = [self.down_sampling(self.encoder_block_att_3(a_3_i)) for a_3_i in a_3] # NOTE: Add maxpooling (like in visual decathlon impl) becuse regular of resnet backbone
        
        # Attention block 4 -> Apply attention over last residual block (without final encoder)
        a_4_mask = [att_i(torch.cat((u_4_b, a_3_i), dim=1)) for a_3_i, att_i in zip(a_3, self.encoder_att_4)]
        a_4 = [a_4_mask_i * u_4_t for a_4_mask_i in a_4_mask]
        
        pred = [F.avg_pool2d(a_4_i, 8) for a_4_i in a_4]
        return pred
        
#         # Task specific decoders
#         out = [0 for _ in self.tasks]
#         for i, t in enumerate(self.tasks):
#             out[i] = F.interpolate(self.decoders[i](a_4[i]), size=out_size, mode='bilinear', align_corners=True)
#             if t == 'segmentation':
#                 out[i] = F.log_softmax(out[i], dim=1)
#             if t == 'normal':
#                 out[i] = out[i] / torch.norm(out[i], p=2, dim=1, keepdim=True)
#         return out
    
    def att_layer(self, in_channel, intermediate_channel, out_channel):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=intermediate_channel, kernel_size=1, padding=0),
            nn.BatchNorm2d(intermediate_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=intermediate_channel, out_channels=out_channel, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channel),
            nn.Sigmoid())
        
    def conv_layer(self, in_channel, out_channel):
        downsample = nn.Sequential(conv1x1(in_channel, 4 * out_channel, stride=1),
                                   nn.BatchNorm2d(4 * out_channel))
        return Bottleneck(in_channel, out_channel, downsample=downsample)

### Test Wrapper
- Wrapping my MTANResNet to output the encoder representations for each task, since we don't use the classifiers

In [10]:
class TestResNet(MTANResNet):
    def __init__(self):
        super().__init__({str(i): i for i in range(3)}) # doesnt matter
    
    def forward(self, X):
        out = self.shared_conv(X)
        sh_outs = [out := l(out) for l in self.shared_layers]
        
        maxpool = lambda x: F.max_pool2d(x, kernel_size=2, stride=2)
        
        res = []
        for task in ["0", "1", "2"]:
            att_out = None
            for stage in range(4):
                att_out = self.attentions[stage](task, sh_outs[stage * 2], sh_outs[stage * 2 + 1], att_out)
                if stage < 3:
                    # NOTE: This can/should be changed if backbone is different (e.g uses dilation)
                    att_out = maxpool(att_out)

            out = F.avg_pool2d(att_out, 8)
            res.append(out)
        
        return res

### Initialisation
- For the implementations to be able to output the same thing, they need to have the same network parameters
- We only really care to copy the conv weights, everything else either doesnt have params or is initialised statically.

In [11]:
# init here to set the conv biases and BN layers
my = init_model(TestResNet())
ref = init_model(MTANDeepLabv3())

# COPY conv weights
with torch.no_grad():
    # attentions
    for stage in range(4):
        my_att = my.attentions[stage]
        
        # task specific
        for task in range(3):
            my_att.task_att[str(task)][0].weight.copy_(getattr(ref, f"encoder_att_{stage+1}")[task][0].weight)
            my_att.task_att[str(task)][3].weight.copy_(getattr(ref, f"encoder_att_{stage+1}")[task][3].weight)
        
        # shared
        if len(list(my_att.shared_att.modules())) > 1:
            ref_convs = [m for m in getattr(ref, f"encoder_block_att_{stage+1}").modules() if isinstance(m, nn.Conv2d)]
            my_convs = [m for m in my_att.shared_att.modules() if isinstance(m, nn.Conv2d)]
            for my_conv, ref_conv in zip(my_convs, ref_convs):
                my_conv.weight.copy_(ref_conv.weight)

    # backbone
    my.shared_conv[0].weight.copy_(ref.shared_conv[0].weight)
    ref_backbone = [f"shared_layer{stage}_{split}" for stage in range(1, 5) for split in ["b", "t"]]
    ref_backbone = nn.Sequential(*[getattr(ref, layer) for layer in ref_backbone])
    ref_backbone_convs = [m for m in ref_backbone.modules() if isinstance(m, nn.Conv2d)]
    my_backbone_convs = [m for m in my.shared_layers.modules() if isinstance(m, nn.Conv2d)]
    for my_conv, ref_conv in zip(my_backbone_convs, ref_backbone_convs):
        my_conv.weight.copy_(ref_conv.weight)

### Comparing Outputs
- Works :)

In [12]:
X = torch.randn(1, 3, 256, 256)

my_out = my(X)
ref_out = ref(X)

print("Outputs match for ResNet50:", all([torch.equal(y1, y2) for y1, y2 in zip(my_out, ref_out)]))

Outputs match for ResNet50: True
