## We will implement MaskedImageModelling (MIM) here 

In [None]:
#| default_exp swin3dmim

In [None]:
#| export 
import torch
import math
from typing import Tuple, List, Optional, Union
from medct.swin3d import Swin3dModel, Swin3dConfig, Swin3dPreTrainedModel
from transformers.models.swin.modeling_swin import SwinMaskedImageModelingOutput

How is encoder stride defined?
`config.patch_size * 2 (len(config.depths)-1) in each direction`

In [None]:
config = Swin3dConfig(image_size=(96, 192, 192), depths=[2, 2], num_heads=[3, 6], patch_size=(8, 16, 16), encoder_stride=(16, 32, 32))
model = Swin3dModel(config, add_pooling_layer=False, use_mask_token=True)

> Total patches at the input

In [None]:
num_patches = (model.config.image_size[0] // model.config.patch_size[0]) * \
              (model.config.image_size[1] // model.config.patch_size[1]) * \
              (model.config.image_size[2] // model.config.patch_size[2])
num_patches

1728

> How masking is done on the transformers repo? But this always give 50 50. so we will later define a custom report.

In [None]:
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
bool_masked_pos.unique(return_counts=True)

(tensor([False,  True]), tensor([830, 898]))

In [None]:
pixel_values = torch.randn((1, 1, 96, 192, 192))
pixel_values.shape

torch.Size([1, 1, 96, 192, 192])

In [None]:
outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)

In [None]:
outputs.last_hidden_state.shape

torch.Size([1, 216, 192])

In [None]:
outputs[0].shape

torch.Size([1, 216, 192])

In [None]:
sequence_output = outputs[0].transpose(1, 2)
batch_size, num_channels, sequence_length = sequence_output.shape
depth= height = width = math.ceil(sequence_length**(1/3))
sequence_output = sequence_output.reshape(batch_size, num_channels, depth, height, width)
sequence_output.shape

torch.Size([1, 192, 6, 6, 6])

In [None]:
num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
num_features

192

## Decoder 

In [None]:
d_stride, h_stride, w_stride = config.encoder_stride
decoder = torch.nn.Conv3d(in_channels=num_features, out_channels=d_stride* h_stride*w_stride * config.num_channels, kernel_size=1)
decoder

Conv3d(192, 16384, kernel_size=(1, 1, 1), stride=(1, 1, 1))

In [None]:
out = decoder(sequence_output)
out.shape

torch.Size([1, 16384, 6, 6, 6])

In [None]:
outputs[0]

tensor([[[-1.1033,  0.8807, -0.5596,  ..., -0.7480, -0.0764,  0.6527],
         [ 1.3849, -1.2244,  0.2439,  ...,  0.1049, -0.5222,  0.5008],
         [-0.2783,  0.4175,  0.1892,  ..., -0.6043, -0.2852, -0.2248],
         ...,
         [-2.0543, -1.5189,  2.1532,  ..., -0.5779, -0.2324,  0.5024],
         [ 0.2676, -2.0882, -1.3387,  ...,  0.2376, -0.2337,  0.2955],
         [ 0.3487, -0.5860, -0.1518,  ...,  0.9824,  0.1284, -0.4836]]],
       grad_fn=<NativeLayerNormBackward0>)

In [None]:
#| export 
# Copied from https://github.com/kuoweilai/pixelshuffle3d/blob/9be76091761caf3f3881eb5b3dc4b8da09315ab1/pixelshuffle3d.py#L6C1-L29C79
# Modified to support scale when it is different on different axis. 
class PixelShuffle3d(torch.nn.Module):
    '''
    This class is a 3d version of pixelshuffle.
    '''
    def __init__(self, scale):
        '''
        :param scale: upsample scale
        '''
        super().__init__()
        self.scale = scale

    def forward(self, input):
        batch_size, channels, in_depth, in_height, in_width = input.size()
        nOut = channels // (self.scale[0]*self.scale[1]*self.scale[2])

        out_depth = in_depth * self.scale[0]
        out_height = in_height * self.scale[1]
        out_width = in_width * self.scale[2]

        input_view = input.contiguous().view(batch_size, nOut, self.scale[0], self.scale[1], self.scale[2], in_depth, in_height, in_width)

        output = input_view.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()

        return output.view(batch_size, nOut, out_depth, out_height, out_width)

In [None]:
ps = PixelShuffle3d(config.encoder_stride)
ps

PixelShuffle3d()

In [None]:
reconstructed_pixel_values = ps(out)
reconstructed_pixel_values.shape

torch.Size([1, 1, 96, 192, 192])

## Caluculate loss 

In [None]:
size = (config.image_size[0] // config.patch_size[0], \
        config.image_size[1] // config.patch_size[1], \
        config.image_size[2] // config.patch_size[2])
size 

(12, 12, 12)

In [None]:
bool_masked_pos = bool_masked_pos.reshape(-1, size[0], size[1], size[2])
bool_masked_pos.shape

torch.Size([1, 12, 12, 12])

In [None]:
mask = (
    bool_masked_pos.repeat_interleave(config.patch_size[0], 1)
    .repeat_interleave(config.patch_size[1], 2)
    .repeat_interleave(config.patch_size[2], 3)
    .unsqueeze(1)
    .contiguous()
)
mask.shape

torch.Size([1, 1, 96, 192, 192])

In [None]:
reconstruction_loss = torch.nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / config.num_channels
masked_im_loss

tensor(0.9222, grad_fn=<DivBackward0>)

## Define your own masking stuff
At the top, we used random 50-50% masking. But what if we want to mask the image by a random % 

In [None]:
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
bool_masked_pos

tensor([[False, False,  True,  ..., False,  True,  True]])

In [None]:
len_keep = int(num_patches * (1 - 0.6))
len_keep

691

In [None]:
x = torch.cat([torch.zeros((len_keep)), torch.ones((num_patches-len_keep))])
x = x[torch.randperm(num_patches)].view(1, -1)
x

tensor([[1., 0., 1.,  ..., 1., 1., 1.]])

In [None]:
values, count = x.unique(return_counts=True)
values, count

(tensor([0., 1.]), tensor([ 691, 1037]))

In [None]:
count[1]/(count.sum())

tensor(0.6001)

In [None]:
#| export 
def mask_patches(num_patches, mask_ratio=0.5):
    len_keep = int(num_patches * (1 - mask_ratio))
    x = torch.cat([torch.zeros((len_keep)), torch.ones((num_patches-len_keep))])
    x = x[torch.randperm(num_patches)].view(1, -1)
    return x

In [None]:
x = mask_patches(num_patches, mask_ratio=0.1)
values, count = x.unique(return_counts=True)
count[0]/(count.sum())

tensor(0.8999)

In [None]:
count, x #mask a particular patch=keep it one here

(tensor([1555,  173]), tensor([[0., 0., 0.,  ..., 0., 0., 0.]]))

In [None]:
#| export 
# Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling
class Swin3dForMaskedImageModeling(Swin3dPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        if len(config.encoder_stride) !=3: raise NotImplementedError("The length of encoder stride should be 3")
        self.swin = Swin3dModel(config, add_pooling_layer=False, use_mask_token=True)

        num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
        d_stride, h_stride, w_stride = config.encoder_stride
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv3d(
                in_channels=num_features, out_channels=(d_stride*h_stride*w_stride) * config.num_channels, kernel_size=1
            ),
            PixelShuffle3d(config.encoder_stride),
        )
        
        self.num_patches = (self.config.image_size[0] // self.config.patch_size[0]) * \
                           (self.config.image_size[1] // self.config.patch_size[1]) * \
                           (self.config.image_size[2] // self.config.patch_size[2])

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SwinMaskedImageModelingOutput]:
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.swin(
            pixel_values,
            bool_masked_pos=bool_masked_pos,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        # Reshape to (batch_size, num_channels, depth, height, width)
        sequence_output = sequence_output.transpose(1, 2)
        batch_size, num_channels, sequence_length = sequence_output.shape
        depth= height = width = math.ceil(sequence_length**(1/3))
        sequence_output = sequence_output.reshape(batch_size, num_channels, depth, height, width)
        

        # Reconstruct pixel values
        reconstructed_pixel_values = self.decoder(sequence_output)

        masked_im_loss = None
        if bool_masked_pos is not None:
            size = (self.config.image_size[0] // self.config.patch_size[0],
                    self.config.image_size[1] // self.config.patch_size[1], 
                    self.config.image_size[2] // self.config.patch_size[2])
                    
            bool_masked_pos = bool_masked_pos.reshape(-1, size[0], size[1], size[2])
            mask = (bool_masked_pos.repeat_interleave(self.config.patch_size[0], 1)
                    .repeat_interleave(self.config.patch_size[1], 2)
                    .repeat_interleave(self.config.patch_size[2], 3)
                    .unsqueeze(1)
                    .contiguous()
                )
            reconstruction_loss = torch.nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
            masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels

        if not return_dict:
            output = (reconstructed_pixel_values,) + outputs[2:]
            return ((masked_im_loss,) + output) if masked_im_loss is not None else output

        return SwinMaskedImageModelingOutput(
            loss=masked_im_loss,
            reconstruction=reconstructed_pixel_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            reshaped_hidden_states=outputs.reshaped_hidden_states,
        )

In [None]:
config = Swin3dConfig(image_size=(96, 192, 192), depths=[2, 2], num_heads=[3, 6], patch_size=(8, 16, 16), encoder_stride=(16, 32, 32))
model = Swin3dForMaskedImageModeling(config)

In [None]:
model.num_patches

1728

In [None]:
bool_masked_pos = mask_patches(model.num_patches, 0.4)
bool_masked_pos

tensor([[0., 1., 0.,  ..., 1., 0., 0.]])

In [None]:
out = model(torch.randn((1, 1, )+model.config.image_size), bool_masked_pos=bool_masked_pos)

In [None]:
out.loss

tensor(0.8276, grad_fn=<DivBackward0>)

> TODO Say we want to calculate this at multiple levels? How do we do that?

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()