In [None]:
from typing import Tuple, Union

import torch
import torch.nn as nn

from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.nets import ViT


class UNETR(nn.Module):
    """
    UNETR based on: "Hatamizadeh et al.,
    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        img_size: Tuple[int, int, int],
        feature_size: int = 16, # - Feature size is 64 in the image
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_heads: int = 12,
        pos_embed: str = "perceptron",
        norm_name: Union[Tuple, str] = "instance",
        conv_block: bool = False,
        res_block: bool = True,
        dropout_rate: float = 0.0,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            img_size: dimension of input image.
            feature_size: dimension of network feature size.
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            norm_name: feature normalization type and arguments.
            conv_block: bool argument to determine if convolutional block is used.
            res_block: bool argument to determine if residual block is used.
            dropout_rate: faction of the input units to drop.

        Examples::

            # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')

            # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm
            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')

        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise AssertionError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise AssertionError("hidden size should be divisible by num_heads.")

        if pos_embed not in ["conv", "perceptron"]:
            raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")

        # ---------------------------
        # GLOBAL CONFIG (matches paper figure)
        
        
        # ---------------------------
        self.num_layers = 12 # - 12 transformer layers (z1..z12 in the paper)
        self.patch_size = (16, 16, 16) # - Patch size 16^3: H×W×D is split into (H/16)×(W/16)×(D/16) tokens
        self.feat_size = (
            img_size[0] // self.patch_size[0],
            img_size[1] // self.patch_size[1],
            img_size[2] // self.patch_size[2],
        )
        self.hidden_size = hidden_size

        # ---------------------------------------------------------------------
        # [Left column in the figure]
        # ViT encoder = Patch Embedding + Positional Embedding + 12×(MHSA+MLP)
        # Returns x and hidden_states_out:
        #   x  -> final transformer output (≈ z12 in the figure)
        #   hidden_states_out -> list of intermediate states to tap at z3, z6, z9
        # ---------------------------------------------------------------------
        self.classification = False
        self.vit = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=self.patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=self.num_layers,
            num_heads=num_heads,
            pos_embed=pos_embed,
            classification=self.classification,
            dropout_rate=dropout_rate,
        )

        # ---------------------------------------------------------------------
        # [Top skip path in the figure]
        # A pure CNN block working directly on the input image at full resolution
        # (yellow "Conv 3×3×3 + ReLU (+Norm)" boxes at H×W×D)
        # Provides the highest‑res skip to the decoder.
        # ---------------------------------------------------------------------
        self.encoder1 = UnetrBasicBlock(
            spatial_dims=3,
            in_channels=in_channels,  # = 4 in MRI = channels of input image
            out_channels=feature_size,  # = 64 in the image
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )

        # ---------------------------------------------------------------------
        # Projection‑&‑up blocks fed by ViT taps:
        # We first reshape tokens back to a 3D feature map (proj_feat), then
        # progressively upsample so their spatial sizes align with the decoder.
        #
        # enc2  ⇐ z3  (hidden_states_out[3])  → upsample to ~ H/2 × W/2 × D/2
        # enc3  ⇐ z6  (hidden_states_out[6])  → upsample to ~ H/4 × W/4 × D/4
        # enc4  ⇐ z9  (hidden_states_out[9])  → upsample to ~ H/8 × W/8 × D/8
        # ---------------------------------------------------------------------
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,          # = 768 = channels of ViT token embeddings
            out_channels=feature_size * 2,    # 
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            num_layer=0,  # num_layers represents extra upsample blocks, so 0 means only the initial upsample
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )

        # ---------------------------------------------------------------------
        # [Decoder path in the figure]
        # Each UnetrUpBlock performs: upsample ("Deconv 2×2×2") + concat
        # with the matching skip (blue "C" circles) + Conv blocks.
        #
        # decoder5 starts from the deepest transformer output (≈ z12),
        # projects it to a feature map and upsamples, then fuses with enc4.
        # ---------------------------------------------------------------------
        self.decoder5 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,          # start from ViT final output (z12)
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )

        # ---------------------------------------------------------------------
        # [Output head in the figure]
        # Final 1×1×1 conv that maps decoder features to the segmentation logits
        # (the small gray "Conv 1×1×1" block).
        # ---------------------------------------------------------------------
        self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)  # type: ignore

    # -------------------------------------------------------------------------
    # Utility: reshape ViT tokens back to 3D feature maps
    # Input shape:  (B, N_tokens, hidden_size) with N_tokens = (H/16)*(W/16)*(D/16)
    # Output shape: (B, hidden_size, H/16, W/16, D/16) to feed CNN blocks
    # -------------------------------------------------------------------------
    def proj_feat(self, x, hidden_size, feat_size):
        x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

    # -------------------------------------------------------------------------
    # (Optional) Load pretrained weights for the ViT encoder from a checkpoint.
    # Copies patch-embedding, transformer blocks, and final norm parameters.
    # -------------------------------------------------------------------------
    def load_from(self, weights):
        with torch.no_grad():
            res_weight = weights
            # copy weights from patch embedding
            for i in weights["state_dict"]:
                print(i)
            self.vit.patch_embedding.position_embeddings.copy_(
                weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
            )
            self.vit.patch_embedding.cls_token.copy_(
                weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
            )
            self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
                weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"]
            )
            self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
                weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"]
            )

            # copy weights from  encoding blocks (default: num of blocks: 12)
            for bname, block in self.vit.blocks.named_children():
                print(block)
                block.loadFrom(weights, n_block=bname)
            # last norm layer of transformer
            self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
            self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])

    def forward(self, x_in):
        # -------------------------------------------------------------
        # ViT forward:
        #   x  = final transformer output  (≈ z12)
        #   hidden_states_out[k] = intermediate tap after block k
        # We tap 3, 6, 9 (~ z3, z6, z9 in the paper’s notation).

        # Every layer’s output (z3, z6, z9, z12), or (x2, x3, x4, x),
        # is still (B, N, 768) (tokens-by-features).

        # proj_feat(x3, ...) is computed inline and passed straight into encoder.
        # It turns the (B, N, 768) tokens-by-features into (B, C, H/16, W/16, D/16) CNN feature maps.
        # -------------------------------------------------------------
        x, hidden_states_out = self.vit(x_in)

        # Top skip from raw image (full resolution)
        enc1 = self.encoder1(x_in)

        # Taps from the transformer, reshaped & projected to CNN features
        x2 = hidden_states_out[3]   # ~ z3
        enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size))

        x3 = hidden_states_out[6]   # ~ z6
        enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size))

        x4 = hidden_states_out[9]   # ~ z9
        enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size))

        # Bridge from the deepest token features (≈ z12) to decoder feature map
        dec4 = self.proj_feat(x, self.hidden_size, self.feat_size)

        # Decoder with skip concatenations (the "C" nodes in the figure)
        dec3 = self.decoder5(dec4, enc4)  # fuse with enc4 (H/8 scale)
        dec2 = self.decoder4(dec3, enc3)  # fuse with enc3 (H/4 scale)
        dec1 = self.decoder3(dec2, enc2)  # fuse with enc2 (H/2 scale)
        out = self.decoder2(dec1, enc1)   # fuse with enc1 (H scale)

        # Final prediction
        logits = self.out(out)
        return logits
