# MultiMAE: Multi-modal Multi-task Masked Autoencoders



**Important:** This notebook requires a GPU for installing certain dependencies. Let's see which one we got here:

In [None]:
!nvidia-smi

## 1 Install dependencies

These cells download the MultiMAE code base, as well as the DPT and Mask2Former repositories that are used to pseudo label RGB images.

First, we need to downgrade PyTorch to version 1.10.0 due to compatibility issues with Detectron2. Make sure to restart the runtime once after reinstalling PyTorch, and once after installing all other packages.

### 1.1 Downgrade PyTorch

In [1]:
%pip uninstall -y torch torchvision torchaudio torchtext
%pip install torch==1.10.0+cu113 torchvision==0.11.0+cu113 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html

Found existing installation: torch 1.10.0+cu113
Uninstalling torch-1.10.0+cu113:
  Successfully uninstalled torch-1.10.0+cu113
Found existing installation: torchvision 0.11.0+cu113
Uninstalling torchvision-0.11.0+cu113:
  Successfully uninstalled torchvision-0.11.0+cu113
Found existing installation: torchaudio 0.10.0+rocm4.1
Uninstalling torchaudio-0.10.0+rocm4.1:
  Successfully uninstalled torchaudio-0.10.0+rocm4.1
[0mNote: you may need to restart the kernel to use updated packages.
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.10.0+cu113
  Downloading https://download.pytorch.org/whl/cu113/torch-1.10.0%2Bcu113-cp38-cp38-linux_x86_64.whl (1821.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m?[0m eta [36m0:00:00[0m[36m0:00:01[0mm00:02[0m
[?25hCollecting torchvision==0.11.0+cu113
  Downloading https://download.pytorch.org/whl/cu113/torchvision-0.11.0%2Bcu113-cp38-cp38-linux_x86_64.whl (21.8 MB)


### 1.2 Install dependencies

**Important**: Before running the following cells, please restart the runtime using the menu bar entries `Runtime > Restart runtime`

In [1]:
import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.version.cuda
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
## Install detectron2 that matches the above pytorch version
## See https://detectron2.readthedocs.io/tutorials/install.html for instructions
# %pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html

%pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html

## In case you have troubles with Detectron2, consider installing it from source instead. This takes a few minutes.
#!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

torch:  1.10 ; cuda:  11.3
Looking in links: https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
Collecting detectron2
  Downloading https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/detectron2-0.6%2Bcu113-cp38-cp38-linux_x86_64.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
Installing collected packages: detectron2
Successfully installed detectron2-0.6+cu113
Note: you may need to restart the kernel to use updated packages.


In [None]:
# clone and install Mask2Former
# !git clone https://github.com/facebookresearch/Mask2Former.git

%cd ../../Mask2Former
!pwd

%pip install -U opencv-python
%pip install git+https://github.com/cocodataset/panopticapi.git

%cd mask2former/modeling/pixel_decoder/ops
!python setup.py build install
%cd ../..

In [None]:
# Clone Dense Prediction Transformer repository
# !git clone https://github.com/isl-org/DPT

In [None]:
# Clone MultiMAE repository
# !git clone https://github.com/EPFL-VILAB/MultiMAE
%pip install timm==0.4.12
%pip install einops==0.3.2

## 2 Imports and model setup

**Important**: Before running the following cells, please restart the runtime **again** using the menu bar entries `Runtime > Restart runtime`

### 2.1 Imports

In [None]:
import sys
sys.path.append("./Mask2Former")
sys.path.append("./DPT")
sys.path.append("./MultiMAE")

# To supress DPT and Mask2Former warnings
import warnings
warnings.filterwarnings("ignore")

import os
from tqdm import tqdm
import random
from functools import partial

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision import datasets, transforms
from einops import rearrange

from PIL import Image
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

# Mask2Former and detectron2 dependencies for semantic segmentation pseudo labeling
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2.data import MetadataCatalog
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic")
from mask2former import add_maskformer2_config

# DPT dependencies for depth pseudo labeling
from dpt.models import DPTDepthModel

from multimae.models.input_adapters import PatchedInputAdapter, SemSegInputAdapter
from multimae.models.output_adapters import SpatialOutputAdapter
from multimae.models.multimae import pretrain_multimae_base
from multimae.utils.data_constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

torch.set_grad_enabled(False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### 2.2 Pseudo labeling model setup

In [None]:
# Initialize Omnidata depth model

!wget https://datasets.epfl.ch/vilab/iccv21/weights/omnidata_rgb2depth_dpt_hybrid.pth -P pretrained_models

omnidata_ckpt = torch.load('./pretrained_models/omnidata_rgb2depth_dpt_hybrid.pth', map_location='cpu')
depth_model = DPTDepthModel()
depth_model.load_state_dict(omnidata_ckpt)
depth_model = depth_model.to(device).eval()

def predict_depth(img):
  depth_model_input = (img.unsqueeze(0) - 0.5) / 0.5
  return depth_model(depth_model_input.to(device))

In [None]:
# COCO Mask2Former

cfg = get_cfg()
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
cfg.merge_from_file("../Mask2Former/configs/coco/panoptic-segmentation/swin/maskformer2_swin_small_bs16_50ep.yaml")
cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/maskformer/mask2former/coco/panoptic/maskformer2_swin_small_bs16_50ep/model_final_a407fd.pkl'
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = True
cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = True
semseg_model = DefaultPredictor(cfg)

def predict_semseg(img):
  return semseg_model(255*img.permute(1,2,0).numpy())['sem_seg'].argmax(0)

def plot_semseg(img, semseg, ax):
  v = Visualizer(img.permute(1,2,0), coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW)
  semantic_result = v.draw_sem_seg(semseg.cpu()).get_image()
  ax.imshow(semantic_result)

### 2.3  MultiMAE model setup

In [None]:
DOMAIN_CONF = {
    'rgb': {
        'input_adapter': partial(PatchedInputAdapter, num_channels=3, stride_level=1),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=3, stride_level=1),
    },
    'depth': {
        'input_adapter': partial(PatchedInputAdapter, num_channels=1, stride_level=1),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=1, stride_level=1),
    },
    'semseg': {
        'input_adapter': partial(SemSegInputAdapter, num_classes=133,
                                 dim_class_emb=64, interpolate_class_emb=False, stride_level=4),
        'output_adapter': partial(SpatialOutputAdapter, num_channels=133, stride_level=4),
    },
}
DOMAINS = ['rgb', 'depth', 'semseg']

input_adapters = {
    domain: dinfo['input_adapter'](
        patch_size_full=16,
    )
    for domain, dinfo in DOMAIN_CONF.items()
}
output_adapters = {
    domain: dinfo['output_adapter'](
        patch_size_full=16,
        dim_tokens=256,
        use_task_queries=True,
        depth=2,
        context_tasks=DOMAINS,
        task=domain
    )
    for domain, dinfo in DOMAIN_CONF.items()
}

multimae = pretrain_multimae_base(
    input_adapters=input_adapters,
    output_adapters=output_adapters,
)

CKPT_URL = 'https://github.com/EPFL-VILAB/MultiMAE/releases/download/pretrained-weights/multimae-b_98_rgb+-depth-semseg_1600e_multivit-afff3f8c.pth'
ckpt = torch.hub.load_state_dict_from_url(CKPT_URL, map_location='cpu')
multimae.load_state_dict(ckpt['model'], strict=False)
multimae = multimae.to(device).eval()
