# Cross Stitch Networks
- Paper: Cross-Stitch Networks for Multi-task Learning - Misra 2016 CVPR
- Reference pytorch implementations: [Vandenhende](https://github.com/SimonVandenhende/Multi-Task-Learning-PyTorch/blob/master/models/cross_stitch.py); [MTAN](https://github.com/lorenmt/mtan/blob/master/im2im_pred/model_segnet_cross.py)
    - Existing implementions seem unecessarily complicated or tightly coupled, hence this notebook
- Sime idea - keep a separate backbone per task and regularly combine activations (linearly)
    - Backbones are not task-specific - through cross-stitching they are jointly optimised
    - The way to control how much of the backbone remains task-specific is through the mixing parameters (alpha, beta)
    - Mixing parameters are initialised (and thought of) as a convex combination, however there are no contraints during training making it effectively a linear combination
    - Alpha dictates how much weight to give the own backbone, while Beta how much weights to attribute to other task's backbones (activations)
    - During the forward pass the input has to pass through all backbones simulatneously and be cross-stitched

![](https://d3i71xaburhd42.cloudfront.net/2976605dc3b73377696537291d45f09f1ab1fbf5/3-Figure3-1.png)
- Downsides:
    - Number of parameters increases linearly with number of tasks - we continue to add backbones (a lot)
    - Works only in single-input multi-output settings (not multi-domain)
    - Because of the high parameter count, it might not be possible to fit in memory if the backbone is large or we have many tasks

![](https://d3i71xaburhd42.cloudfront.net/2976605dc3b73377696537291d45f09f1ab1fbf5/4-Figure4-1.png)

- Implementation notes:
    - Authors suggest to put Cross-Stitch units after pooling layers, in resnets the equivalent would be after each stage (when downsampling)
    - Cross-stitching is meant to be applied per channel. A "unit" refers one [n_tasks x n_tasks] matrix for one channel. 
        - Although there was one case in the paper where it is applied per layer. MTAN impl does this.
    - Unclear what the suggested initialisation for CrossStitch units is, paper has ablations but does not specify what the default should be
        - Vadenhende uses (a=0.9, b=0.1), while Liu uses (a=1,b=1)
    - Backbone initialisation - paper prefers single-task pretraining. For a fair comparison with related work, the backbones should be intialised identically. In practice the backbones are pre-trained networks.
    
- Tests
    - Testing against Vadenhende implementation for resnet backbones. Needed to modify to decouple from rest of library.



## Standalone CrossStitch Implementation
- Backbone/encoder logic only

- Minimal Implementation
    - Only for resnets, very similar to reference

- Generic Implementation
    - works for resnets/vggs
    - has a bit of bloat to determine stages and channel sizes
    - Factory method to instantiate common resnet/vgg arch based on str
        - patches up vggs to have "stage" structure
        - identical task backbones

- Use cases:
    - DeepLabs - Encapsulate and add heads
    - SegNet: Extend to patch maxpools and override forward to handle indices; encapsulate and add decoders

In [1]:
import copy
import torch
import torchvision
from torch import nn
from collections import OrderedDict


class CrossStitchLayer(nn.Module):
    """
    Keeps all mixing weights in a single tensor. Dict input and output.
    :author: @xapharius
    """

    def __init__(self, tasks: list, num_channels: int, alpha=0.9, beta=0.1):
        super().__init__()
        self.num_channels = num_channels
        self.task_idx = {t: idx for idx, t in enumerate(tasks)}

        # units[task1][task2] returns the weigths that should be applied on the channels of task2's backbone when computing the output of task1
        self.units = nn.Parameter(torch.Tensor(len(tasks), len(tasks), num_channels))
        self.init(alpha, beta)

    def __repr__(self):
        return f"CrossStitchLayer(tasks={list(self.task_idx.keys())}, num_channels={self.num_channels}"

    def __getitem__(self, item):
        """allows for [ti, tj] indexing"""
        ti, tj = item
        return self.units[self.task_idx[ti], self.task_idx[tj]]

    def init(self, alpha: float, beta: float):
        self.units.data.fill_(beta)
        n_tasks = len(self.task_idx)
        self.units.data[range(n_tasks), range(n_tasks), :] = alpha
        return self

    def forward(self, x: dict) -> dict:
        return {
            ti: torch.sum(
                torch.stack([self[ti, tj].view(1, -1, 1, 1) * x[tj] for tj in x]),
                dim=0,
            )
            for ti in x
        }


class CrossStitchResNet(nn.Module):
    """
    Resnets only, requires passing num channels for each stage. 
    Generic implementation is lengthier but has more helper functions.
    """
    STAGES = ["layer1", "layer2", "layer3", "layer4"]

    def __init__(self, backbones: dict, channels: list, alpha=0.9, beta=0.1):
        super().__init__()
        self.backbones = nn.ModuleDict(backbones)
        self.tasks = list(backbones.keys())

        cs = {stage: CrossStitchLayer(self.tasks, c_, alpha, beta) for stage, c_ in zip(self.STAGES, channels)}
        self.cross_stitch = nn.ModuleDict(cs)

    def forward(self, x: torch.Tensor) -> dict:
        x = {
            task: m.maxpool(m.relu(m.bn1(m.conv1(x))))
            for task, m in self.backbones.items()
        }

        for stage in self.STAGES:
            for task in self.tasks:
                x[task] = getattr(self.backbones[task], stage)(x[task])
            x = self.cross_stitch[stage](x)
        return x


# cs = CrossStitchLayer(["task1", "task2"], 2)
# cs({"task1": torch.ones(1, 2), "task2": -torch.ones(1, 2)})

In [2]:
class GenericCrossStitchBackbone(nn.Module):
    """
    Only the backbone/encoder logic, extend or encapsulate to add heads.
    For resnet type models with names stages "layerX". VGGs can be patched to have the same structure.
    Has factories, and figures out the number of channels in each stage automatically.
    :author: @xapharius
    """

    def __init__(self, backbones: dict, alpha=0.9, beta=0.1):
        """
        :param backbones: {task: resnet}
        :param alpha, beta: same/other task weight
        """
        super().__init__()
        self.backbones = nn.ModuleDict(backbones)
        self.tasks = list(backbones.keys())

        _bb = backbones[self.tasks[0]]
        self.stages = [f"layer{i}" for i in range(1, 6) if hasattr(_bb, f"layer{i}")]
        
        cs = dict()
        for stage in self.stages:
            channels = self.get_out_channels(getattr(_bb, stage))
            cs[stage] = CrossStitchLayer(self.tasks, channels, alpha, beta)
        self.cross_stitch = nn.ModuleDict(cs)

    def forward(self, x: torch.Tensor) -> dict:
        x = {task: x for task in self.tasks}

        # Handle resnet pre-stage layers
        if hasattr(self.backbones[self.tasks[0]], "conv1"):
            x = {
                task: m.maxpool(m.relu(m.bn1(m.conv1(x[task]))))
                for task, m in self.backbones.items()
            }

        for stage in self.stages:
            for task in self.tasks:
                x[task] = getattr(self.backbones[task], stage)(x[task])
            x = self.cross_stitch[stage](x)
        return x

    @staticmethod
    def get_out_channels(model: nn.Module):
        """Based on the last conv or batchnorm layer in the model"""
        m = [m for m in model.modules() if isinstance(m, (nn.Conv2d, nn.BatchNorm2d))][-1]
        if isinstance(m, nn.Conv2d):
            return m.out_channels
        if isinstance(m, nn.BatchNorm2d):
            return m.num_features

    @staticmethod
    def patch_vgg(model: nn.Module) -> nn.Module:
        """Split vgg layers into stages similar to resnets"""
        maxpool_idx = [ix for ix, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)]
        return nn.Sequential(
            OrderedDict({
                f"layer{i+1}": nn.Sequential(*model.features[ix_start + 1 : ix_end + 1])
                for i, (ix_start, ix_end) in enumerate(
                    zip([-1] + maxpool_idx, maxpool_idx)
                )
            })
        )

    @classmethod
    def factory(
        cls, tasks: list, model="resnet18", pretrained=None, alpha=0.9, beta=0.1
    ):
        """For predefined resnets/vggs, same init for each task"""
        bb = getattr(torchvision.models, model)(weights=pretrained)
        if "vgg" in model:
            bb = cls.patch_vgg(bb)
        backbones = {task: copy.deepcopy(bb) for task in tasks}
        return cls(backbones, alpha, beta)


# CrossStitchBackbone.factory(tasks=["t1", "t2"], model="resnet18", pretrained=True)
# CrossStitchBackbone.factory(tasks=["t1", "t2"], model="vgg16_bn")

# Vandenhende Implementation
https://github.com/SimonVandenhende/Multi-Task-Learning-PyTorch/blob/master/models/cross_stitch.py
- Test against this implementation
- Requires a bit of patching as it's tighly coupled to library

In [3]:
#
# Authors: Simon Vandenhende
# Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)

""" 
    Implementation of cross-stitch networks
    https://arxiv.org/abs/1604.03539
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class ChannelWiseMultiply(nn.Module):
    def __init__(self, num_channels):
        super(ChannelWiseMultiply, self).__init__()
        self.param = nn.Parameter(torch.FloatTensor(num_channels), requires_grad=True)

    def init_value(self, value):
        with torch.no_grad():
            self.param.data.fill_(value)

    def forward(self, x):
        return torch.mul(self.param.view(1,-1,1,1), x)


class CrossStitchUnit(nn.Module):
    def __init__(self, tasks, num_channels, alpha, beta):
        super(CrossStitchUnit, self).__init__()
        self.cross_stitch_unit = nn.ModuleDict({t: nn.ModuleDict({t: ChannelWiseMultiply(num_channels) for t in tasks}) for t in tasks})

        for t_i in tasks:
            for t_j in tasks:
                if t_i == t_j:
                    self.cross_stitch_unit[t_i][t_j].init_value(alpha)
                else:
                    self.cross_stitch_unit[t_i][t_j].init_value(beta)

    def forward(self, task_features):
        out = {}
        for t_i in task_features.keys():
            prod = torch.stack([self.cross_stitch_unit[t_i][t_j](task_features[t_j]) for t_j in task_features.keys()])
            out[t_i] = torch.sum(prod, dim=0)
        return out
           

class CrossStitchNetwork(nn.Module):
    """ 
        Implementation of cross-stitch networks.
        We insert a cross-stitch unit, to combine features from the task-specific backbones
        after every stage.
       
        Argument: 
            backbone: 
                nn.ModuleDict object which contains pre-trained task-specific backbones.
                {task: backbone for task in p.TASKS.NAMES}
        
            heads: 
                nn.ModuleDict object which contains the task-specific heads.
                {task: head for task in p.TASKS.NAMES}
        
            stages: 
                list of stages where we instert a cross-stitch unit between the task-specific backbones.
                Note: the backbone modules require a method 'forward_stage' to get feature representations
                at the respective stages.
        
            channels: 
                dict which contains the number of channels in every stage
        
            alpha, beta: 
                floats for initializing cross-stitch units (see paper)
        
    """
    def __init__(self, p, backbone: nn.ModuleDict, heads: nn.ModuleDict, 
                    stages: list, channels: dict, alpha: float, beta: float):
        super(CrossStitchNetwork, self).__init__()

        # Tasks, backbone and heads
        # self.tasks = p.TASKS.NAMES
        self.tasks = p # MY EDIT
        self.backbone = backbone
        self.heads = heads
        self.stages = stages

        # Cross-stitch units
        self.cross_stitch = nn.ModuleDict({stage: CrossStitchUnit(self.tasks, channels[stage], alpha, beta) for stage in stages})


    def forward(self, x):
        img_size = x.size()[-2:]
        x = {task: x for task in self.tasks} # Feed as input to every single-task network

        x = {task: m.maxpool(m.relu(m.bn1(m.conv1(x[task])))) for task, m in self.backbone.items()} # MY EDIT
        
        # Backbone
        for stage in self.stages:
    
            # Forward through next stage of task-specific network
            for task in self.tasks:
                # x[task] = self.backbone[task].forward_stage(x[task], stage)
                x[task] = getattr(self.backbone[task], stage)(x[task]) # MY EDIT
            
            # Cross-stitch the task-specific features
            x = self.cross_stitch[stage](x)
        return x
        # Task-specific heads
        # out = {task: self.heads[task](x[task]) for task in self.tasks}
        # out = {task: F.interpolate(out[task], img_size, mode='bilinear') for task in self.tasks} 

        # return out

# Test

In [4]:
X = torch.rand(1, 3, 224, 224)

In [5]:
# Reference Implementation

from torchvision.models import resnet18

p = ["task1", "task2"]
stages = ["layer1", "layer2", "layer3", "layer4"]
channels = {'layer1': 64, 'layer2': 128, 'layer3': 256, 'layer4': 512}

backbones = nn.ModuleDict({task: resnet18(weights="IMAGENET1K_V1") for task in p})
heads = nn.ModuleDict({task: nn.Sequential() for task in p})
model_ref = CrossStitchNetwork(p, backbones, heads, stages, channels, alpha=0.9, beta=0.1)

out_ref = model_ref(X)

In [6]:
# Minimal ResNet CrossStitch

backbones = nn.ModuleDict({task: resnet18(weights="IMAGENET1K_V1") for task in p})
channels = [64, 128, 256, 512]

model_my = CrossStitchResNet(backbones, channels, alpha=0.9, beta=0.1)
out_my = model_my(X)

print("Matching backbone outputs:")
for task in out_ref.keys():
    print(task, torch.equal(out_ref[task], out_my[task]))

Matching backbone outputs:
task1 True
task2 True


In [7]:
# Generic CrossStitch

model_my = GenericCrossStitchBackbone.factory(tasks=["task1", "task2"], pretrained="IMAGENET1K_V1", alpha=0.9, beta=0.1)
out_my = model_my(X)

print("Matching backbone outputs:")
for task in out_ref.keys():
    print(task, torch.equal(out_ref[task], out_my[task]))

Matching backbone outputs:
task1 True
task2 True
