In [1]:
from encoder import *
from decoder import *
import torchvision

In [2]:
encoder_config = EncoderConfig(
    image_size=224,
    hidden_size=768,
    intermediate_size=3072,
    num_hidden_layers=12,
    num_attention_heads=12,
    num_channels=3,
    patch_size=16,
    layer_norm_eps=1e-6,
    attention_dropout=0.0,
    num_image_tokens=None,
    do_random_mask=True,
    mask_ratio=0.75
)

decoder_config = DecoderConfig(
    image_size=224,
    in_proj_dim=768,
    hidden_size=768,
    intermediate_size=3072,
    num_hidden_layers=12,
    num_attention_heads=12,
    num_channels=3,
    patch_size=16,
    layer_norm_eps=1e-6,
    attention_dropout=0.0,
    do_loss_calculation=True,
    do_norm_pix_loss=False
)

In [3]:
encoder = EncoderModel(encoder_config)
decoder = DecoderModel(decoder_config)

In [4]:
img = torchvision.io.read_image("dog.jpg")
# resize image to 224x224
img = torchvision.transforms.functional.resize(img, (224, 224)).unsqueeze(0)
# normalize image
img = img / 255.0

In [5]:
img = img.repeat(2, 1, 1, 1)
img.shape

torch.Size([2, 3, 224, 224])

In [6]:
encoder_op, mask, ids_restore, target = encoder(img)
encoder_op.shape

torch.Size([2, 49, 768])

In [7]:
op, loss = decoder((encoder_op, mask, ids_restore), target)
op.shape, loss

(torch.Size([2, 196, 768]), tensor(1.4868, grad_fn=<DivBackward0>))

In [8]:
# def mae_loss(pred, target, mask):
#     """
#     Compute the MAE loss for masked patches only.

#     Args:
#     - pred: Predicted patches, tensor of shape [batch_size, num_patches, channels].
#     - target: Ground truth patches, tensor of shape [batch_size, num_patches, channels].
#     - mask: Binary mask tensor of shape [batch_size, num_patches], where 1 indicates a masked patch, 0 indicates a visible patch.

#     Returns:
#     - loss: MSE loss computed only over the masked patches.
#     """
#     # Reshape mask to be broadcastable over the [batch_size, num_patches, channels] dimension
#     mask = mask.unsqueeze(-1)  # Shape: [batch_size, num_patches, 1]

#     # Apply the mask: Only keep the masked patches (where mask == 1)
#     masked_pred = pred * mask  # Shape: [batch_size, num_patches, channels]
#     masked_target = target * mask  # Shape: [batch_size, num_patches, channels]

#     # Compute MSE over masked patches
#     loss_fn = nn.MSELoss(reduction='sum')  # Sum over all elements

#     # Normalize the loss by the number of masked patches
#     num_masked_patches = mask.sum()  # Total number of masked patches
#     loss = loss_fn(masked_pred, masked_target) / num_masked_patches

#     return loss