In [1]:
from fastai.callback.wandb import *
from fastMONAI.vision_all import *

In [3]:
set_seed(42, reproducible=True)

In [10]:
default_config = SimpleNamespace(
    batch_size = 2, # best 2
    img_path = "ImgTrMskCmb",
    # ImgTr -> base raw QSM
    # ImgTrMsk -> masked QSM images
    # ImgTrMskCmb -> masked with T2* map
    # ImgTrMskCmbMag -> masked QSM with T2* map and magnitude
    model_channel= "8,16,32", # "4,8,8", "4,8,16", "8,16,32"
    learning_rate = 0.01, # 0.1 / 0.01
    
    epoch = 10 # 300-500 epochs for training from scratch
)

config = default_config

In [13]:
# functions to get paths of image and label files

def get_nii_files(path, recurse=True, folders=None):
    res =  get_files(path, [".nii"], recurse=recurse, folders=folders)
    return [str(x) for x in res]

def get_gz_files(path, recurse=True, folders=None):
    res =  get_files(path, [".gz"], recurse=recurse, folders=folders)
    return [str(x) for x in res]

def get_lbl(imagepath):
    # get segmentatino folder name by replacing 'qsm_' with 'segg_'
    seg_name = imagepath.split('/')[-1].replace('qsm-even-echoes','segmentation_clean').replace('.nii','.nii.gz')
    
    return str(path_lbl_gm/seg_name) #return the str'd path because fastmonai expects string

# function to split valid and train based on the 'valid' array
def FileSplitter():
    def _func(x):
        return any(s in x for s in valid)
    def _inner(o, **kwargs): return FuncSplitter(_func)(o)
    return _inner

In [14]:
# valid data points
valid = [
    # new data
    'z0186251', # gold marker only, but one marker only occupies 1 layer
    'z0705200', # calcification fragments. idk good for training or test lol
    'z1472355' # very small 1 calcification 
]

In [15]:
# dataset definition
path = Path("bidsmonai-data/")
path_lbl_gm = path/"labelsTrGm"
path_im = path/config.img_path

fnames = get_nii_files(path_im)
lbl_names = get_gz_files(path_lbl_gm)

ll = [str(x) for x in lbl_names]
med_dataset = MedDataset(img_list=ll, dtype=MedMask)
resample, reorder = med_dataset.suggestion()

In [16]:
# augmentation setup
size = [144,144,64]
item_tfms= [
    RandomFlip(axes=("LR",)),
    # RandomFlip(axes=("AP",)),
    ZNormalization(),
    PadOrCrop(size),
]

# define and load data block
bids = MedDataBlock(
    blocks=(ImageBlock(cls=MedImage),MedMaskBlock),
    splitter=FileSplitter(),
    get_items=get_nii_files,
    get_y = get_lbl,
    item_tfms=item_tfms,
    reorder=reorder,
    resample=resample
)
dls = bids.dataloaders(path_im, bs=config.batch_size)

# determine model channel count from dataset loaded
MODEL_INPUT_CHANNELS = dls.train_ds[0][0].size()[0]

In [17]:
from monai.networks.nets import UNet
from monai.losses import DiceLoss

loss_func = CustomLoss(
    loss_func=DiceLoss(sigmoid=True)
)

In [20]:
# convert channel config string into tuple
CHANNEL_CONFIGS = {
    "4,8,16": (4,8,16),
    "4,8,8": (4,8,8),
    "8,16,32": (8,16,32)
}


# define model
model = UNet(
    dimensions=3,
    in_channels=MODEL_INPUT_CHANNELS,
    out_channels=1,
    channels=CHANNEL_CONFIGS[config.model_channel],
    strides=(2,2),
    num_res_units=2
)
model = model.model

# create learner
learn=Learner(
    dls,
    model,
    loss_func=loss_func,
    opt_func=ranger,
    metrics=binary_dice_score,
    cbs=[
        #TODO set name/check with W&B
        SaveModelCallback(
            monitor="valid_loss",
            every_epoch=False,
            with_opt=True
        ),
        SaveModelCallback(
            monitor="binary_dice_score",
            every_epoch=False,
            with_opt=True
        ),
        # WandbCallback(log_model=True)
    ]
)

In [21]:
# run training as configured
learn.fit_flat_cos(config.epoch, lr=config.learning_rate)

epoch,train_loss,valid_loss,binary_dice_score,time
0,0.999928,0.999955,2.5e-05,00:07
1,0.999928,0.999955,2.5e-05,00:04
2,0.999929,0.999955,2.5e-05,00:04
3,0.999929,0.999955,2.5e-05,00:04
4,0.999929,0.999955,2.5e-05,00:04
5,0.999929,0.999948,3.1e-05,00:04
6,0.999927,0.999939,3.5e-05,00:04
7,0.999924,0.99993,4e-05,00:04
8,0.999921,0.999917,4.9e-05,00:04
9,0.999916,0.99991,5.3e-05,00:04




Better model found at epoch 0 with valid_loss value: 0.9999549388885498.
Better model found at epoch 0 with binary_dice_score value: 2.4727080017328262e-05.
Better model found at epoch 2 with valid_loss value: 0.999954879283905.
Better model found at epoch 5 with valid_loss value: 0.999947726726532.
Better model found at epoch 5 with binary_dice_score value: 3.069983722525649e-05.
Better model found at epoch 6 with valid_loss value: 0.9999393820762634.
Better model found at epoch 6 with binary_dice_score value: 3.4578257327666506e-05.
Better model found at epoch 7 with valid_loss value: 0.9999303817749023.
Better model found at epoch 7 with binary_dice_score value: 3.992760684923269e-05.
Better model found at epoch 8 with valid_loss value: 0.9999167323112488.
Better model found at epoch 8 with binary_dice_score value: 4.8511166824027896e-05.
Better model found at epoch 9 with valid_loss value: 0.9999095797538757.
Better model found at epoch 9 with binary_dice_score value: 5.31868608959