In [1]:
from mae.decoder import  *
from mae.encoder import *
import torchvision

In [2]:
encoder_config = EncoderConfig(
    image_size=224,
    hidden_size=768,
    intermediate_size=3072,
    num_hidden_layers=12,
    num_attention_heads=12,
    num_channels=3,
    patch_size=16,
    layer_norm_eps=1e-8,
    attention_dropout=0.0,
    num_image_tokens=None,
    do_random_mask=False,
    mask_ratio=0.75
)

decoder_config = DecoderConfig(
    image_size=224,
    in_proj_dim=768,
    hidden_size=768,
    intermediate_size=3072,
    num_hidden_layers=12,
    num_attention_heads=12,
    num_channels=3,
    patch_size=16,
    layer_norm_eps=1e-8,
    attention_dropout=0.0,
    do_loss_calculation=False,
)

In [3]:
encoder = EncoderModel(encoder_config)
decoder = DecoderModel(decoder_config)

In [4]:
img = torchvision.io.read_image("dog.jpg")
# resize image to 224x224
img = torchvision.transforms.functional.resize(img, (224, 224)).unsqueeze(0)
# normalize image
img = img / 255.0

In [5]:
img = img.repeat(2, 1, 1, 1)
img.shape

torch.Size([2, 3, 224, 224])

In [6]:
encoder_op, mask, ids_restore = encoder(img)
encoder_op.shape

torch.Size([2, 196, 768])

In [7]:
op, loss = decoder((encoder_op, mask, ids_restore), img)
op.shape, loss

(torch.Size([2, 3, 224, 224]), None)

In [8]:
op

tensor([[[[ 0.0803,  0.1212, -1.0792,  ...,  0.0895, -0.1893,  0.2833],
          [-0.4466,  0.7161,  0.4436,  ...,  0.8274,  0.1024, -0.2308],
          [-0.3762, -0.3539, -0.4450,  ..., -0.9515, -0.6690, -0.2326],
          ...,
          [ 0.2359, -1.0694, -0.0058,  ...,  1.2715, -0.6000,  0.1574],
          [-1.0460,  0.7951, -0.9496,  ...,  0.9215, -0.4596,  0.0249],
          [-0.9521,  1.2195, -0.5152,  ..., -0.0985, -0.0838, -0.1849]],

         [[ 0.4451, -0.6161, -0.2058,  ..., -0.2029,  0.1858,  0.1745],
          [-0.1391, -0.2809, -0.3045,  ...,  0.6473,  0.0908, -0.1560],
          [ 0.4616, -0.2726, -0.4020,  ..., -0.2546,  0.3106, -0.2373],
          ...,
          [ 0.4425, -0.3242,  0.9490,  ...,  0.2141,  0.4529,  0.8602],
          [ 0.3228,  0.2255,  1.1811,  ..., -0.3888,  1.2466,  0.8956],
          [-0.2798,  0.3041, -0.0846,  ...,  0.5497, -0.1276, -0.4146]],

         [[-0.2066,  0.1005, -0.1012,  ...,  0.1073,  0.0664, -1.0251],
          [ 0.6611, -0.6934, -