## Convnextv2

> ConvNeXt V2 is a pure convolutional model (ConvNext), inspired by the design of Vision Transformers


- In this paper they have made some changes to the ConvNext model to support `MAE`.
- Introduced `Global response normalization (GRN)` to enhance inter-channel feature competition. 
- ConvNext-V2 smallest variant called Atto has 3.7m and gets 76.7% top-1 accuracy on ImageNet, the 650M Huge model achieves 88.9% top-1 accuracy.

In [None]:
#| default_exp convnextv2

In [None]:
#| export 
import torch
import torch.nn as nn
import fastcore.all as fc

from typing import Optional, Union, Tuple
from transformers.models.convnextv2.modeling_convnextv2 import drop_path, ConvNextV2DropPath, ConvNextV2LayerNorm, BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
from transformers.modeling_outputs import BackboneOutput
from transformers.utils.backbone_utils import BackboneMixin

## MAE 
- 60% of the 32x32 patches are masked. 
- Unlike transformers where we can ignore the masked features, in conv it is difficult to achieve this. 
    - sparse convolution
    - apply binary masking operation before and after the dense conv operation. (theoretically more computationally expensive)

## drop path

In [None]:
x = torch.randn((4, 64, 64, 64, 128))
x.shape

torch.Size([4, 64, 64, 64, 128])

In [None]:
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
shape

(4, 1, 1, 1, 1)

In [None]:
keep_prob = 0.5
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
random_tensor.shape

torch.Size([4, 1, 1, 1, 1])

In [None]:
out = x*random_tensor
out.shape

torch.Size([4, 64, 64, 64, 128])

> So we can import `drop_path` and `ConvNextV2DropPath` from transformers itself as these will work for any dimension

## Global Response Normalization 
The performance of the convnext is not on par with transformers. After doing a feature space analysis, the potential issue with feature collapse at the MLP layer when training ConvNeXt directly on masked input. This norm techique tries to increase contrast and selectivity of channels.

- global feature aggregation
- feature normalization
- feature calibration

In [None]:
x = torch.arange(4).view(2, 2).float()
x

tensor([[0., 1.],
        [2., 3.]])

In [None]:
torch.norm(x, p=2, dim=(1, ), keepdim=True)

tensor([[1.0000],
        [3.6056]])

so for an image of shape (N, H, W, C) we will get (N, 1, 1, C)

In [None]:
x = torch.randn((2, 48, 48, 96))
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
gx.shape

torch.Size([2, 1, 1, 96])

In [None]:
nx = gx/gx.mean(dim=-1, keepdim=True)
nx.shape

torch.Size([2, 1, 1, 96])

In [None]:
#| export 
class ConvNextV2GRN3d(nn.Module):
    """GRN (Global Response Normalization) layer"""

    def __init__(self, dim: int):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))
        self.bias = nn.Parameter(torch.zeros(1, 1, 1, 1, dim))

    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        # Compute and normalize global spatial feature maps - N, H, W, D, C
        global_features = torch.norm(hidden_states, p=2, dim=(1, 2, 3), keepdim=True)
        norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
        hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
        return hidden_states

In [None]:
grn3d = ConvNextV2GRN3d(dim=128)
grn3d

ConvNextV2GRN3d()

In [None]:
x = torch.randn((4, 64, 64, 64, 128))+100
print(x.mean())
with torch.no_grad():
    out = grn3d(x)
    print(out.mean()) # same because we have weight sand bias as 0

tensor(100.0001)
tensor(100.0001)


## LayerNorm
In ConvNext block we have LayerNorm at >
![Layernorm in Convnext block](../assets/convnextv2-v1.png)

In the original code we have channel last and channel first too. we will only implement channel last here. We can use the Norm which is present in HugggingFace itself.

In [None]:
#| export 
# Copied from transformers.models.convnextv2.modeling_convnextv2.ConvNextV2LayerNorm 
class ConvNextV2LayerNorm3d(nn.Module):
    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError(f"Unsupported data format: {self.data_format}")
        self.normalized_shape = (normalized_shape,)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.data_format == "channels_last":
            x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            input_dtype = x.dtype
            x = x.float()
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = x.to(dtype=input_dtype)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] #Only this line changed
        return x

In [None]:
net=ConvNextV2LayerNorm3d(normalized_shape=128, data_format="channels_first")
net

ConvNextV2LayerNorm3d()

In [None]:
x = torch.randn((4, 128, 64, 64, 64))
out = net(x)
out.shape

torch.Size([4, 128, 64, 64, 64])

In [None]:
net=ConvNextV2LayerNorm3d(normalized_shape=128, data_format="channels_last")
x = torch.randn((4, 64, 64, 64, 128))
out = net(x)
out.shape

torch.Size([4, 64, 64, 64, 128])

## ConvNextV2Embedding 

In [None]:
layer = nn.Conv3d(3, 40, kernel_size=4, stride=4)
out = layer(torch.randn((2, 3, 96, 192, 192)))
out.shape

torch.Size([2, 40, 24, 48, 48])

In [None]:
#| export 
# Copied from transformers.models.convnextv2.modeling_convnextv2.ConvNextV2Embeddings with 
# ConvNextV2Embeddings -> ConvNextV2Embeddings3D
class ConvNextV2Embeddings3d(nn.Module):
    """This class is comparable to (and inspired by) the SwinEmbeddings class
    found in src/transformers/models/swin/modeling_swin.py.
    """

    def __init__(self, config):
        super().__init__()
        self.patch_embeddings = nn.Conv3d(
            config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
        )
        self.layernorm = ConvNextV2LayerNorm3d(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
        self.num_channels = config.num_channels

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        num_channels = pixel_values.shape[1]
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        embeddings = self.patch_embeddings(pixel_values)
        embeddings = self.layernorm(embeddings)
        return embeddings

In [None]:
class config:
    num_channels = 3
    hidden_sizes = [40, 80, 160, 320] #Atto depths=[2, 2, 6, 2], dims=[40, 80, 160, 320]
    patch_size = 4

In [None]:
embed = ConvNextV2Embeddings3d(config())
embed

ConvNextV2Embeddings3d(
  (patch_embeddings): Conv3d(3, 40, kernel_size=(4, 4, 4), stride=(4, 4, 4))
  (layernorm): ConvNextV2LayerNorm3d()
)

In [None]:
x = torch.randn((2, 3, 96, 192, 192))
embed_out = embed(x)
embed_out.shape

torch.Size([2, 40, 24, 48, 48])

## ConvNextV2Layer3D 

In [None]:
#| export
# Copied from transformers.models.convnextv2.modeling_convnextv2.ConvNextV2Layer 
class ConvNextV2Layer3d(nn.Module):
    """This corresponds to the `Block` class in the original implementation.

    There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
    H, W) (2) [DwConv, Permute to (N, H, W, D, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back

    The authors used (2) as they find it slightly faster in PyTorch.

    Args:
        config ([`ConvNextV2Config3D`]): Model configuration class.
        dim (`int`): Number of input channels.
        drop_path (`float`): Stochastic depth rate. Default: 0.0.
    """

    def __init__(self, config, dim, drop_path=0):
        super().__init__()
        # depthwise conv
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.layernorm = ConvNextV2LayerNorm3d(dim, eps=1e-6)
        # pointwise/1x1 convs, implemented with linear layers
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.act = ACT2FN[config.hidden_act]
        self.grn = ConvNextV2GRN3d(4 * dim)
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.drop_path = ConvNextV2DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
        input = hidden_states
        x = self.dwconv(hidden_states)
        # (batch_size, num_channels, height, width, deoth) -> (batch_size, height, width, depth, num_channels)
        x = x.permute(0, 2, 3, 4, 1)
        x = self.layernorm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        # (batch_size, height, width, depth, num_channels) -> (batch_size, num_channels, height, width, depth)
        x = x.permute(0, 4, 1, 2, 3)

        x = input + self.drop_path(x)
        return x

In [None]:
class config:
    num_channels = 3
    hidden_sizes = [40, 80, 160, 320] #Atto depths=[2, 2, 6, 2], dims=[40, 80, 160, 320]
    patch_size = 4
    hidden_act = "gelu"

In [None]:
layer = ConvNextV2Layer3d(config, dim=40)
layer

ConvNextV2Layer3d(
  (dwconv): Conv3d(40, 40, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=40)
  (layernorm): ConvNextV2LayerNorm3d()
  (pwconv1): Linear(in_features=40, out_features=160, bias=True)
  (act): GELUActivation()
  (grn): ConvNextV2GRN3d()
  (pwconv2): Linear(in_features=160, out_features=40, bias=True)
  (drop_path): Identity()
)

In [None]:
layer_out = layer(embed_out)
layer_out.shape

torch.Size([2, 40, 24, 48, 48])

## ConvNextV2Stage3D

In [None]:
#| export
# Copied from transformers.models.convnextv2.modeling_convnextv2.ConvNextV2Stage 
class ConvNextV2Stage3d(nn.Module):
    """ConvNeXTV23D stage, consisting of an optional downsampling layer + multiple residual blocks.

    Args:
        config ([`ConvNextV2Config3D`]): Model configuration class.
        in_channels (`int`): Number of input channels.
        out_channels (`int`): Number of output channels.
        depth (`int`): Number of residual blocks.
        drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
    """

    def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
        super().__init__()

        if in_channels != out_channels or stride > 1:
            self.downsampling_layer = nn.Sequential(
                ConvNextV2LayerNorm3d(in_channels, eps=1e-6, data_format="channels_first"),
                nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
            )
        else:
            self.downsampling_layer = nn.Identity()
        drop_path_rates = drop_path_rates or [0.0] * depth
        self.layers = nn.Sequential(
            *[ConvNextV2Layer3d(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
        )

    def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
        hidden_states = self.downsampling_layer(hidden_states)
        hidden_states = self.layers(hidden_states)
        return hidden_states

## ConvNextV2Encoder3D

In [None]:
#| export
# Copied from transformers.models.convnextv2.modeling_convnextv2.ConvNextV2Encoder 
class ConvNextV2Encoder3d(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.stages = nn.ModuleList()
        drop_path_rates = [
            x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
        ]
        prev_chs = config.hidden_sizes[0]
        for i in range(config.num_stages):
            out_chs = config.hidden_sizes[i]
            stage = ConvNextV2Stage3d(
                config,
                in_channels=prev_chs,
                out_channels=out_chs,
                stride=2 if i > 0 else 1,
                depth=config.depths[i],
                drop_path_rates=drop_path_rates[i],
            )
            self.stages.append(stage)
            prev_chs = out_chs

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
        all_hidden_states = () if output_hidden_states else None

        for i, layer_module in enumerate(self.stages):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            hidden_states = layer_module(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)

        return BaseModelOutputWithNoAttention(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
        )



In [None]:
#| export 
# Copied from transformers.models.convnextv2.configuration_convnextv2.ConvNextV2Config
class ConvNextV2Config3d(BackboneConfigMixin, PretrainedConfig):
    model_type = "convnextv23d"

    def __init__(
        self,
        num_channels=3,
        patch_size=4,
        num_stages=2,
        hidden_sizes=None,
        depths=None,
        hidden_act="gelu",
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        drop_path_rate=0.0,
        image_size=224,
        out_features=None,
        out_indices=None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.num_channels = num_channels
        self.patch_size = patch_size
        self.num_stages = num_stages
        self.hidden_sizes = [40, 80] if hidden_sizes is None else hidden_sizes
        self.depths = [3, 3] if depths is None else depths
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.drop_path_rate = drop_path_rate
        self.image_size = image_size
        self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
        self._out_features, self._out_indices = get_aligned_output_features_output_indices(
            out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
        )

In [None]:
encoder = ConvNextV2Encoder3d(ConvNextV2Config3d())
len(encoder.stages)

2

In [None]:
eout = encoder(embed_out, output_hidden_states=True)

In [None]:
[i.shape for i in eout.hidden_states]

[torch.Size([2, 40, 24, 48, 48]),
 torch.Size([2, 40, 24, 48, 48]),
 torch.Size([2, 80, 12, 24, 24])]

## Get Pretrained Model 

In [None]:
#| export
# Copied from transformers.models.convnextv2.modeling_convnextv2.ConvNextV2PreTrainedModel 
class ConvNextV2PreTrainedModel3d(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = ConvNextV2Config3d
    base_model_prefix = "convnextv2_3d"
    main_input_name = "pixel_values"

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv3d)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

## Final model 

In [None]:
#| export 
# Copied from transformers.models.convnextv2.modeling_convnextv2.ConvNextV2Model 
class ConvNextV2Model3d(ConvNextV2PreTrainedModel3d):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.embeddings = ConvNextV2Embeddings3d(config)
        self.encoder = ConvNextV2Encoder3d(config)

        # final layernorm layer
        self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        embedding_output = self.embeddings(pixel_values)

        encoder_outputs = self.encoder(
            embedding_output,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]

        # global average pooling, (N, C, H, W, D) -> (N, C)
        pooled_output = self.layernorm(last_hidden_state.mean([-3, -2, -1]))

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndNoAttention(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
        )

In [None]:
config = ConvNextV2Config3d(image_size=(48, 96, 96))
model = ConvNextV2Model3d(config)

In [None]:
param_count = 0
for name, params in model.named_parameters():
    param_count+= params.shape.numel()
param_count #1.5 million params

354960

In [None]:
%%time
x = torch.randn((2, 3, 96, 96, 96))
with torch.no_grad():
    out = model(x, output_hidden_states=True)

CPU times: user 1.59 s, sys: 898 ms, total: 2.49 s
Wall time: 656 ms


In [None]:
[i.shape for i in out.hidden_states]

[torch.Size([2, 40, 24, 24, 24]),
 torch.Size([2, 40, 24, 24, 24]),
 torch.Size([2, 80, 12, 12, 12])]

In [None]:
out.last_hidden_state.shape

torch.Size([2, 80, 12, 12, 12])

In [None]:
out.pooler_output.shape

torch.Size([2, 80])

## Defining the bacbone itself. 

In [None]:
#| export 
# Copied from transformers.models.convnext.modeling_convnext.ConvNextBackbone with CONVNEXT->CONVNEXTV2,ConvNext->ConvNextV2,facebook/convnext-tiny-224->facebook/convnextv2-tiny-1k-224
class ConvNextV2Backbone3d(ConvNextV2PreTrainedModel3d, BackboneMixin):
    def __init__(self, config):
        super().__init__(config)
        super()._init_backbone(config)

        self.embeddings = ConvNextV2Embeddings3d(config)
        self.encoder = ConvNextV2Encoder3d(config)
        self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes

        # Add layer norms to hidden states of out_features
        hidden_states_norms = {}
        for stage, num_channels in zip(self._out_features, self.channels):
            hidden_states_norms[stage] = ConvNextV2LayerNorm3d(num_channels, data_format="channels_first")
        self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)

        # initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        pixel_values: torch.Tensor,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> BackboneOutput:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        embedding_output = self.embeddings(pixel_values)

        outputs = self.encoder(
            embedding_output,
            output_hidden_states=True,
            return_dict=True,
        )

        hidden_states = outputs.hidden_states

        feature_maps = ()
        # we skip the stem
        for idx, (stage, hidden_state) in enumerate(zip(self.stage_names[1:], hidden_states[1:])):
            if stage in self.out_features:
                hidden_state = self.hidden_states_norms[stage](hidden_state)
                feature_maps += (hidden_state,)

        if not return_dict:
            output = (feature_maps,)
            if output_hidden_states:
                output += (outputs.hidden_states,)
            return output

        return BackboneOutput(
            feature_maps=feature_maps,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=None,
        )

In [None]:
config = ConvNextV2Config3d(image_size=(96, 192, 192), num_channels=1, patch_size=(2, 4, 4), out_features=["stage1", "stage2"])
model = ConvNextV2Backbone3d(config)

In [None]:
x = torch.randn((2, 1, 96, 192, 192))
with torch.no_grad():
    out = model(x, output_hidden_states=True)

In [None]:
[i.shape for i in out.feature_maps]

[torch.Size([2, 40, 48, 48, 48]), torch.Size([2, 80, 24, 24, 24])]

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()