## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"

!python -c "import matplotlib" || pip install -q matplotlib

%matplotlib inline

In [None]:
!pip install mamba_ssm
!pip install causal_conv1d

## Setup imports

In [None]:
from __future__ import annotations
import torch.nn as nn
import torch 
from functools import partial
import glob
import json
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import torch.nn.functional as F 
from torchinfo import summary

In [None]:
import mamba_ssm

In [None]:
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock
from monai.data import DataLoader, decollate_batch, Dataset
from monai.config import print_config
from monai.losses import DiceCELoss, DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
    CropForegroundd,
    Resized,
    SpatialPadd,
    CenterSpatialCropd,
    RandAffined,
    RandGaussianNoised,
    RandAdjustContrastd
)

from monai.utils import set_determinism

print_config()

## Setup data directory

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")

if directory is not None:

    os.makedirs(directory, exist_ok=True)

root_dir = tempfile.mkdtemp() if directory is None else directory

print(root_dir)

## Setup average meter, fold reader, checkpoint saver

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)


def datafold_read(datalist, basedir, fold=0, key="training"):
    with open(datalist) as f:
        json_data = json.load(f)

    json_data = json_data[key]

    for d in json_data:
        for k in d:
            if isinstance(d[k], list):
                d[k] = [os.path.join(basedir, iv) for iv in d[k]]
            elif isinstance(d[k], str):
                d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]

    tr = []
    val = []
    for d in json_data:
        if "fold" in d and d["fold"] == fold:
            val.append(d)
        else:
            tr.append(d)

    return tr, val

## Setup dataloader

In [None]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is NETC
    label 2 is SNFH
    label 3 is ET
    label 4 is RC
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            result.append(torch.logical_or(d[key] == 1, d[key] == 3))
            result.append(torch.logical_or(torch.logical_or(d[key] == 1, d[key] == 3), d[key] == 2))
            result.append(d[key] == 3)
            result.append(d[key] == 4)
            d[key] = torch.stack(result, axis=0).float()
        return d

In [None]:
def get_loader():
    training = {"training": []}
    items = sorted(glob.glob('/kaggle/input/BraTS2024_small_dataset/*')) 
    fold = -1; 
    for i in range(len(items)):   
        if fold > 3:
            fold = -1            
        fold = fold + 1
        if i >= 150:   
            break   
        values = sorted(glob.glob(f"{items[i]}/*"))      
        training["training"].extend([{"fold": fold,"image": values[1:], "label": values[0]}])         
    with open('training_data.json', 'w') as file:      
        json.dump(training, file)
    train_files, validation_files = datafold_read(datalist="/kaggle/working/training_data.json", basedir='', fold=1)
    
    # Training Transform
    train_transform = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys="image"),
            EnsureTyped(keys=["image", "label"]),
            ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.0, 1.0, 1.0),
                mode=("bilinear", "nearest"),
            ),
            Resized(keys=["image", "label"], spatial_size=[128, 128 , 128]),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
        ]
    )

    # Validation Transform
    val_transform = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys="image"),
            EnsureTyped(keys=["image", "label"]),
            ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.0, 1.0, 1.0),
                mode=("bilinear", "nearest"),
            ),
            Resized(keys=["image", "label"], spatial_size=[128, 128 , 128]),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ]
    )



    train_ds = Dataset(data=train_files, transform=train_transform)

    train_loader = DataLoader(
        train_ds,
        batch_size=1,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )
    val_ds = Dataset(data=validation_files, transform=val_transform)
    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )



    return train_loader, val_loader, train_ds, val_ds

## Set dataset root directory and hyper-parameters

In [None]:
train_loader, val_loader, train_ds, val_ds= get_loader()

## Check data shape and visualize

In [None]:
img_add = os.path.join("/kaggle/input/BraTS2024_small_dataset/BraTS-GLI-02063-105/BraTS-GLI-02063-105-t1c.nii")

label_add = os.path.join("/kaggle/input/BraTS2024_small_dataset/BraTS-GLI-02063-105/BraTS-GLI-02063-105-seg.nii")

img = nib.load(img_add).get_fdata()

label = nib.load(label_add).get_fdata()

print(f"image shape: {img.shape}, label shape: {label.shape}")

plt.figure("image", (18, 6))

plt.subplot(1, 2, 1)

plt.title("image")

plt.imshow(img[:, :, 135], cmap="gray")

plt.subplot(1, 2, 2)

plt.title("label")

plt.imshow(label[:, :, 135])

plt.show()

## Create SegMamba model
  

In [None]:
from torch.cuda.amp import custom_bwd, custom_fwd
import causal_conv1d_cuda
import selective_scan_cuda
class MambaInnerFnNoOutProj(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
        """
             xz: (batch, dim, seqlen)
        """
        assert checkpoint_lvl in [0, 1]
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        if torch.is_autocast_enabled():
            x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
        if xz.stride(-1) != 1:
            xz = xz.contiguous()
        conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
        x, z = xz.chunk(2, dim=1)
        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
            x, conv1d_weight, conv1d_bias, None, None, None, True
        )
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
        ctx.is_variable_B = B is None
        ctx.is_variable_C = C is None
        ctx.B_proj_bias_is_None = B_proj_bias is None
        ctx.C_proj_bias_is_None = C_proj_bias is None
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if B.stride(-1) != 1:
                B = B.contiguous()
        if C is None:  # variable C
            C = x_dbl[:, -d_state:]  # (bl dstate)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if C.stride(-1) != 1:
                C = C.contiguous()
        if D is not None:
            D = D.contiguous()
        out, scan_intermediates, out_z = selective_scan_cuda.fwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
        )
        ctx.delta_softplus = delta_softplus
        ctx.checkpoint_lvl = checkpoint_lvl
        if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
            conv1d_out, delta = None, None
        ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                              delta_proj_weight, conv1d_out, delta,
                              A, B, C, D, delta_bias, scan_intermediates, out)
        # return rearrange(out_z, "b d l -> b l d")
        return out_z

    @staticmethod
    @custom_bwd
    def backward(ctx, dout):
        # dout: (batch, seqlen, dim)
        (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, 
         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        if dout.stride(-1) != 1:
            dout = dout.contiguous()
        if ctx.checkpoint_lvl == 1:
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
                x, conv1d_weight, conv1d_bias, None, None, None, True
            )
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                              "d (b l) -> b d l", l = L)
        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
        # backward of selective_scan_cuda with the backward of chunk).
        dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
        dx, dz = dxz.chunk(2, dim=1)
        # dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l
        dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz,
            ctx.delta_softplus,
            True  # option to recompute out_z
        )
        dD = dD if D is not None else None
        dx_dbl = torch.empty_like(x_dbl)
        dB_proj_bias = None
        if ctx.is_variable_B:
            if not A.is_complex():
                dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
            dB = None
        dC_proj_bias = None
        if ctx.is_variable_C:
            if not A.is_complex():
                dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
            dx_dbl[:, -d_state:] = dC  # (bl d)
            dC = None
        ddelta = rearrange(ddelta, "b d l -> d (b l)")
        ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
        dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
        dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
        dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
        dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
        # backward of conv1d with the backward of chunk).
        dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
            x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
        )
        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
        dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
        return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                dA, dB, dC, dD,
                ddelta_bias if delta_bias is not None else None,
                dB_proj_bias, dC_proj_bias, None)

def mamba_inner_fn_no_out_proj(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

In [None]:
# Copyright (c) 2023, Tri Dao, Albert Gu.

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from einops import rearrange, repeat

try:
    import causal_conv1d_fn, causal_conv1d_update
except ImportError:
    causal_conv1d_fn, causal_conv1d_update = None, None

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    selective_state_update = None

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

class XMamba(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True,  # Fused kernel options
        layer_idx=None,
        device=None,
        dtype=None,
        nslices=5
    ):
        self.factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx
        self.nslices = nslices
        self.silu = nn.SiLU()

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **self.factory_kwargs)

        self.activation = "silu"
        self.act = nn.SiLU()

        # S4D real initialization
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **self.factory_kwargs,
        )

        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **self.factory_kwargs
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **self.factory_kwargs)

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32\
        self.D._no_weight_decay = True

        # bidirection
        A_b = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_b_log = torch.log(A_b)  # Keep A_b_log in fp32
        self.A_b_log = nn.Parameter(A_b_log)
        self.A_b_log._no_weight_decay = True 

        self.conv1d_b = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **self.factory_kwargs,
        )
        # trlocthem 
        self.softplus = nn.Softplus()
        # trlocthem
        self.x_proj_b = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **self.factory_kwargs
        )
        self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **self.factory_kwargs)

        self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
        self.D_b._no_weight_decay = True

        # spatial
        A_s = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_s_log = torch.log(A_s)  # Keep A_b_log in fp32
        self.A_s_log = nn.Parameter(A_s_log)
        self.A_s_log._no_weight_decay = True 

        self.conv1d_s = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **self.factory_kwargs,
        )

        self.x_proj_s = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **self.factory_kwargs
        )
        self.dt_proj_s = nn.Linear(self.dt_rank, self.d_inner, bias=True, **self.factory_kwargs)

        self.D_s = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
        self.D_s._no_weight_decay = True

        self.device = device

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(self.d_inner, **self.factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)

        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        self.dt_proj.bias._no_reinit = True
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **self.factory_kwargs)
        self.tanh = nn.Tanh()

    def WMF(self, *o):
        k_weights = nn.Parameter(torch.tensor([1/3, 1/3, 1/3]), requires_grad=True)
        k_weights = torch.softmax(k_weights, dim=0)

        assert len(o) == len(k_weights), "The number of outputs and weights must match."

        O = sum(w * out for w, out in zip(k_weights, o))
        sp = nn.Softplus()
        return sp(O)

    #trlocthem
    def process_direction(
            self,
            x: Tensor,
            conv1d: nn.Conv1d,
            conv1d_weight, 
            conv1d_bias, 
            x_proj_weight, 
            delta_proj_weight,
            A, B=None, C=None, D=None, 
            delta_bias=None, 
            B_proj_bias=None,
            C_proj_bias=None, 
            delta_softplus=True
    ):
        x = conv1d(x)
        x = F.relu(x)
        x = mamba_inner_fn_no_out_proj(x, conv1d_weight.to(dtype=x.dtype), conv1d_bias.to(dtype=x.dtype), x_proj_weight.to(dtype=x.dtype), delta_proj_weight.to(dtype=x.dtype), A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
        return x
    #trlocthem

    def gaussian_decay_mask(self , sequence):
        length = sequence.shape[1]
        # Automatically determine center and last index
        center_index = (length + 1) // 2

        ref_index = center_index
        ref_vector = sequence[:, center_index, :]

        # Index-based Gaussian mask
        indices = torch.arange(length, dtype=torch.float32, device=sequence.device)
        sigma1 = torch.abs(indices - ref_index).mean()  # Sigma calculation
        weights1 = torch.exp(-0.5 * ((indices - ref_index) ** 2) / (sigma1 ** 2))
        weights1 /= weights1.sum()
        weights1 = weights1.repeat(sequence.size(0), 1)  # Repeat weights for batch dimension

        # Vector-based Gaussian mask
        distances = torch.norm(sequence - ref_vector.unsqueeze(1), dim=2)
        sigma2 = distances.mean(dim=1, keepdim=True)
        weights2 = torch.exp(-0.5 * (distances / sigma2) ** 2)
        weights2 = weights2 / weights2.sum(dim=1, keepdim=True)

        combined_weights = weights1  * weights2
        combined_weights = combined_weights / combined_weights.sum(dim=1, keepdim=True)
        s_hat_f = sequence * combined_weights.unsqueeze(2)
        return s_hat_f
        
    def forward(self, hidden_states, inference_params=None):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        batch, seqlen, dim = hidden_states.shape

        self.forward_conv1d = nn.Conv1d(
            in_channels=dim*4, out_channels=dim*4, kernel_size=1, **self.factory_kwargs
        ).cuda(device=self.device)
        self.backward_conv1d = nn.Conv1d(
            in_channels=dim*4, out_channels=dim*4, kernel_size=1, **self.factory_kwargs
        ).cuda(device=self.device)
        conv_state, ssm_state = None, None
        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                # The states are updated inplace
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out

        # We do matmul and transpose BLH -> HBL at the same time
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )

        
        if self.in_proj.bias is not None:
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
        # In the backward pass we write dx and dz next to each other to avoid torch.cat
        self.center = (seqlen + 1) // 2
        proj1 = nn.Linear(dim, dim, **self.factory_kwargs).cuda(device=device)
        proj2 = nn.Linear(dim*2, dim*2, **self.factory_kwargs).cuda(device=device)
        self.norm = nn.LayerNorm(dim*2).cuda(device=self.device)
        # self.norm.weight = self.norm.weight.to(self.device)
        # self.norm.bias = self.norm.bias.to(self.device)
        self.adapool = nn.AdaptiveAvgPool1d(2*dim)
        if self.use_fast_path and inference_params is None:  # Doesn't support outputting the states
                self.norm = self.norm.to(self.device)
                skip = xz
                out_f = self.process_direction(
                    xz[:, :, : self.center],
                    self.forward_conv1d,
                    self.conv1d.weight,
                    self.conv1d.bias,
                    self.x_proj.weight,
                    self.dt_proj.weight,
                    A,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D.float(),
                    self.dt_proj.bias.float(),
                    None, # B_proj_bias
                    None, # C_proj_bias
                    True, # delta_softplus
                ) # (B, D, S)
                out_bw = self.process_direction(
                    xz[:, :, self.center :].flip([-1]),
                    self.backward_conv1d,
                    self.conv1d.weight,
                    self.conv1d.bias,
                    self.x_proj.weight,
                    self.dt_proj.weight,
                    A,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D.float(),
                    self.dt_proj.bias.float(),
                    None,
                    None,
                    True,
                ) # (B, D, S)
                out_f = rearrange(out_f, 'b d n -> b n d')
                out_f = self.gaussian_decay_mask(out_f)
                out_f = self.silu(out_f)
            
                out_bw = rearrange(out_bw, 'b d n -> b n d')
                out_bw = self.gaussian_decay_mask(out_bw)
                out_bw = self.silu(out_bw)

                
                out = torch.cat([out_f, out_bw.flip([-1])], dim=-2)
                out = proj2(out)
                out = rearrange(out, 'b n d -> b d n')
                skip = self.adapool(rearrange(skip, 'b n d -> b d n'))
                skip = rearrange(skip, 'b d n -> b n d')
                out = out + skip
                out = out.permute(0, 2, 1)
                out = self.norm(out)
                out = out.permute(0, 2, 1)
                out = self.tanh(out)

                A_b = -torch.exp(self.A_b_log.float())
                skip1 = xz.flip([-1])
                out_bf = self.process_direction(
                    skip1[:, :, : self.center],
                    self.forward_conv1d,
                    self.conv1d_b.weight,
                    self.conv1d_b.bias,
                    self.x_proj_b.weight,
                    self.dt_proj_b.weight,
                    A_b,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D_b.float(),
                    self.dt_proj_b.bias.float(),
                    None,
                    None,
                    True,
                )

                out_bbw = self.process_direction(
                    skip1[:, :, self.center :].flip([-1]),
                    self.backward_conv1d,
                    self.conv1d_b.weight,
                    self.conv1d_b.bias,
                    self.x_proj_b.weight,
                    self.dt_proj_b.weight,
                    A_b,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D_b.float(),
                    self.dt_proj_b.bias.float(),
                    None,
                    None,
                    True,
                ) 
                # (B, D, S)
                out_bf = rearrange(out_bf, 'b d n -> b n d')
                out_bf = self.gaussian_decay_mask(out_bf)
                out_bf = self.silu(out_bf)

                out_bbw = rearrange(out_bbw, 'b d n -> b n d')
                out_bbw = self.gaussian_decay_mask(out_bbw)
                out_bbw = self.silu(out_bbw)

                out_b = torch.cat([out_bf, out_bbw.flip([-1])], dim=-2)
                out_b = proj2(out_b)
                out_b = rearrange(out_b, 'b n d -> b d n')
                skip1 = self.adapool(rearrange(skip1, 'b n d -> b d n'))
                skip1 = rearrange(skip1, 'b d n -> b n d')
                out_b = out_b + skip1
                out_b = out_b.permute(0, 2, 1)
                out_b = self.norm(out_b)
                out_b = out_b.permute(0, 2, 1)
                out_b = self.tanh(out_b)

                A_s = -torch.exp(self.A_s_log.float())
                xz_s = xz.chunk(self.nslices, dim=-1)
                xz_s = torch.stack(xz_s,dim=-1)
                xz_s = xz_s.flatten(-2)
                skip2 = xz_s
                # self.adapool = nn.AdaptiveAvgPool1d(skip2.shape[2])
                out_sf = self.process_direction(
                    skip2[:, :, : self.center],
                    self.forward_conv1d,
                    self.conv1d_s.weight,
                    self.conv1d_s.bias,
                    self.x_proj_s.weight,
                    self.dt_proj_s.weight,
                    A_s,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D_s.float(),
                    self.dt_proj_s.bias.float(),
                    None,
                    None,
                    True,
                )
                
                out_sbw = self.process_direction(
                    skip2[:, :, self.center :].flip([-1]),
                    self.backward_conv1d,
                    self.conv1d_s.weight,
                    self.conv1d_s.bias,
                    self.x_proj_s.weight,
                    self.dt_proj_s.weight,
                    A_s,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D_s.float(),
                    self.dt_proj_s.bias.float(),
                    None,
                    None,
                    True,
                ) # (B, D, S)

                out_sf = out_sf.reshape(batch,self.d_inner, -1)
                out_sf = rearrange(out_sf, 'b d n -> b n d')
                out_sf = self.gaussian_decay_mask(out_sf)
                out_sf = self.silu(out_sf)

                out_sbw = out_sbw.reshape(batch,self.d_inner, -1)
                out_sbw = rearrange(out_sbw, 'b d n -> b n d')
                out_sbw = self.gaussian_decay_mask(out_sbw)
                out_sbw = self.silu(out_sbw)

                out_s = torch.cat([out_sf, out_sbw.flip([-1])], dim=-2)
                out_s = proj2(out_s)
                out_s = rearrange(out_s, 'b n d -> b d n')
                skip2 = self.adapool(rearrange(skip2, 'b n d -> b d n'))
                skip2 = rearrange(skip2, 'b d n -> b n d')
                out_s = out_s + skip2
                out_s = out_s.permute(0, 2, 1)
                out_s = self.norm(out_s)
                out_s = out_s.permute(0, 2, 1)
                out_s = self.tanh(out_s)
                
                out_s = out_s.reshape(batch,self.d_inner,seqlen//self.nslices,self.nslices).permute(0,1,3,2).flatten(-2)

                out = self.WMF(out, out_b.flip([-1]), out_s) 
                out = rearrange(out, "b d l -> b l d")  # Rearrange the tensor as needed
                out = F.linear(out, self.out_proj.weight, self.out_proj.bias)
            
        return out
    
    

    def step(self, hidden_states, conv_state, ssm_state):
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
        x, z = xz.chunk(2, dim=-1)  # (B D)

        # Conv step
        if causal_conv1d_update is None:
            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)
            conv_state[:, :, -1] = x
            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)  # (B D)
            if self.conv1d.bias is not None:
                x = x + self.conv1d.bias
            x = self.act(x).to(dtype=dtype)
        else:
            x = causal_conv1d_update(
                x,
                conv_state,
                rearrange(self.conv1d.weight, "d 1 w -> d w"),
                self.conv1d.bias,
                self.activation,
            )

        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        # Don't add dt_bias here
        dt = F.linear(dt, self.dt_proj.weight)  # (B d_inner)
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        # SSM step
        if selective_state_update is None:
            # Discretize A and B
            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
            dB = torch.einsum("bd,bn->bdn", dt, B)
            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
            y = y + self.D.to(dtype) * x
            y = y * self.act(z)  # (B D)
        else:
            y = selective_state_update(
                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
            )

        out = self.out_proj(y)
        return out.unsqueeze(1), conv_state, ssm_state

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        device = self.out_proj.weight.device
        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
        conv_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
        )
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
        # ssm_dtype = torch.float32
        ssm_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
        )
        return conv_state, ssm_state

    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
        assert self.layer_idx is not None
        if self.layer_idx not in inference_params.key_value_memory_dict:
            batch_shape = (batch_size,)
            conv_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_conv,
                device=self.conv1d.weight.device,
                dtype=self.conv1d.weight.dtype,
            )
            ssm_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_state,
                device=self.dt_proj.weight.device,
                dtype=self.dt_proj.weight.dtype,
                # dtype=torch.float32,
            )
            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
        else:
            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
            # TODO: What if batch size changes between generation, and we reuse the same states?
            if initialize_states:
                conv_state.zero_()
                ssm_state.zero_()
        return conv_state, ssm_state

In [None]:
# Copyright (c) MONAI Consortium

from __future__ import annotations
import torch.nn as nn
import torch 
from functools import partial

from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock
from mamba_ssm import Mamba
import torch.nn.functional as F 

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape, )

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]

            return x

class MambaLayer(nn.Module):
    def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, num_slices=None):
        super().__init__()
        self.dim = dim
        self.norm = nn.LayerNorm(dim)
        self.mamba = XMamba(
                d_model=dim,
                d_state=d_state, 
                d_conv=d_conv,
                expand=expand,
                nslices=num_slices,
        )
    
    def forward(self, x):
        B, C = x.shape[:2]
        x_skip = x
        assert C == self.dim
        n_tokens = x.shape[2:].numel()
        img_dims = x.shape[2:]
        x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
        x_norm = self.norm(x_flat)
        x_mamba = self.mamba(x_norm)
        out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
        out = out + x_skip
        act = nn.GELU()
        
        return act(out)
    
class MlpChannel(nn.Module):
    def __init__(self,hidden_size, mlp_dim, ):
        super().__init__()
        self.fc1 = nn.Conv3d(hidden_size, mlp_dim, 1)
        self.act = nn.GELU()
        self.fc2 = nn.Conv3d(mlp_dim, hidden_size, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

class GSC(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.DW_Conv = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm3d(in_channels),
            nn.ReLU()
        )

        self.PW_Conv = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0),
            nn.InstanceNorm3d(in_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        
        x = self.DW_Conv(x)
        residual = x
        out = self.PW_Conv(x)
        
        return out + residual

class MambaEncoder(nn.Module):
    def __init__(self, in_chans=1, depths=[2, 2, 2, 2], dims=[48, 96, 192, 384],
                 drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3]):
        super().__init__()

        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
              nn.Conv3d(in_chans, dims[0], kernel_size=7, stride=2, padding=3),
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                nn.InstanceNorm3d(dims[i]),
                nn.Conv3d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()
        self.gscs = nn.ModuleList()
        num_slices_list = [64, 32, 16, 8]
        cur = 0
        for i in range(4):
            gsc = GSC(dims[i])

            stage = nn.Sequential(
                *[MambaLayer(dim=dims[i], num_slices=num_slices_list[i]) for j in range(depths[i])]
            )

            self.stages.append(stage)
            self.gscs.append(gsc)
            cur += depths[i]

        self.out_indices = out_indices

        self.mlps = nn.ModuleList()
        for i_layer in range(4):
            layer = nn.InstanceNorm3d(dims[i_layer])
            layer_name = f'norm{i_layer}'
            self.add_module(layer_name, layer)
            self.mlps.append(MlpChannel(dims[i_layer], 2 * dims[i_layer]))

    def forward_features(self, x):
        outs = []
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.gscs[i](x)
            x = self.stages[i](x)

            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                x_out = norm_layer(x)
                x_out = self.mlps[i](x_out)
                outs.append(x_out)

        return tuple(outs)

    def forward(self, x):
        x = self.forward_features(x)
        return x

class SegMamba(nn.Module):
    def __init__(
        self,
        in_chans=1,
        out_chans=13,
        depths=[2, 2, 2, 2],
        feat_size=[48, 96, 192, 384],
        drop_path_rate=0,
        layer_scale_init_value=1e-6,
        hidden_size: int = 768,
        norm_name = "instance",
        conv_block: bool = True,
        res_block: bool = True,
        spatial_dims=3,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.depths = depths
        self.drop_path_rate = drop_path_rate
        self.feat_size = feat_size
        self.layer_scale_init_value = layer_scale_init_value

        self.spatial_dims = spatial_dims
        self.vit = MambaEncoder(in_chans, 
                                depths=depths,
                                dims=feat_size,
                                drop_path_rate=drop_path_rate,
                                layer_scale_init_value=layer_scale_init_value,
                              )
        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.in_chans,
            out_channels=self.feat_size[0],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[0],
            out_channels=self.feat_size[1],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder3 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[1],
            out_channels=self.feat_size[2],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder4 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[2],
            out_channels=self.feat_size[3],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )

        self.encoder5 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[3],
            out_channels=self.hidden_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )

        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=self.hidden_size,
            out_channels=self.feat_size[3],
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[3],
            out_channels=self.feat_size[2],
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[2],
            out_channels=self.feat_size[1],
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[1],
            out_channels=self.feat_size[0],
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=self.feat_size[0],
            out_channels=self.feat_size[0],
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=48, out_channels=self.out_chans)

    def forward(self, x_in):
        outs = self.vit(x_in)
        enc1 = self.encoder1(x_in)
        x2 = outs[0]
        enc2 = self.encoder2(x2)
        x3 = outs[1]
        enc3 = self.encoder3(x3)
        x4 = outs[2]
        enc4 = self.encoder4(x4)
        enc_hidden = self.encoder5(outs[3])
        dec3 = self.decoder5(enc_hidden, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        dec0 = self.decoder2(dec1, enc1)
        out = self.decoder1(dec0)
                
        return self.out(out)

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

## Optimizer and loss function

In [None]:
max_epochs = 120
val_interval = 1
VAL_AMP = True
roi = (128, 128, 128)

device = torch.device("cuda")
model = SegMamba(in_chans=4,
                 out_chans=4,
                 depths=[2,2,2,2],
                 feat_size=[48, 96, 192, 384]).to(device)


loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=120)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(128, 128, 128),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.amp.autocast(device_type='cuda'):
            return _compute(input)
    else:
        return _compute(input)


# use amp to accelerate training
scaler = torch.amp.GradScaler('cuda')
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

## Finetunning

In [None]:
def load_checkpoint(filename, model, optimizer=None, lr_scheduler=None):
    checkpoint = torch.load(filename)
    
    model.load_state_dict(checkpoint["state_dict"], strict=False)
    
    print(f"Checkpoint loaded: Epoch {checkpoint['epoch']}, Best Accuracy: {checkpoint['best_acc']}")

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    if lr_scheduler is not None:
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
    
    return checkpoint["epoch"], checkpoint["best_acc"]

In [None]:
# checkpoint_file = '/kaggle/input/pretrain/model.pt'

# start_epoch, best_acc = load_checkpoint(
#     checkpoint_file, 
#     model, 
#     optimizer=optimizer, 
#     lr_scheduler=lr_scheduler
# )

## Execute training

In [None]:
def save_checkpoint(optimizer, lr_scheduler, model, epoch, filename="model.pt", best_acc=0, dir_add="./"):
    state_dict = model.state_dict()
    optimizer_state_dict = optimizer.state_dict()
    lr_scheduler_state_dict = lr_scheduler.state_dict()
    
    save_dict = {
        "epoch": epoch,
        "best_acc": best_acc,
        "state_dict": state_dict,
        "optimizer_state_dict": optimizer_state_dict,
        "lr_scheduler_state_dict": lr_scheduler_state_dict,
    }
    
    filename = os.path.join(dir_add, filename)
    torch.save(save_dict, filename)
    print(f"Checkpoint saved at {filename}")

In [None]:
from tqdm import tqdm

In [None]:
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
metric_values_tc = []
metric_values_wt = [] 
metric_values_et = []
metric_values_rc = []

total_start = time.time()
for epoch in tqdm(range(max_epochs), desc="Training Epochs"):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        with torch.amp.autocast(device_type='cuda'):
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()

    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_outputs = inference(val_inputs)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                
                # Đo lường và tính toán các metrics cho từng lớp
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)


            # Tính tổng Dice score cho tất cả lớp (mean dice)
            metric = dice_metric.aggregate().item()
            metric_values.append(metric)

            # Các giá trị Dice cho từng lớp cụ thể
            metric_batch = dice_metric_batch.aggregate()
            metric_tc = metric_batch[0].item() 
            metric_values_tc.append(metric_tc)
            metric_wt = metric_batch[1].item()
            metric_values_wt.append(metric_wt)
            metric_et = metric_batch[2].item()
            metric_values_et.append(metric_et)
            metric_rc = metric_batch[3].item()
            metric_values_rc.append(metric_rc)

            dice_metric.reset()
            dice_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                save_checkpoint(
                    optimizer,
                    lr_scheduler,
                    model,
                    epoch,
                    best_acc=best_metric,
                )
    
                print("saved new best metric model")

            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f} \n"
                f"tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f} rc: {metric_rc:.4f} \n"
                f"Best mean dice: {best_metric:.4f} \n"
                f"at epoch: {best_metric_epoch}"
            )
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
total_time = time.time() - total_start

In [None]:
total_time = time.time() - total_start

In [None]:
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")

## Plot the loss and Dice metric

In [None]:
import matplotlib.pyplot as plt

# Đoạn vẽ đồ thị cho loss và Dice mean
plt.figure("train", (12, 6))

# Vẽ Epoch Average Loss
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(x, y, color="red")

# Vẽ Val Mean Dice
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Epoch")
plt.ylabel("Dice Score")
plt.plot(x, y, color="green")

plt.show()

# Đoạn vẽ đồ thị cho các lớp (TC, WT, ET, RC)
plt.figure("train", (18, 6))

# Val Mean Dice TC (Tumor Core)
plt.subplot(1, 4, 1)
plt.title("Val Mean Dice TC")
x = [val_interval * (i + 1) for i in range(len(metric_values_tc))]
y = metric_values_tc
plt.xlabel("Epoch")
plt.ylabel("Dice Score")
plt.plot(x, y, color="brown")

# Val Mean Dice WT
plt.subplot(1, 4, 2)
plt.title("Val Mean Dice WT")
x = [val_interval * (i + 1) for i in range(len(metric_values_wt))]
y = metric_values_wt
plt.xlabel("Epoch")
plt.ylabel("Dice Score")
plt.plot(x, y, color="purple")

# Val Mean Dice ET
plt.subplot(1, 4, 3)
plt.title("Val Mean Dice ET")
x = [val_interval * (i + 1) for i in range(len(metric_values_et))]
y = metric_values_et
plt.xlabel("Epoch")
plt.ylabel("Dice Score")
plt.plot(x, y, color="orange")

plt.subplot(1, 4, 4)
plt.title("Val Mean Dice RC")
x = [val_interval * (i + 1) for i in range(len(metric_values_rc))]
y = metric_values_rc
plt.xlabel("Epoch")
plt.ylabel("Dice Score")
plt.plot(x, y, color="blue")

plt.show()

In [None]:
import pandas as pd

df_loss = pd.DataFrame({
    "Epoch": range(1, len(epoch_loss_values) + 1),
    "Loss": epoch_loss_values
})
df_loss.to_csv("epoch_loss_values.csv", index=False)

df_metric = pd.DataFrame({
    "Epoch": [val_interval * (i + 1) for i in range(len(metric_values))],
    "Dice Score": metric_values
})
df_metric.to_csv("metric_values.csv", index=False)

classes = {
    "metric_values_tc": metric_values_tc,
    "metric_values_wt": metric_values_wt,
    "metric_values_et": metric_values_et,
    "metric_values_rc": metric_values_rc
}

for class_name, values in classes.items():
    df = pd.DataFrame({
        "Epoch": [val_interval * (i + 1) for i in range(len(values))],
        "Dice Score": values
    })
    df.to_csv(f"{class_name}.csv", index=False)


## Create test set dataloader

In [None]:
# Các modal cần xử lý
modalities = ["t1n", "t1c", "t2f", "t2w"]

# Tạo danh sách test_files tự động
test_files = [
    {
        "image": [
            f"/kaggle/input/BraTS2024_small_dataset/BraTS-GLI-02063-105/BraTS-GLI-02063-105-{modality}.nii"
            for modality in modalities
        ],
        "label": "/kaggle/input/BraTS2024_small_dataset/BraTS-GLI-02063-105/BraTS-GLI-02063-105-seg.nii",
    }
]

test_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

test_ds = Dataset(data=test_files, transform=test_transform)

test_loader = DataLoader(
    test_ds,
    batch_size=1,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)

## Load the best saved checkpoint and perform inference 



We select a single case from the validation set and perform inference to compare the model segmentation output with the corresponding label. 

In [None]:
model.load_state_dict(torch.load(os.path.join("model.pt"))["state_dict"])
model.to(device)
model.eval()

model_inferer_test = partial(
    sliding_window_inference,
    roi_size=[roi[0], roi[1], roi[2]],
    sw_batch_size=1,
    predictor=model,
    overlap=0.6,
)


with torch.no_grad():
    for batch_data in test_loader:
        image = batch_data["image"].to(device)
        prob = torch.sigmoid(model_inferer_test(image))
        seg = prob[0].detach().cpu().numpy()
        seg = (seg > 0.5).astype(np.int8)
        seg_out = np.zeros((seg.shape[1], seg.shape[2], seg.shape[3]))
        seg_out[seg[1] == 1] = 2
        seg_out[seg[0] == 1] = 1
        seg_out[seg[2] == 1] = 3
        seg_out[seg[3] == 1] = 4

## Visualize segmentation output and compare with label

In [None]:
with torch.no_grad():
    # select one image to evaluate and visualize the model output
    val_input = val_ds[3]["image"].unsqueeze(0).to(device)
    roi_size = (128, 128, 128)
    sw_batch_size = 4
    val_output = inference(val_input)
    val_output = post_trans(val_output[0])
    plt.figure("image", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"image channel {i}")
        plt.imshow(val_ds[3]["image"][i, :, :, 64].detach().cpu(), cmap="gray")
    plt.show()
    # visualize the 4 channels label corresponding to this image
    plt.figure("label", (18, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"label channel {i}")
        plt.imshow(val_ds[3]["label"][i, :, :, 64].detach().cpu())
    plt.show()
    # visualize the 4 channels model output corresponding to this image
    plt.figure("output", (18, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[i, :, :, 64].detach().cpu())
    plt.show()

In [None]:
slice_num = 90

img_add = os.path.join("/kaggle/input/BraTS2024_small_dataset/BraTS-GLI-02063-105/BraTS-GLI-02063-105-t1c.nii")

label_add = os.path.join("/kaggle/input/BraTS2024_small_dataset/BraTS-GLI-02063-105/BraTS-GLI-02063-105-seg.nii")

img = nib.load(img_add).get_fdata()

label = nib.load(label_add).get_fdata()

plt.figure("image", (18, 6))

plt.subplot(1, 3, 1)

plt.title("image")

plt.imshow(img[:, :, slice_num], cmap="gray")

plt.subplot(1, 3, 2)

plt.title("label")

plt.imshow(label[:, :, slice_num])

plt.subplot(1, 3, 3)

plt.title("segmentation")

plt.imshow(seg_out[:, :, slice_num])

plt.show()

# Save model.pt

In [None]:
torch.save(torch.load(os.path.join("model.pt")), "/kaggle/working/model.pt")

## Cleanup data directory



Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)