In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import sys

import matplotlib.pyplot as plt

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import timm

assert timm.__version__ == "0.3.2"  # version check
import timm.optim.optim_factory as optim_factory

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler

import models_mae

from engine_pretrain import train_one_epoch
from main_train import main

test_size = 100

In [None]:
world_size = torch.cuda.device_count()
print('world_size = %d' % world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['WORLD_SIZE'] = str(world_size)

torch.multiprocessing.spawn(main, nprocs=world_size, args=(world_size,))

In [None]:
eff_batch_size = 200 * 1 * misc.get_world_size()
lr = 1e-4 * eff_batch_size / 256

model = models_mae.__dict__['mae_vit'](norm_pix_loss=False)
model.to('cuda')

param_groups = optim_factory.add_weight_decay(model, 0.3)
optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.95))

checkpoint = torch.load('./output_dir/checkpoint-199.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

In [None]:
transform = transforms.Compose([
            transforms.Resize((32, 32), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset_test = datasets.ImageFolder(os.path.join('./data/cifar10', 'test'), transform=transform)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=test_size, shuffle=False)
images, labels = next(iter(data_loader_test))

In [None]:
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title='', sat=False):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    if not sat:
        plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    else:
        plt.imshow(image.int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def run_one_image(x, model, sat=False):
    # make it a batch-like
    x = x.unsqueeze(dim=0)
    # x = torch.einsum('nhwc->nchw', x)

    # run MAE
    x = x.to('cuda')
    loss, y, mask = model(x, mask_ratio=0.75)
    x = x.detach().cpu()
    
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', x).detach().cpu()

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original", sat)

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked", sat)

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction", sat)

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible", sat)

    plt.show()

In [None]:
idx = np.random.randint(0, test_size)
print(f'image id: {idx} - label: {labels[idx].item()}')
image_test = images[idx]
print(image_test.shape)
run_one_image(image_test, model)
run_one_image(image_test, model, sat=True)