In [1]:
from mae.decoder import  *
from mae.encoder import *
import torchvision

In [4]:
encoder_config = EncoderConfig(
    image_size=224,
    hidden_size=768,
    intermediate_size=3072,
    num_hidden_layers=12,
    num_attention_heads=12,
    num_channels=3,
    patch_size=16,
    layer_norm_eps=1e-6,
    attention_dropout=0.0,
    num_image_tokens=None,
    do_random_mask=True,
    mask_ratio=0.75
)

decoder_config = DecoderConfig(
    image_size=224,
    in_proj_dim=768,
    hidden_size=768,
    intermediate_size=3072,
    num_hidden_layers=12,
    num_attention_heads=12,
    num_channels=3,
    patch_size=16,
    layer_norm_eps=1e-6,
    attention_dropout=0.0,
    do_loss_calculation=True,
)

In [5]:
encoder = EncoderModel(encoder_config)
decoder = DecoderModel(decoder_config)

In [6]:
img = torchvision.io.read_image("dog.jpg")
# resize image to 224x224
img = torchvision.transforms.functional.resize(img, (224, 224)).unsqueeze(0)
# normalize image
img = img / 255.0

In [7]:
img = img.repeat(2, 1, 1, 1)
img.shape

torch.Size([2, 3, 224, 224])

In [8]:
encoder_op, mask, ids_restore = encoder(img)
encoder_op.shape

torch.Size([2, 49, 768])

In [9]:
op, loss = decoder((encoder_op, mask, ids_restore), img)
op.shape, loss.shape, loss

(torch.Size([2, 3, 224, 224]),
 torch.Size([]),
 tensor(0.8701, grad_fn=<MseLossBackward0>))

In [10]:
op

tensor([[[[ 0.7226,  0.3313, -0.4810,  ...,  0.0546, -0.6842, -0.0195],
          [ 0.5536, -0.8628,  0.4235,  ..., -0.6943,  0.8448, -0.1242],
          [-0.3881, -0.0930,  0.0834,  ...,  1.0735, -0.3513, -0.1858],
          ...,
          [-0.8425,  0.7944,  0.4843,  ..., -0.6284, -1.2117, -0.4688],
          [-0.1908, -0.0426,  0.0454,  ...,  0.4292, -0.7883,  0.2271],
          [ 0.1726, -0.0528, -1.1564,  ...,  0.6047, -0.0196, -0.2510]],

         [[-0.2031,  0.1729,  0.4070,  ...,  0.3841, -0.7907,  0.0975],
          [-0.0536, -0.4314,  0.0327,  ...,  0.6603,  0.2337,  0.1603],
          [ 0.4947, -0.2547, -0.2665,  ...,  0.1633, -0.6230,  0.0787],
          ...,
          [ 0.4314, -0.6206,  0.5678,  ..., -1.0114, -0.5201, -0.4156],
          [-0.1480, -0.0195, -0.1213,  ..., -0.3604, -0.2285, -0.1394],
          [ 0.2935,  0.5748,  0.2411,  ..., -0.1690,  0.2752, -0.1123]],

         [[-0.7959,  0.2464,  0.6870,  ...,  0.4973,  0.3346,  0.5141],
          [ 0.4473,  0.0768, -