> Masked Image Modelling (MIM) on ConvNextV2

Im not sure if this works but I am gonna give this a try. we can very well mask and fill embeddings of convnextv2 with mask tokens similar to Swin Transformers. I am not sure why ConvNextV2 implemented sparse conv to achieve the same. 

In [None]:
#| default_exp convnextv2mim

In [None]:
#| export 
import torch

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from transformers.utils import ModelOutput
from medct.convnextv2 import ConvNextV2Model3d, ConvNextV2Config3d, ConvNextV2PreTrainedModel3d
from medct.swin3dmim import PixelShuffle3d

In [None]:
#| export 
# Copied from medct.swin3dmim.mask_patches
def mask_patches(num_patches, mask_ratio=0.5):
    len_keep = int(num_patches * (1 - mask_ratio))
    x = torch.cat([torch.zeros((len_keep)), torch.ones((num_patches-len_keep))])
    x = x[torch.randperm(num_patches)].view(1, -1)
    return x

In [None]:
#| export 
@dataclass
class ConvNextV2MaskedImageModelingOutput(ModelOutput):
    """
    Swin masked image model outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Masked image modeling (MLM) loss.
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Reconstructed pixel values.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
    """

    loss: Optional[torch.FloatTensor] = None
    reconstruction: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None

In [None]:
config = ConvNextV2Config3d(num_channels=1, image_size=(96, 192, 192), patch_size=(8, 16, 16), hidden_sizes=[40, 80], depths=[2, 2])
model = ConvNextV2Model3d(config, use_mask_token=True)

In [None]:
out = model(torch.randn((1, 1, 96, 192, 192)))

In [None]:
out[0].shape

torch.Size([1, 80, 6, 6, 6])

In [None]:
len(config.hidden_sizes)

2

In [None]:
#| export 
# Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling
class ConvNextV2ForMaskedImageModeling(ConvNextV2PreTrainedModel3d):
    def __init__(self, config):
        super().__init__(config)
        num_layers = len(config.hidden_sizes)
        config.encoder_stride = (config.patch_size[0]*num_layers, 
                                 config.patch_size[1]*num_layers, 
                                 config.patch_size[2]*num_layers)
        self.num_patches = (config.image_size[0] // config.patch_size[0]) * \
                           (config.image_size[1] // config.patch_size[1]) * \
                           (config.image_size[2] // config.patch_size[2])
        if len(config.encoder_stride) !=3: raise NotImplementedError("The length of encoder stride should be 3")
        self.model = ConvNextV2Model3d(config, use_mask_token=True)

        num_features = config.hidden_sizes[-1]
        d_stride, h_stride, w_stride = config.encoder_stride
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv3d(
                in_channels=num_features, out_channels=(d_stride*h_stride*w_stride) * config.num_channels, kernel_size=1
            ),
            PixelShuffle3d(config.encoder_stride),
        )
        
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, ConvNextV2MaskedImageModelingOutput]:
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model(
            pixel_values,
            bool_masked_pos=bool_masked_pos,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
#         # Reshape to (batch_size, num_channels, depth, height, width)
#         sequence_output = sequence_output.transpose(1, 2)
#         batch_size, num_channels, sequence_length = sequence_output.shape
#         depth= height = width = math.ceil(sequence_length**(1/3))
#         sequence_output = sequence_output.reshape(batch_size, num_channels, depth, height, width)
        

        # Reconstruct pixel values
        reconstructed_pixel_values = self.decoder(sequence_output)

        masked_im_loss = None
        if bool_masked_pos is not None:
            size = (self.config.image_size[0] // self.config.patch_size[0],
                    self.config.image_size[1] // self.config.patch_size[1], 
                    self.config.image_size[2] // self.config.patch_size[2])
                    
            bool_masked_pos = bool_masked_pos.reshape(-1, size[0], size[1], size[2])
            mask = (bool_masked_pos.repeat_interleave(self.config.patch_size[0], 1)
                    .repeat_interleave(self.config.patch_size[1], 2)
                    .repeat_interleave(self.config.patch_size[2], 3)
                    .unsqueeze(1)
                    .contiguous()
                )
            reconstruction_loss = torch.nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
            masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels

        if not return_dict:
            output = (reconstructed_pixel_values,) + outputs[2:]
            return ((masked_im_loss,) + output) if masked_im_loss is not None else output

        return ConvNextV2MaskedImageModelingOutput(
            loss=masked_im_loss,
            reconstruction=reconstructed_pixel_values,
            hidden_states=outputs.hidden_states,
        )

In [None]:
mim = ConvNextV2ForMaskedImageModeling(config)
mim

ConvNextV2ForMaskedImageModeling(
  (model): ConvNextV2Model3d(
    (embeddings): ConvNextV2Embeddings3d(
      (patch_embeddings): Conv3d(1, 40, kernel_size=(8, 16, 16), stride=(8, 16, 16))
      (layernorm): ConvNextV2LayerNorm3d()
    )
    (encoder): ConvNextV2Encoder3d(
      (stages): ModuleList(
        (0): ConvNextV2Stage3d(
          (downsampling_layer): Identity()
          (layers): Sequential(
            (0): ConvNextV2Layer3d(
              (dwconv): Conv3d(40, 40, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=40)
              (layernorm): ConvNextV2LayerNorm3d()
              (pwconv1): Linear(in_features=40, out_features=160, bias=True)
              (act): GELUActivation()
              (grn): ConvNextV2GRN3d()
              (pwconv2): Linear(in_features=160, out_features=40, bias=True)
              (drop_path): Identity()
            )
            (1): ConvNextV2Layer3d(
              (dwconv): Conv3d(40, 40, kernel_size=(7, 7, 7), stride=(1, 

In [None]:
bool_masked_pos = bool_masked_pos = torch.randint(low=0, high=2, size=(1, mim.num_patches)).bool()
out = mim(torch.randn((1, 1, 96, 192, 192)), bool_masked_pos=bool_masked_pos)

torch.Size([1, 1728, 40]) torch.Size([1, 1728, 1])


In [None]:
out.loss

tensor(0.7998, grad_fn=<DivBackward0>)

In [None]:
out.reconstruction.shape

torch.Size([1, 1, 96, 192, 192])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()