In [1]:
import random, os, functools, statistics
from typing import Any

import numpy as np
import zarr
import matplotlib.pyplot as plt
import dill, joblib

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.transforms import v2
from torch.utils.data import DataLoader

from glio.train2 import *
from glio.train2.cbs_summary import Summary
from glio.visualize import vis_imshow, vis_imshow_grid, Visualizer
from glio.jupyter_tools import show_slices, show_slices_arr, clean_mem
from glio.torch_tools import area_around, one_hot_mask, summary, lr_finder, to_binary, count_parameters, replace_conv, replace_conv_transpose, replace_layers
from glio.python_tools import type_str, CacheRepeatIterator, get_all_files
from glio import nn as gnn
from glio.nn import conv, convt, linear, seq, block
from glio.data import DSToTarget
from glio.helpers import cnn_output_size, tcnn_output_size
from glio.loaders import nifti
from glio.transforms import fToChannels, fToChannelsFirst,fToChannelsLast, z_normalize, norm_to01

from monai.losses.dice import DiceLoss, DiceFocalLoss, GeneralizedDiceFocalLoss
from monai.losses.ssim_loss import SSIMLoss
from monai.metrics import MeanIoU, SurfaceDiceMetric, DiceHelper, compute_iou, compute_dice # type:ignore

from schedulefree import AdamWScheduleFree
from came_pytorch import CAME
from madgrad import MADGRAD

In [29]:
TITLE = "BRaTS2024 histnorm 2D"
PRELOAD = 1.0
AREA_AROUND = 96,96
DEPTH_AROUND = 1

def loader(d: dict[str, Any], wh_around = AREA_AROUND, d_around = DEPTH_AROUND):
    # load center
    center = d["center"]
    # load seg
    seg = d["seg"]
    if isinstance(seg, str): seg = area_around(torch.from_numpy(np.asanyarray(d["obs"][seg][d["slice"]])), center, wh_around)
    if d_around == 0:
        # return 2D array around center and 2D seg array around center
        return (
            area_around(torch.from_numpy(np.asanyarray(d["obs"][d["dim"]][d["prep"]][d["mod"]][d["slice"]])), center, wh_around),
            one_hot_mask(seg, 4)
        )
    else:
        # return (d_around*2+1, h, w) array around center and 2D seg array around center
        return (
            area_around(torch.from_numpy(np.asanyarray(d["obs"][d["dim"]][d["prep"]][d["mod"]][d["slice"]-d_around:d["slice"]+d_around+1])), center, AREA_AROUND),
            one_hot_mask(seg, 4)
        )

def loader_multimod(d: dict[str, Any], wh_around = AREA_AROUND, d_around = DEPTH_AROUND):
    # load center
    center = d["center"]
    # load seg
    seg = d["seg"]
    if isinstance(seg, str): seg = area_around(torch.from_numpy(np.asanyarray(d["obs"][seg][d["slice"]])), center, wh_around)
    # get prep slice
    if d["prep"] is None: prepslice = slice(None)
    else: prepslice = d["prep"]
    # get mod slice
    if d["mod"] is None: modslice = slice(None)
    else: modslice = d["mod"]
    if d_around == 0:
        # return 2D array around center and 2D seg array around center
        return (
            area_around(torch.from_numpy(np.asanyarray(d["obs"][d["dim"]][prepslice, modslice, d["slice"]])).flatten(0, -3), center, wh_around)[:,:wh_around[0],:wh_around[1]],
            one_hot_mask(seg, 4)
        )
    else:
        # return (d_around*2+1, h, w) array around center and 2D seg array around center
        return (
            area_around(torch.from_numpy(np.asanyarray(d["obs"][d["dim"]][prepslice, modslice, d["slice"]-d_around:d["slice"]+d_around+1])).flatten(0, -3), center, AREA_AROUND)[:,:wh_around[0],:wh_around[1]],
            one_hot_mask(seg, 4)
        )

def tfm_input(data):
    return data[0]

def tfm_target(data):
    return data[1]

def add_zipstore_to_ds(
    ds:DSToTarget,
    store:zarr.ZipStore,
    loader=loader,
    tfminit=None,
    tfminp=tfm_input,
    tfmtarg=tfm_target,
    dims = ("top", "side", "front"),
    prep=("hist","nohist"),
    mods=("t1c","t1n","t2f","t2w"),
    mergeprep = False,
    mergemods = False,
    slstart = 0,
    slend = 0,
    wh_around = AREA_AROUND,
    loadseg=False,
    ):
    # convert hist/nohist to index
    prep = [(0 if i.lower()=="nohist" else 1) for i in prep]
    # convert modality to index
    allmods = ("t1c","t1n","t2f","t2w")
    mods = [allmods.index(i.lower()) for i in mods]
    # init group
    arr = zarr.group(store)
    # for all observations
    for obs in arr.values():
        # for top/front/side
        for dim in dims:
            # for hist/nohist
            for p in [None] if mergeprep else prep:
                # for t1c/t1n/t2f/t2w
                for m in [None] if mergemods else mods:
                    centers = obs[f"{dim}_centers"]
                    # for all slices along first dim, saving the center
                    for i,center in enumerate(centers):
                        # check if slice is in range
                        if slstart <= i < (len(centers) - slend):
                            # save segmentation as is doesnt take up much RAM
                            if loadseg: seg = area_around(torch.from_numpy(np.asanyarray(obs[f"{dim}_seg"][i])), center, wh_around).clone()
                            else: seg = f"{dim}_seg"
                            # add sample as dict
                            ds.add_sample((dict(obs=obs, dim=dim, prep=p, mod=m, slice=i, center=torch.from_numpy(np.asanyarray(center)), seg=seg)),
                                        loader, tfminit, tfminp, tfmtarg)

def plot_preds(learner:Learner, batch, softmax = True, unsqueeze = True):
    batch = list(batch)
    if unsqueeze:
        batch[0] = batch[0].unsqueeze(0)
        batch[1] = batch[1].unsqueeze(0)
    preds = learner.inference(batch[0])
    v = Visualizer()
    v.imshow_grid(batch[0][0], mode="bhw", label="вход")
    v.imshow_grid(batch[1][0], mode="bhw", label = "реальная карта")
    if softmax:
        output = torch.stack([preds[0],preds[0],preds[0]], dim=1)
        output[:,0] *=  F.softmax(preds[0],0)
        v.imshow_grid(output, mode="bchw", label="сырой выход")
        v.imshow_grid(to_binary(F.softmax(preds[0], 0)), mode="bhw", label="предсказанная карта")
    else:
        v.imshow_grid(preds[0], mode="bchw", label="сырой выход")
        v.imshow_grid(to_binary(preds[0], 0), mode="bhw", label="предсказанная карта")
    v.show(figsize=(24, 24), nrows=1)



In [3]:
stores = [zarr.ZipStore(i, mode="r") for i in get_all_files(r"E:\dataset\BRaTS2024 2D full norm", extensions="zip")]
len(stores)

24

In [4]:
dstrain = DSToTarget()
dstest = DSToTarget()
print('adding')
for i,store in enumerate(stores[:22]):
    print(i, end='\r')
    add_zipstore_to_ds(dstrain, store, slstart=1, slend=-1, loader=loader_multimod, mergemods=True, mergeprep=True)
for i,store in enumerate(stores[22:]):
    print(i, end='\r')
    add_zipstore_to_ds(dstest, store, slstart=1, slend=-1, loader=loader_multimod, mergemods=True, mergeprep=True)

adding


In [31]:
dstest.set_loader(loader_multimod)

In [32]:
from monai.networks.nets import SegResNetDS # type:ignore

summary(SegResNetDS(2, in_channels=24, out_channels=4, init_filters=16), (8, 24, 96,96))

path                                         module                                       input size               output size              params    buffers   
monai.networks.nets.segresnet_ds.SegResNetDS/encoder/conv_inittorch.nn.modules.conv.Conv2d                 (8, 24, 96, 96)          (8, 16, 96, 96)          3456      0         
monai.networks.nets.segresnet_ds.SegResNetDS/encoder/layers/0/blocks/0/norm1torch.nn.modules.batchnorm.BatchNorm2d       (8, 16, 96, 96)          (8, 16, 96, 96)          32        33        
monai.networks.nets.segresnet_ds.SegResNetDS/encoder/layers/0/blocks/0/act1torch.nn.modules.activation.ReLU             (8, 16, 96, 96)          (8, 16, 96, 96)          0         0         
monai.networks.nets.segresnet_ds.SegResNetDS/encoder/layers/0/blocks/0/conv1torch.nn.modules.conv.Conv2d                 (8, 16, 96, 96)          (8, 16, 96, 96)          2304      0         
monai.networks.nets.segresnet_ds.SegResNetDS/encoder/layers/0/blocks/0/norm2torch.nn.m

In [33]:
MODEL = SegResNetDS(2, in_channels=24, out_channels=4, init_filters=24)
NAME = f"{MODEL.__class__.__name__}"
LR = 1e-3
BATCH_SIZE = 64
N_EPOCHS = 100

# dl_train = DataLoader(ds_train, BATCH_SIZE)
dltrain = DataLoader(dstrain, BATCH_SIZE, shuffle=True)
dltest = DataLoader(dstest, BATCH_SIZE)

OPT = AdamWScheduleFree(MODEL.parameters(), lr=LR, eps=1e-6)
LOSS_FN = DiceFocalLoss(softmax=True)
SCHED = None

clean_mem()
MODEL = gnn.LSUV(MODEL, dltrain, max_iter=10)

LEARNER_NAME = f"{NAME} lr={LR} bs={BATCH_SIZE} loss = {LOSS_FN.__name__ if hasattr(LOSS_FN, '__name__') else type_str(LOSS_FN)} opt={OPT.__class__.__name__} sch={SCHED.__class__.__name__}"
learner = Learner(MODEL, LEARNER_NAME,
                  cbs = (Metric_Loss(), # Log_GradHistorgram(16), Log_SignalHistorgram(16),
                         Log_Time(), Save_Best(TITLE), Save_Last(TITLE), Log_LR(), Summary(), Accelerate("no"),
                         Metric_PredsTargetsFn(DiceLoss(softmax=True), step=4, name="dice loss"), Metric_Accuracy(True, True, True),
                         Metric_PredsTargetsFn(lambda x,y:compute_iou(to_binary(F.softmax(x[:,1:], 1)), y[:,1:]).nanmean(), step=8, name = "iou"),
                         Metric_PredsTargetsFn(lambda x,y:compute_dice(to_binary(F.softmax(x[:,1:], 1)), y[:,1:]).nanmean(), step=8, name="dice coeff"),
                         FastProgressBar(plot=True, step_batch=16, metrics=
                                         ["train loss", "test loss", "train dice coeff", "test dice coeff", "train iou", "test iou"]),
                         ),
                  loss_fn=LOSS_FN,
                  optimizer=OPT,
                  scheduler=SCHED,)
learner.fit(N_EPOCHS, dltrain, dltest)

plt.show()
plot_preds(learner, dstrain[3], softmax=True)
plot_preds(learner, dstrain[4], softmax=True)
plot_preds(learner, dstrain[6], softmax=True)


RuntimeError: stack expects each tensor to be equal size, but got [24, 96, 96] at entry 0 and [16, 96, 96] at entry 41