In [1]:
import os

import torch
from monai.networks.nets import DynUNet
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
task_name = {
    "01": "Task01_BrainTumour",
    "02": "Task02_Heart",
    "03": "Task03_Liver",
    "04": "Task04_Hippocampus",
    "05": "Task05_Prostate",
    "06": "Task06_Lung",
    "07": "Task07_Pancreas",
    "08": "Task08_HepaticVessel",
    "09": "Task09_Spleen",
    "10": "Task10_Colon",
}

patch_size = {
    "01": [96, 96, 96],
    "02": [160, 192, 80],
    "03": [128, 128, 128],
    "04": [40, 56, 40],
    "05": [320, 256, 20],
    "06": [192, 160, 80],
    "07": [224, 224, 40],
    "08": [192, 192, 64],
    "09": [192, 160, 64],
    "10": [192, 160, 56],
}

spacing = {
    "01": [1.0, 1.0, 1.0],
    "02": [1.25, 1.25, 1.37],
    "03": [0.77, 0.77, 1],
    "04": [1.0, 1.0, 1.0],
    "05": [0.62, 0.62, 3.6],
    "06": [0.79, 0.79, 1.24],
    "07": [0.8, 0.8, 2.5],
    "08": [0.8, 0.8, 1.5],
    "09": [0.79, 0.79, 1.6],
    "10": [0.78, 0.78, 3],
}

clip_values = {
    "01": [0, 0],
    "02": [0, 0],
    "03": [-17, 201],
    "04": [0, 0],
    "05": [0, 0],
    "06": [-1024, 325],
    "07": [-96, 215],
    "08": [-3, 243],
    "09": [-41, 176],
    "10": [-30, 165.82],
}

normalize_values = {
    "01": [0, 0],
    "02": [0, 0],
    "03": [99.40, 39.36],
    "04": [0, 0],
    "05": [0, 0],
    "06": [-158.58, 324.7],
    "07": [77.99, 75.4],
    "08": [104.37, 52.62],
    "09": [99.29, 39.47],
    "10": [62.18, 32.65],
}

data_loader_params = {
    "01": {"batch_size": 8},
    "02": {"batch_size": 2},
    "03": {"batch_size": 8},
    "04": {"batch_size": 9},
    "05": {"batch_size": 2},
    "06": {"batch_size": 2},
    "07": {"batch_size": 2},
    "08": {"batch_size": 2},
    "09": {"batch_size": 2},
    "10": {"batch_size": 2},
}

deep_supr_num = {
    "01": 3,
    "02": 3,
    "03": 3,
    "04": 1,
    "05": 4,
    "06": 3,
    "07": 3,
    "08": 3,
    "09": 3,
    "10": 3,
}


In [11]:
def get_kernels_strides(task_id):
    """
    This function is only used for decathlon datasets with the provided patch sizes.
    When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
    be divisible by the product of all strides in the corresponding dimension.
    In addition, the minimal spatial size should have at least one dimension that has twice the size of
    the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.
    """
    sizes, spacings = patch_size[task_id], spacing[task_id]
    input_size = sizes
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [
            2 if ratio <= 2 and size >= 8 else 1
            for (ratio, size) in zip(spacing_ratio, sizes)
        ]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        for idx, (i, j) in enumerate(zip(sizes, stride)):
            if i % j != 0:
                raise ValueError(
                    f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
                )
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)

    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])
    return kernels, strides

In [12]:
def get_network(task_id):
    # n_class = len(properties["labels"])
    # in_channels = len(properties["modality"])
    kernels, strides = get_kernels_strides(task_id)
    print("Kernels .....")
    print(np.shape(kernels))
    print(len(kernels))
    print(kernels)

    print("Strides .....")
    print(np.shape(strides))
    print(len(strides))
    print(strides)
    raise ValueError

    net = DynUNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=8,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=deep_supr_num[task_id],
    )

In [14]:
net = get_network("01")

Kernels .....
(5, 3)
5
[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
Strides .....
(5, 3)
5
[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]


ValueError: 

In [16]:
heads = [torch.rand(1)] * 3
print(heads)

[tensor([0.2368]), tensor([0.2368]), tensor([0.2368])]


In [17]:
params = {
    "spatial_dims": 3,
    "in_channels": 32,
    "out_channels": 64,
    "kernel_size": [3,3,3],
    "stride": [2,2,2]
    }

print(params)

{'spatial_dims': 3, 'in_channels': 32, 'out_channels': 64, 'kernel_size': [3, 3, 3], 'stride': [2, 2, 2]}


In [19]:
kargs = (**params)
print(kargs)

SyntaxError: invalid syntax (777555519.py, line 1)