## Masked Autoencoders: Visualization Demo

## Prepare


In [1]:
import sys
import os
import requests

import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image

# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
    !git clone https://github.com/sysu19351176/Change_Detection_MAE.git
    sys.path.append('./mae')
else:
    sys.path.append('..')
import models_mae

Running in Colab.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm==0.4.5
  Downloading timm-0.4.5-py3-none-any.whl (287 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m287.4/287.4 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.4.5
Cloning into 'Change_Detection_MAE'...
remote: Enumerating objects: 44, done.[K
remote: Counting objects: 100% (2/2), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 44 (delta 0), reused 0 (delta 0), pack-reused 42[K
Unpacking objects: 100% (44/44), 867.96 KiB | 7.82 MiB/s, done.


ModuleNotFoundError: ignored

### Define utils

In [None]:
# define the utils

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(imgt1,imgt2,img_label,model):
    x_t1 = torch.tensor(imgt1)
    x_t2=torch.tensor(imgt2)
    x_label=torch.tensor(img_label)


    # make it a batch-like
    x_t1 = x_t1.unsqueeze(dim=0)
    x_t1 = torch.einsum('nhwc->nchw', x_t1)

    x_t2 = x_t2.unsqueeze(dim=0)
    x_t2 = torch.einsum('nhwc->nchw', x_t2)

    x_label = x_label.unsqueeze(dim=0)
    x_label = torch.einsum('nhwc->nchw', x_label)

    # run MAE
    loss, y, mask = model(x_t1.float(), x_t2.float(),x_label.float())
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()
  

### Load an image

In [None]:
# load an image
img_label = 'https://uc78c92a63fd851f38bd835aabf1.previews.dropboxusercontent.com/p/thumb/AB1U5YAbVbuOfTvFY969B1xJ4Fg9wQbwPDW9jUjwtZGRNqY8O-dHYDqoun-8nFZPJWIlCTJ06IfP7lFm7a0qkReu7okYxkGafK-Q-T-OcGuCqtrhhVZKLJzFXdRxlaKiG6DeiT-VFKuDIrnHXQFGf9JV86Eu1zr_R1AN7z6GZmstckAUAQme1duCGROdkhNPkW4yenjHQNXb6L8Ng3nqZ_CAvWNziaiEGrpDpJhnjqNNAyK44OsjCmdgN0YTeWDazU-IQKyhsIY72mtEJfRsIu-vwnrU2C9tWVtQG28pxVkgssqQNBpoE_JoWKvEvO-YqIRyP34rhFdjNyx3e9peA1kU3u1dh0nNG-jiBJKZek-Pf1bi0SCV9rxyCK0sza5j4AA/p.png' # from LEVIR-CD256
img_t1_url = 'https://uc65bd97fb2efd30aa82a917f0da.previews.dropboxusercontent.com/p/thumb/AB3JhQPRZZo1G8N8f4J4Ff0jLEgwUCdiuCZcgkMQax_QAOf_U2GZtsX1sDP827daEwfopipFifOSddnXwM0SrqcupcZ_gE5FCCTVCO93XgnglkhI5hz_OPRVISsDuQ_-S6BzkVk_XH_L6IUjGJ1UEHmDzs9INRSN84g_5BG6T2V5PN-HgavsNXRVxmN94iJIa6AssImCwO8kOCYTRyLPLZZ8UdM8-9Ux5f7Z_rYnlLJ1uoSvMVsMCx2KTYwealnhDgR5Wb-_dkAMeHc9ply_NGBTXrubY1-WJGTppsOwxB2Ua6zSjk7Mj5mCXdtkk5Wz1R6Mn1cp5qW_3fPVGh_qFmCzMBBOrvXgFTXiWp9_nVgda-iIfyxyYbmqwqeeTVGSA3TBMO6vof_E2qIdxtJ1anzU/p.png?size=512x512&size_mode=1' 
img_t2_url ='https://uc517512112389a6ff959acb5f0e.previews.dropboxusercontent.com/p/thumb/AB1UclKlzzcY0B93FUMREDi0yWBqo_H7ZvliJuVAVcxz7kNVCHsJuvpUYiFVuDzFcrz_zCrG-JfWjqZWB_ROeo7SYu0uDI9AtES0fBWXsh--KbTOyRibpPLuJxB1jDGDdtOU0MNjtAj0Dzb6oTsOqYmKuAdmqeHcnD_4VA27_SxsH7ezN8CjqShCmshTGer5VYkfDtPNIEd_vAFRgXw77ylAJs6u2oa9H8CpfJaFTPysUYZaXuk3thKS2t7VRsX7vM2y98V9mHwrv0VsOVzYT97zXD3dhxXBun2YtuR2n1vHwXb0udvuZYaJCqyVx_gVLI4yZdX_x0gJ3wl6DWHs5pN4yi3t-K2TRv8ALnjqqJc6nQWpgKJ07deK61sK4ip-XaA/p.png'
img_t1 = Image.open(requests.get(img_t1_url, stream=True).raw)
img_t1 = img_t1.resize((256, 256))
img_t1 = np.array(img_t1) / 255.

img_t2 = Image.open(requests.get(img_t2_url, stream=True).raw)
img_t2 = img_t2.resize((256, 256))
img_t2 = np.array(img_t2) / 255.

assert img_t1.shape == (256, 256, 3) and img_t1.shape == (256, 256, 3)

# normalize by ImageNet mean and std
img_t1 = img_t1 - imagenet_mean
img_t1 = img_t1 / imagenet_std

img_t2 = img_t2 - imagenet_mean
img_t2 = img_t2 / imagenet_std

plt.rcParams['figure.figsize'] = [5, 5]

plt.subplot(1, 2, 1)
show_image(img_t1, "Tl")

plt.subplot(1, 2, 2)
show_image(img_t2, "T2")



### Load a pre-trained MAE model

In [None]:
# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

chkpt_dir = 'mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')


### Run MAE on the image

In [None]:
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
run_one_image(img, model_mae)

### Load another pre-trained MAE model

In [None]:
# This is an MAE model trained with an extra GAN loss for more realistic generation (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth

chkpt_dir = 'mae_visualize_vit_large_ganloss.pth'
model_mae_gan = prepare_model('mae_visualize_vit_large_ganloss.pth', 'mae_vit_large_patch16')
print('Model loaded.')

### Run MAE on the image

In [None]:
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with extra GAN loss:')
run_one_image(img, model_mae_gan)