# Develop segmentor2 with example

In [1]:
import numpy as np
import dask.array as da
#import subprocess
import tempfile
from pathlib import Path
import os
cwd = os.getcwd()
import tempfile
import logging
from types import SimpleNamespace
import tqdm #progress bar in iterations
import pandas as pd

from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset

import albumentations as alb
import albumentations.pytorch

import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils

import matplotlib.pyplot as plt

In [2]:
import logging
logging.basicConfig(level=logging.INFO)
import tifffile

Settings

In [3]:
data_vol_norm_process = "mean_stdev_3" #standard clipping

cuda_device=0

nn1_loss_criterion='DiceLoss'
nn1_eval_metric='MeanIoU'
nn1_lr=1e-5
nn1_max_lr=3e-2

# nn1_epochs = 15
nn1_epochs = 5 # debug

nn1_batch_size = 2
nn1_num_workers = 1

#Default
nn1_models_class_generator = [{
'class':'smp', #smp: segmentation models pytorch
'arch': 'U_Net',
'encoder_name': 'resnet34',
'encoder_weights': 'imagenet', # TODO: support for using existing models (loading)
'in_nchannels':1,
'nclasses':3,
}]

nn1_axes_to_models_indices = [0,1,2] # By default use the same model for all axes
# To use different models, use [0,1,2] for model0 along z, model1 along y, and model2 along x

temp_data_outdir = None
cuda_str = f"cuda:{cuda_device}"

In [None]:
#segm2 = lgs2.cMultiAxisRotationsSegmentor2.create_simple_separate_models_per_axis(3)

trainlabels max value is 2, so 3 classes

In [None]:
nclasses =3

setup NN1 models

In [None]:
nn1_dict_gen = {'class':'smp', #smp: segmentation models pytorch
    'arch': 'U_Net',
    'encoder_name': 'resnet34',
    'encoder_weights': 'imagenet', # TODO: support for using existing models (loading)
    'in_nchannels':1, #greyscale
    'nclasses':nclasses,
}


In [None]:
nn1_models_class_generator = [nn1_dict_gen,
    nn1_dict_gen.copy(),
    nn1_dict_gen.copy()
]

In [None]:
nn1_models_class_generator

In [None]:
def create_nn1_ptmodel_from_class_generator(nn1_cls_gen_dict: dict):
    # get segm model from dictionary item
    model0=None

    if nn1_cls_gen_dict['class'].lower()=='smp': #unet, AttentionNet (manet) and fpn
        #Segmentation models pytorch
        arch = nn1_cls_gen_dict['arch'].lower()
        if arch=="unet" or arch=="u_net":
            NN_class = smp.Unet
        elif arch=="manet":
            model0 = smp.MAnet
        elif arch=="fpn":
            model0 = smp.FPN
        else:
            raise ValueError(f"arch:{arch} not valid.")
        
        model0 = NN_class(
            encoder_name = nn1_cls_gen_dict['encoder_name'],
            encoder_weights = nn1_cls_gen_dict['encoder_weights'],
            in_channels = nn1_cls_gen_dict['in_nchannels'],
            classes = nn1_cls_gen_dict['nclasses'],
            #activation = "sigmoid" # Whether to use activation or not, depends whether the loss function require slogits or not
            activation = None
            )
    else:
        raise ValueError(f"class {nn1_cls_gen_dict['class']} not supported.")
    
    # TODO: add other 2D model support, not just SMPs

    return model0

In [None]:
NN1_models = [ create_nn1_ptmodel_from_class_generator(x).to(f"cuda:{cuda_device}") for x in nn1_models_class_generator]

In [None]:
NN1_models

In [None]:
len(NN1_models)

In [None]:
nn1_axes_to_models_indices = [0,1,2]

In [None]:
idx_models = np.unique(nn1_axes_to_models_indices)
idx_models

# Load data and create dataloaders

load data

In [None]:
data_labels_fn=[
    ("./test_data/TS_0005_crop.tif", "./test_data/TS_0005_ribos_membr_crop.tif"),
]

traindatas=[]
trainlabels=[]

for datafn0, labelfn0 in data_labels_fn:
     #Make sure data and labels are curated in the correct data format
    traindatas.append(tifffile.imread(datafn0))
    trainlabels.append(tifffile.imread(labelfn0)) #In this case labels are already in uint8

In [None]:
print(trainlabels[0].max())

Normalise data to "mean_stdev_3"

In [None]:
traindata_list0=[]

# Clip data to -3*stdev and +3*stdev and normalises to values between 0 and 1
for d0 in traindatas:
    d0_mean = np.mean(d0)
    d0_std = np.std(d0)

    if d0_std==0:
        raise ValueError("Error. Stdev of data volume is zero.")
    
    d0_corr = (d0.astype(np.float32) - d0_mean) / d0_std
    d0_corr = (np.clip(d0_corr, -3.0, 3.0) +3.0) / 6.0
    
    traindata_list0.append(d0_corr*255)

In [None]:
traindata_list = [ t.astype(np.uint8) for t in traindata_list0]

In [None]:
len(traindata_list)

In [None]:
traindata_list[0].shape

In [None]:
traindata_list[0].dtype

In [None]:
# view some slices to ensure the data is properly loaded
import random
randomints= np.random.default_rng().permutation(256)

fig, axs = plt.subplots(1, 4, figsize=(10,5))
#fig.tight_layout()
for i in range(4):
    ir = randomints[i]
    axs[i].imshow(traindata_list[0][ir,:,:], cmap="gray", vmin=0, vmax=255)
    axs[i].set_axis_off()
plt.tight_layout()

Create datasets and dataloaders for each model

In [None]:
def get_train_augmentations_v0(h,w):

    def get_nearest_multiple_of_32(v):
        i32 = v//32
        return i32*32

    img_h, img_w = h,w

    img_h32, img_w32 = get_nearest_multiple_of_32(img_h),  get_nearest_multiple_of_32(img_w)
    assert img_h32>0 and img_w>0

    tfms0 =alb.Compose(
                [
                alb.RandomSizedCrop(
                    min_max_height= (img_h32//2, img_h32),
                    height=img_h32,
                    width=img_w32 ,
                    p=0.5,
                ),
                #Deciding what resizing augmentations is difficult not kowing what
                # sizes the images can be different

                alb.VerticalFlip(p=0.5),
                alb.RandomRotate90(p=0.5),
                alb.Transpose(p=0.5),
                alb.OneOf(
                    [
                        alb.ElasticTransform(
                            alpha=120, sigma=120 * 0.07, alpha_affine=120 * 0.04, p=0.5
                        ),
                        alb.GridDistortion(p=0.5),
                        alb.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=0.5),
                    ],
                    p=0.5,
                ),
                alb.CLAHE(p=0.5),
                alb.OneOf([alb.RandomBrightnessContrast(p=0.5),alb.RandomGamma(p=0.5)], p=0.5),
                alb.pytorch.ToTensorV2()
                ]
            )
    return tfms0

In [None]:
# TODO: check
class NN1_train_input_dataset_along_axes(Dataset):
    def __init__(self, datavols_list, labelsvols_list, axes=[0,1,2], cuda_device_str=0):
        
        self.cuda_device_str = cuda_device_str
        self.datavols_list = datavols_list
        self.labelsvols_list = labelsvols_list
        self.axes = axes

        #given an idx number, retrive the item, axis, slice number and transform
        self._idx_to_item=[]
        self._idx_to_ax=[]
        self._idx_to_slicen=[]
        self._idx_to_tfms = []

        #total_slices=0
        for id, d0 in enumerate(datavols_list):
            for ia, ax0 in enumerate(axes):
                nslices=d0.shape[ax0]
                #total_slices+= nslices

                id0_to_item = [id]*nslices
                self._idx_to_item.extend(id0_to_item)

                ax0_to_item = [ax0]*nslices
                self._idx_to_ax.extend(ax0_to_item)

                slice_range = np.arange(0,nslices).tolist()
                self._idx_to_slicen.extend(slice_range)

                if ax0==0:
                    t0 = get_train_augmentations_v0( *d0[0,:,:].shape )
                elif ax0==1:
                    t0 = get_train_augmentations_v0( *d0[:,0,:].shape )
                elif ax0==2:
                    t0 = get_train_augmentations_v0( *d0[:,:,0].shape )
                else:
                    raise ValueError(f"ax0 {ax0} not valid")
                self._idx_to_tfms.extend([t0]*nslices)

        total_slices = len(self._idx_to_item)

        assert total_slices==len(self._idx_to_ax) and total_slices==len(self._idx_to_ax) and total_slices==len(self._idx_to_slicen) and total_slices==len(self._idx_to_tfms)

        self.len = total_slices


    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        
        it = self._idx_to_item[idx]
        ax = self._idx_to_ax[idx]
        slicen = self._idx_to_slicen[idx]
        tfms = self._idx_to_tfms[idx]
        
        if ax==0:
            data_slice = self.datavols_list[it][slicen,:,:]
            labels_slice = self.labelsvols_list[it][slicen,:,:]
        elif ax==1:
            data_slice = self.datavols_list[it][:,slicen,:]
            labels_slice = self.labelsvols_list[it][:,slicen,:]
        elif ax==2:
            data_slice = self.datavols_list[it][:,:,slicen]
            labels_slice = self.labelsvols_list[it][:,:,slicen]
        else:
            raise ValueError(f"ax {ax} not valid")

        assert data_slice.shape == labels_slice.shape

        # Apply transforms
        res =tfms(image=data_slice, mask=labels_slice)

        data=res['image']
        labels=res['mask']
        
        data= data.to(self.cuda_device_str).float()
        labels=labels.to(self.cuda_device_str).long()

        #return a tuple data, mask
        return data, labels

In [None]:
trainlabels_list = trainlabels

In [None]:
# Test creating a dataset
ds0 = NN1_train_input_dataset_along_axes(
    traindata_list,
    trainlabels_list,
    [1,2], # change axis as desired
    cuda_device
)

In [None]:
len(ds0._idx_to_tfms)

Visualise data and respective labels from datasets (no transforms applied)

In [None]:
nimages=5
fig,axs = plt.subplots(1,nimages,figsize=(15,5))

randomints= np.random.default_rng().permutation(256)

for i in range(nimages):
    r0 = randomints[i]
    datai, labeli = ds0[r0]
    print(f"i:{i}, datai shape:{datai.shape}, type:{datai.dtype}   label shape:{labeli.shape}, type:{labeli.dtype}")
    datai=datai.detach().cpu().numpy()[0,:,:]
    labeli = labeli.detach().cpu().numpy()[:,:]
    axs[i].imshow(datai, cmap="gray")
    axs[i].imshow(labeli,cmap='tab10', alpha=0.5, vmax=10)
    axs[i].set_axis_off()

    if i==nimages-1:
        break

OK

In [None]:
nn1_axes_to_models_indices

In [None]:
#nn1_axes_to_models_indices = [0,1,2]

In [None]:
np.flatnonzero(
        np.array(nn1_axes_to_models_indices) == 2
    ).tolist()

In [None]:
dataloaders_train=[]
dataloaders_test=[]
for i in range(len(NN1_models)):
    #Gets the axes that the NN1 model is supposed to be used
    model_axes= np.flatnonzero(
        np.array(nn1_axes_to_models_indices) == i
    ).tolist()

    dl_train=None
    dl_test=None

    if len(model_axes)>0:

        ds0 = NN1_train_input_dataset_along_axes(
            traindata_list,
            trainlabels_list,
            model_axes,
            cuda_device
        )

        dset1, dset2 = torch.utils.data.random_split(ds0, [0.8,0.2])

        dl_train = DataLoader(dset1, batch_size=nn1_batch_size, shuffle=True)
        dl_test = DataLoader(dset2, batch_size=nn1_batch_size, shuffle=True)

    dataloaders_train.append(dl_train)
    dataloaders_test.append(dl_test)


In [None]:
len(dataloaders_train[0])

In [None]:
len(dataloaders_test[0])

In [None]:
dataloaders_train[0]

visualise some data and labels

In [None]:
nimages=5
fig,axs = plt.subplots(1,nimages,figsize=(15,5))
for i, (datai,labeli) in enumerate(dataloaders_train[2]): # Change index to 0,1,2 for z,y,x
    print(f"i:{i}, datai shape:{datai.shape}, type:{datai.dtype}   label shape:{labeli.shape}, type:{labeli.dtype}")
    datai=datai.detach().cpu().numpy()[0,0,:,:]
    labeli = labeli.detach().cpu().numpy()[0,:,:]
    axs[i].imshow(datai, cmap="gray")
    axs[i].imshow(labeli,cmap='tab10', alpha=0.5, vmax=10)
    axs[i].set_axis_off()

    if i==nimages-1:
        break

looks ok

# setup loss function

In [None]:
nn1_loss_criterion

In [None]:
nn1_loss_func_and_activ = None
activ = torch.nn.Sigmoid()
if "crossentropyloss" in nn1_loss_criterion.lower():
    nn1_loss_func = torch.nn.CrossEntropyLoss().to(cuda_str) # expects logits!
    
    # or can use
    # nn1_loss_func = torch.nn.functional.cross_entropy(pred_logits, target)
    
    nn1_loss_func_and_activ= {"func":nn1_loss_func, "activ":activ}
elif "diceloss" in nn1_loss_criterion.lower():
    nn1_loss_func = smp.losses.DiceLoss(mode='multiclass', from_logits=True).to(cuda_str)
    nn1_loss_func_and_activ= {"func":nn1_loss_func, "activ":None}
else:
    raise ValueError(f"{nn1_loss_criterion} not a valid loss criteria")

In [None]:
nn1_loss_func_and_activ

# setup metric function

In [None]:
# Setup metrics for test data
nn1_metric_func = None
if "iou" in nn1_eval_metric.lower():
    nn1_metric_func = segmentation_models_pytorch.utils.metrics.IoU()
elif "dice" in nn1_eval_metric.lower() or "fscore" in nn1_eval_metric.lower():
    nn1_metric_func = segmentation_models_pytorch.utils.metrics.Fscore()
elif "accuracy" in nn1_eval_metric.lower():
    nn1_metric_func = segmentation_models_pytorch.utils.metrics.Accuracy()

In [None]:
nn1_metric_func

# Setup training of each model individually

train, test loops

In [None]:
def train_loop(dataloader, model, loss_func_and_activ, optimizer, scaler, scheduler, do_log=True):
    loss_fn = loss_func_and_activ["func"]
    activ_fn = loss_func_and_activ["activ"]
    
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        #X=X_parse(X)
        # Compute prediction and loss
        pred = model(X)

        if activ_fn is not None:
            pred = activ_fn(pred)

        loss = loss_fn(pred, y)

        # Backpropagation
        #loss.backward()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        #optimizer.step() #step done by the scheduler
        optimizer.zero_grad()

        scheduler.step()

        if do_log and batch % 50 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            logging.info(f"batch:{batch}  loss: {loss:>7f}  [{current:>5d}/{size:>5d}]. lr:{scheduler.get_last_lr()}")

def test_loop(dataloader, model, loss_func_and_activ, metric_fn=None):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    #size = len(dataloader.dataset)
    #num_batches = len(dataloader)

    loss_fn = loss_func_and_activ["func"]
    activ_fn = loss_func_and_activ["activ"]

    test_losses=[]
    test_metrics=[]

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            #X=X_parse(X)
            #y=y_parse(y)
            pred = model(X)

            if activ_fn is not None:
                pred = activ_fn(pred)

            loss = loss_fn(pred, y)

            test_loss = loss.item()
            test_losses.append(test_loss)
            
            if metric_fn is not None:
                pred_argmax = torch.argmax(pred, dim=1)
                metric = metric_fn(pred_argmax, y).item()
                test_metrics.append(metric)
            # #metric
            # correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    avg_loss = np.mean(np.array(test_losses))
    logging.info(f"Avg loss: {avg_loss:>8f}")

    avg_metric=None
    if not metric_fn is None:
        avg_metric = np.mean(np.array(test_metrics))
        logging.info(f"Avg metric: {avg_metric:>8f}")

    return {"avg_loss":avg_loss, "avg_metric":avg_metric, "test_metrics":test_metrics, "test_losses":test_losses}

def train_model(model0, dl_train, dl_test, loss_func_and_activ, optimizer, scaler, scheduler, epochs, metric_fn=None):
    logging.info("train_model()")
    test_results=[]
    for t in range(epochs):
        logging.info(f"---- Epoch {t+1}/{epochs} ----")
        train_loop(dl_train, model0, loss_func_and_activ, optimizer, scaler, scheduler)

        test_res=None
        if dl_test is not None:
            test_res= test_loop(dl_test, model0, loss_func_and_activ, metric_fn=metric_fn)
            test_results.append(test_res)
    logging.info(f"Done!")
    if dl_test is not None:
        logging.info(f"Final test loss is : {test_res['avg_loss']}, and metric is: {test_res['avg_metric']}")
    return {"test_results": test_results}


In [None]:
print(len(NN1_models))
print(len(dataloaders_train))
print(len(dataloaders_test))

Train first model

In [None]:
model= NN1_models[0]
dl_train0 = dataloaders_train[0]
dl_test0 = dataloaders_test[0]

In [None]:
#Setup optimizer and scaler
optimizer = torch.optim.AdamW(model.parameters(), lr=nn1_lr)
scaler=torch.cuda.amp.GradScaler()

epochs = nn1_epochs
#epochs = 10

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr= nn1_max_lr,
    steps_per_epoch=len(dl_train0),
    epochs=epochs,
    #pct_start=0.1, #default=0.3
    )

In [None]:
train_model(model, dl_train0, dl_test0, nn1_loss_func_and_activ, optimizer, scaler, scheduler,
            epochs=epochs,
            metric_fn=nn1_metric_func
            )

Prediction of some slices of first model

In [None]:
model= NN1_models[0]
model.eval()
nimages=3
fig,axs = plt.subplots(2,nimages,figsize=(15,10))
for i, (datai,labeli) in enumerate(dataloaders_train[0]):
    print(f"i:{i}, datai shape:{datai.shape}, type:{datai.dtype}   label shape:{labeli.shape}, type:{labeli.dtype}")
    #datai=datai.detach().cpu().numpy()[0,0,:,:]
    labeli = labeli.detach().cpu().numpy()[0,:,:]
    
    X=datai
    pred=model(X)
    pred_argmax = torch.argmax(pred, dim=1)

    pred_np = pred_argmax.detach().cpu().numpy()[0,:,:]
    print(f"i:{i}, pred_max:{pred_np.max()}")
    datai_np = datai.detach().cpu().numpy()[0,0,:,:]
    axs[0,i].imshow(datai_np, cmap="gray")
    axs[0,i].imshow(pred_np,cmap='tab10', alpha=0.5, vmax=10)
    axs[0,i].set_axis_off()

    axs[1,i].imshow(datai_np, cmap="gray")
    axs[1,i].imshow(labeli,cmap='tab10', alpha=0.5, vmax=10)
    axs[1,i].set_axis_off()

    if i==nimages-1:
        break

    #predictions on top, ground truth at the bottom

Not bad

Second model (Y axis) training

In [None]:
model= NN1_models[1]
dl_train0 = dataloaders_train[1]
dl_test0 = dataloaders_test[1]

In [None]:
#Setup optimizer and scaler
optimizer = torch.optim.AdamW(model.parameters(), lr=nn1_lr)
scaler=torch.cuda.amp.GradScaler()

epochs = nn1_epochs
#epochs = 10

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr= nn1_max_lr,
    steps_per_epoch=len(dl_train0),
    epochs=epochs,
    #pct_start=0.1, #default=0.3
    )

In [None]:
train_model(model, dl_train0, dl_test0, nn1_loss_func_and_activ, optimizer, scaler, scheduler,
            epochs=epochs,
            metric_fn=nn1_metric_func
            )

Prediction of some slices of second model

In [None]:
model= NN1_models[1]
model.eval()
nimages=3
fig,axs = plt.subplots(2,nimages,figsize=(15,10))
for i, (datai,labeli) in enumerate(dataloaders_train[1]):
    print(f"i:{i}, datai shape:{datai.shape}, type:{datai.dtype}   label shape:{labeli.shape}, type:{labeli.dtype}")
    #datai=datai.detach().cpu().numpy()[0,0,:,:]
    labeli = labeli.detach().cpu().numpy()[0,:,:]
    
    X=datai
    pred=model(X)
    pred_argmax = torch.argmax(pred, dim=1)

    pred_np = pred_argmax.detach().cpu().numpy()[0,:,:]
    print(f"i:{i}, pred_max:{pred_np.max()}")
    datai_np = datai.detach().cpu().numpy()[0,0,:,:]
    axs[0,i].imshow(datai_np, cmap="gray")
    axs[0,i].imshow(pred_np,cmap='tab10', alpha=0.5, vmax=10)
    axs[0,i].set_axis_off()

    axs[1,i].imshow(datai_np, cmap="gray")
    axs[1,i].imshow(labeli,cmap='tab10', alpha=0.5, vmax=10)
    axs[1,i].set_axis_off()

    if i==nimages-1:
        break

    #predictions on top, ground truth at the bottom

Third model (X axis) training

In [None]:
model= NN1_models[2]
dl_train0 = dataloaders_train[2]
dl_test0 = dataloaders_test[2]

In [None]:
#Setup optimizer and scaler
optimizer = torch.optim.AdamW(model.parameters(), lr=nn1_lr)
scaler=torch.cuda.amp.GradScaler()

epochs = nn1_epochs
#epochs = 10

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr= nn1_max_lr,
    steps_per_epoch=len(dl_train0),
    epochs=epochs,
    #pct_start=0.1, #default=0.3
    )

In [None]:
train_model(model, dl_train0, dl_test0, nn1_loss_func_and_activ, optimizer, scaler, scheduler,
            epochs=epochs,
            metric_fn=nn1_metric_func
            )

Prediction of some slices of second model

In [None]:
model= NN1_models[2]
model.eval()
nimages=3
fig,axs = plt.subplots(2,nimages,figsize=(15,10))
for i, (datai,labeli) in enumerate(dataloaders_train[2]):
    print(f"i:{i}, datai shape:{datai.shape}, type:{datai.dtype}   label shape:{labeli.shape}, type:{labeli.dtype}")
    #datai=datai.detach().cpu().numpy()[0,0,:,:]
    labeli = labeli.detach().cpu().numpy()[0,:,:]
    
    X=datai
    pred=model(X)
    pred_argmax = torch.argmax(pred, dim=1)

    pred_np = pred_argmax.detach().cpu().numpy()[0,:,:]
    print(f"i:{i}, pred_max:{pred_np.max()}")
    datai_np = datai.detach().cpu().numpy()[0,0,:,:]
    axs[0,i].imshow(datai_np, cmap="gray")
    axs[0,i].imshow(pred_np,cmap='tab10', alpha=0.5, vmax=10)
    axs[0,i].set_axis_off()

    axs[1,i].imshow(datai_np, cmap="gray")
    axs[1,i].imshow(labeli,cmap='tab10', alpha=0.5, vmax=10)
    axs[1,i].set_axis_off()

    if i==nimages-1:
        break

    #predictions on top, ground truth at the bottom

# Predict volume(s) using the models

setup temporary folder to store the predicted volumes

In [None]:
tempdir_pred= tempfile.TemporaryDirectory()
path_out_results = Path(tempdir_pred.name)
logging.info(f"tempdir_pred_path:{path_out_results}")

In [None]:
import h5py
def _save_pred_data(folder, data, count,axis, rot):
    # Saves predicted data to h5 file in tempdir and return file path in case it is needed
    file_path = f"{folder}/pred_{count}_{axis}_{rot}.h5"

    logging.info(f"Saving data of shape {data.shape} to {file_path}.")
    with h5py.File(file_path, "w") as f:
        f.create_dataset("/data", data=data)

    return file_path

In [None]:
data_to_predict = traindata_list[0] #Only first volume for testing.
# Volumes in traindata_list has already been normalised/clipped

In [None]:
class VolumeSlicerDataset(Dataset):

    def __init__(self, datavol, axis, per_slice_tfms=None, device_str="cuda:0"):
        assert datavol.ndim==3
        assert axis==0 or axis==1 or axis==2

        self.datavol=datavol
        self.axis=axis
        self.per_slice_tfms=per_slice_tfms
        self.device_str = device_str

    def __len__(self):
        return self.datavol.shape[self.axis]

    def __getitem__(self, idx):
        
        data_slice=None
        if self.axis==0:
            data_slice = self.datavol[idx,:,:]
        elif self.axis==1:
            data_slice = self.datavol[:,idx,:]
        elif self.axis==2:
            data_slice = self.datavol[:,:,idx]

        res = data_slice
        # Apply transform
        if self.per_slice_tfms is not None:
            res = self.per_slice_tfms(data_slice)

        #Convert to tensor and send to device
        res_torch = torch.unsqueeze(torch.from_numpy(res), dim=0).float().to(self.device_str)

        return res_torch

In [None]:
# Test
ds0 = VolumeSlicerDataset(data_to_predict, axis=0 , per_slice_tfms=None, device_str=cuda_str)

Visualise test dataset

In [None]:
nimages=5
fig,axs = plt.subplots(1,nimages,figsize=(15,5))

randomints= np.random.default_rng().permutation(256)

for i in range(nimages):
    r0 = randomints[i]
    datai_t = ds0[r0]
    print(f"i:{i}, datai shape:{datai_t.shape}, dtype:{datai_t.dtype}, type:{type(datai_t)}")
    datai=datai_t.detach().cpu().numpy()[0,:,:]
    axs[i].imshow(datai, cmap="gray")
    axs[i].set_axis_off()

    if i==nimages-1:
        break

In [None]:
dl0 = DataLoader(dataset=ds0, batch_size=nn1_batch_size, shuffle=False)

In [None]:
#Use model0
model=NN1_models[0]

In [None]:
model.eval()

SM_func = torch.nn.Softmax(dim=1) # to get probabilities

Run predictions for the whole volume along axes previously specified

In [None]:
preds_list = []
labels_list = []
with torch.no_grad():
    for ibatch, x in enumerate(dl0):
        logging.info(f"ibatch: {ibatch} ")
        X= model(x)
        #logging.info(f"X.shape:{X.shape}")

        pred_probs_slice = SM_func(X)
        #logging.info(f"pred_probs_slice.shape:{pred_probs_slice.shape}")
        #preds_list.append(pred_probs_slice)

        # get labels using argmax
        lbl_slice = torch.argmax(pred_probs_slice, dim=1)
        #labels_list.append(lbl_slice)

        # need to move out from device, otherwise it uses too much RAM

        pred_probs_slice_np = pred_probs_slice.detach().cpu().numpy()
        lbl_slice_np = lbl_slice.detach().cpu().numpy().astype(np.uint8)

        preds_list.append(pred_probs_slice_np)
        labels_list.append(lbl_slice_np)

In [None]:
len(preds_list)

In [None]:
preds_list_conc = np.concatenate(preds_list, axis=0)
preds_list_conc.shape

In [None]:
#preds_z = np.swapaxes(np.concatenate(preds_list, axis=0),0,1)
preds_z = np.transpose(preds_list_conc, axes=(1,0,2,3))

In [None]:
preds_z.shape

In [None]:
labels_pred_z = np.concatenate(labels_list, axis=0)

In [None]:
labels_pred_z.shape

view some z-slices and predictions

In [None]:
nimages=4
fig,axs = plt.subplots(1,nimages,figsize=(15,5))

randomints= np.random.default_rng().permutation(256)

for i in range(nimages):
    r0 = randomints[i]
    datai = data_to_predict[r0,:,:]
    labeli = labels_pred_z[r0,:,:]
    axs[i].imshow(datai, cmap="gray")
    axs[i].set_axis_off()
    axs[i].imshow(labeli , cmap='tab10', alpha=0.5, vmax=10)
    axs[i].set_axis_off()

OK

view probabilities

In [None]:
nimages=4
fig,axs = plt.subplots(1,nimages,figsize=(15,5))

randomints= np.random.default_rng().permutation(256)

for i in range(nimages):
    r0 = randomints[i]
    datai = data_to_predict[r0,:,:]
    predi = preds_z[2,r0,:,:]
    axs[i].imshow(datai, cmap="gray")
    axs[i].set_axis_off()
    axs[i].imshow(predi, cmap='viridis', alpha=0.4)
    axs[i].set_axis_off()

In [None]:
def nn1_predict_slices_along_axis(datavol, axis, device_str):
    ds0 = VolumeSlicerDataset(datavol, axis , per_slice_tfms=None, device_str=device_str)
    dl0 = DataLoader(dataset=ds0, batch_size=nn1_batch_size, shuffle=False)

    # Get correct model
    model_index = nn1_axes_to_models_indices[axis]
    model = NN1_models[model_index]
    logging.info(f"axis:{axis}, use model_index: {model_index}")

    model.eval()
    
    SM_func = torch.nn.Softmax(dim=1)

    preds_list = []
    labels_list = []
    for ibatch, x in enumerate(dl0):
        # x.shape is (batchsize, 1, 256,256) with 256 being the imagesize
        X= model(x)
        #pred shape is (batchsize, 3, 256, 256)

        pred_probs_slice = SM_func(X) #Convert to probabilities

        # get labels using argmax
        lbl_slice = torch.argmax(pred_probs_slice, dim=1)
        #labels_list.append(lbl_slice)

        # need to move away from device, otherwise it uses too much VRAM
        pred_probs_slice_np = pred_probs_slice.detach().cpu().numpy()
        lbl_slice_np = lbl_slice.detach().cpu().numpy().astype(np.uint8)

        preds_list.append(pred_probs_slice_np)
        labels_list.append(lbl_slice_np)

    logging.info("Prediction of all slices complete. Now stacking and getting the right orientation.")
    # stack slices
    preds_list_conc = np.concatenate(preds_list, axis=0) # shape will be (256,3,256,256)
    labels_pred_conc = np.concatenate(labels_list, axis=0)

    pred_oriented = None
    labels_oriented = None
    if axis==0:
        pred_oriented = np.transpose(preds_list_conc, axes=(1,0,2,3))
        labels_oriented = labels_pred_conc # no need to orient
    elif axis==1:
        pred_oriented = np.transpose(preds_list_conc, axes=(1,2,0,3))
        labels_oriented = np.transpose(labels_pred_conc, axes=(1,0,2))
    elif axis==2:
        pred_oriented = np.transpose(preds_list_conc, axes=(1,2,3,0))
        labels_oriented = np.transpose(labels_pred_conc, axes=(1,2,0))

    #with pred_oriented note that class probability is at the start
    return pred_oriented, labels_oriented

In [None]:
# test
res = nn1_predict_slices_along_axis(data_to_predict, axis=2, device_str=cuda_str)

In [None]:
# import napari
# NV=napari.Viewer()
# NV.add_image(data_to_predict)
# NV.add_labels(res[1])
# NV.add_labels(trainlabels_list[0])

In [None]:
# import napari
# NV=napari.Viewer()
# NV.add_image(data_to_predict)
# NV.add_image(res[0][1,...])

## Several volumes, different rotations and axis and save

In [None]:
pred_data_probs_filenames=[] #Will store results in files, and keep the filenames as reference
pred_data_labels_filenames=[]
pred_sets=[]
pred_planes=[]
pred_rots=[]
pred_ipred=[]
pred_shapes=[]
itag=0
iset=0

for krot in range(0, 4): #Around axis rotations
    rot_angle_degrees = krot * 90
    logging.info(f"Volume to be rotated by {rot_angle_degrees} degrees")

    #Predict 3 axis
    #YX, along Z
    # planeYX=(1,2)
    logging.info("Predicting YX slices, along Z")
    data_vol = np.array(np.rot90(data_to_predict,krot, axes=(1,2))) #rotate

    prob0,lab0 = nn1_predict_slices_along_axis(data_vol, axis=0, device_str=cuda_str)

    #invert rotations before saving
    pred_probs = np.rot90(prob0, -krot, axes=(2,3)) 
    pred_labels = np.rot90(lab0, -krot, axes=(1,2)) #note that class is at start

    fn = _save_pred_data(path_out_results,pred_probs, iset, "YX", rot_angle_degrees)
    pred_data_probs_filenames.append(fn)
    fn = _save_pred_data(path_out_results,pred_labels, iset, "YX_labels", rot_angle_degrees)
    pred_data_labels_filenames.append(fn)
    
    pred_sets.append(iset)
    pred_planes.append("YX")
    pred_rots.append(rot_angle_degrees)
    pred_ipred.append(itag)
    pred_shapes.append(pred_labels.shape)
    itag+=1



    #ZX
    logging.info("Predicting ZX slices, along Y")
    #planeZX=(0,2)
    data_vol = np.array(np.rot90(data_to_predict,krot, axes=(0,2))) #rotate
    prob0,lab0 = nn1_predict_slices_along_axis(data_vol, axis=1, device_str=cuda_str)

    pred_probs = np.rot90(prob0, -krot, axes=(1,3)) #invert rotation before saving
    pred_labels = np.rot90(lab0, -krot, axes=(0,2))

    fn = _save_pred_data(path_out_results,pred_probs, iset, "ZX", rot_angle_degrees)
    pred_data_probs_filenames.append(fn)
    fn = _save_pred_data(path_out_results,pred_labels, iset, "ZX_labels", rot_angle_degrees)
    pred_data_labels_filenames.append(fn)
    
    pred_sets.append(iset)
    pred_planes.append("ZX")
    pred_rots.append(rot_angle_degrees)
    pred_ipred.append(itag)
    pred_shapes.append(pred_labels.shape)
    itag+=1



    #ZY
    logging.info("Predicting ZY slices, along X")
    #planeZY=(0,1)
    data_vol = np.array(np.rot90(data_to_predict,krot, axes=(0,1))) #rotate
    prob0,lab0 = nn1_predict_slices_along_axis(data_vol, axis=2, device_str=cuda_str)

    pred_probs = np.rot90(prob0, -krot, axes=(1,2)) #invert rotation before saving
    pred_labels = np.rot90(lab0, -krot, axes=(0,1))
    
    fn = _save_pred_data(path_out_results,pred_probs, iset, "ZY", rot_angle_degrees)
    pred_data_probs_filenames.append(fn)
    fn = _save_pred_data(path_out_results,pred_labels, iset, "ZY_labels", rot_angle_degrees)
    pred_data_labels_filenames.append(fn)
    
    pred_sets.append(iset)
    pred_planes.append("ZY")
    pred_rots.append(rot_angle_degrees)
    pred_ipred.append(itag)
    pred_shapes.append(pred_labels.shape)
    itag+=1

Saved predictions to folder `C:\Users\LUIS-W~1\AppData\Local\Temp\tmp0rbspoqt/ `

Collect information to a pandas table

In [None]:
all_pred_pd = pd.DataFrame({
    'pred_data_probs_filenames': pred_data_probs_filenames,
    'pred_data_labels_filenames': pred_data_labels_filenames,
    'pred_sets':pred_sets,
    'pred_planes':pred_planes,
    'pred_rots':pred_rots,
    'pred_ipred':pred_ipred,
    'pred_shapes': pred_shapes,
})

Save pandas table in case we need to exit before training NN2

In [None]:
all_pred_pd.to_csv("developing_segmentor2_nn1_temp_results.csv")

# NN2 training

Similar to segmentor.py but using own pytorch MLP classifier (Kaggle)

In [None]:
import numpy as np
import dask.array as da
#import subprocess
import tempfile
from pathlib import Path
import os
cwd = os.getcwd()
import tempfile
import logging
from types import SimpleNamespace
import tqdm #progress bar in iterations
import pandas as pd

from torch.utils.data import Dataset, DataLoader, random_split, Subset

import torch
import torch.nn as nn
import albumentations as alb
import albumentations.pytorch

import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils

import matplotlib.pyplot as plt

import logging
logging.basicConfig(level=logging.INFO)
import tifffile

In [None]:
nn2_MLP_models_class_generator = {
    "nn2_hidden_layer_sizes" : "10,10",
    "nn2_activation": 'tanh',
    "nn2_out_nclasses": 3,
    "nn2_in_nchannels": 3*12
}

#nn2_max_iter = 1000
nn2_ntrain = 262144 #Note that this is not a MLPClassifier parameter
nn2_train_epochs = 20
nn2_batch_size = 4096
nn2_lr = 1e-6
nn2_max_lr = 5e-2

In [None]:
# Optional: Load csv file
all_pred_pd = pd.read_csv("developing_segmentor2_nn1_temp_results.csv")
all_pred_pd

In [None]:
all_pred_pd["pred_data_probs_filenames"][0]

Load data

In [None]:
#similar to segmentor.py
import h5py
data_all_np5d=None

logging.debug("Aggregating multiple sets onto a single volume data_all_np5d")
# aggregate multiple sets for data
for i,prow in all_pred_pd.iterrows():

    prob_filename = prow['pred_data_probs_filenames']
    with h5py.File(prob_filename,'r') as f:
        data0 = np.array(f["data"])

    if i==0:
        #initialise
        logging.info(f"data0.shape:{data0.shape}")
        all_shape0 = (
            1, # needs to be adjusted
            12, # needs to be adjusted, perhaps can be collected from dataframe
            *data0.shape
            )

        data_all_np5d=np.zeros( all_shape0 , dtype=data0.dtype)

    
    ipred=prow['pred_ipred']
    iset=prow['pred_sets']

    data_all_np5d[iset,ipred, :,:,:, :] = data0

In [None]:
data_all_np5d.shape

Setup dataloader and Dataset for training NN2. Based in Kaggle solution

In [None]:
data_all_np5d.shape

In [None]:
p0 = np.transpose( data_all_np5d , axes=(0,3,4,5,1,2))

In [None]:
p0.shape

In [None]:
p0.shape[:3]

In [None]:
data_flat_for_mlp = p0.reshape( (np.prod(p0.shape[:4]), p0.shape[4]*p0.shape[5]))

In [None]:
data_flat_for_mlp.shape

ok. Note that the MLP input must have 12*3= 36 inputs

In [None]:
trainlabels_list_np = np.array(trainlabels_list)
trainlabels_list_np.shape

In [None]:
label_flat_for_mlp = trainlabels_list_np.ravel()
label_flat_for_mlp.shape

In [None]:
X_train= torch.from_numpy(data_flat_for_mlp).float()
y_train= torch.from_numpy(label_flat_for_mlp).long()

In [None]:
X_train.shape

In [None]:
subset_indices = torch.randperm(X_train.shape[0])[:nn2_ntrain]

In [None]:
len(subset_indices)

In [None]:
y_train

Note that TensorDataset will create X,y samples by indexing in the first dimension

In [None]:
X_train_subset = X_train[subset_indices,:].to(cuda_str)
y_train_subset = y_train[subset_indices].to(cuda_str)

In [None]:
subset_dataset = TensorDataset(X_train_subset, y_train_subset)

In [None]:
subset_dataset

In [None]:
len(subset_dataset)

In [None]:
subset_dataset[0]

Create dataloader

In [None]:
nn2_train_loader = DataLoader(subset_dataset, batch_size=nn2_batch_size, shuffle=True)
# no need for random as it has already been randomized ?

In [None]:
nn2_train_loader

## Setup MLP based in the number of input channels

In [3]:
class MLPClassifier(nn.Module):
    # MLP classifier with sigmoid activation

    # Should I add softmax?
    def __init__(self, input_size:int, hiden_sizes_list:list, output_size:int, activ_str:str):
        super().__init__()

        size0= input_size

        self.hidden = nn.ModuleList()

        for hls in hiden_sizes_list:
            hid_layer0 =  nn.Linear(size0, hls)
            self.hidden.append(hid_layer0)
            size0=hls
        #last layer
        self.hidden.append(nn.Linear(size0, output_size))

        if "tanh" in activ_str.lower():
            self.activ = nn.functional.tanh
        elif "relu" in activ_str.lower():
            self.activ = nn.functional.relu
        elif "sigm" in activ_str.lower():
            self.activ = nn.functional.sigmoid
        else:
            raise ValueError(f"activ_str {activ_str} not valid")

    def forward(self, x):
        # for i,hlayer in self.hidden:
        #     x= self.activ(hlayer(x))
        for i in range(len(self.hidden)-1):
            x= self.activ(self.hidden[i](x))
        
        #Last layer
        x = self.hidden[-1](x)
        
        #x = self.sigm(x)
        return x #returns logits
    
    # def predict_class_as_cpu_np(self,x):
    #     p0 = self.forward(x)
    #     pred = torch.squeeze(torch.argmax(p0, dim=1))
    #     return pred.detach().cpu().numpy()

In [4]:
def create_nn2_ptmodel_from_class_generator(nn2_cls_gen_dict: dict ):
    hid_layers = nn2_cls_gen_dict['nn2_hidden_layer_sizes'].split(",")

    if len(hid_layers)==0:
        ValueError(f"Invalid nn2_hidden_layer_sizes : {nn2_cls_gen_dict['nn2_hidden_layer_sizes']}")

    hid_layers_num_list = list(map(int, hid_layers))
    logging.info(f"hid_layers_num_list: {hid_layers_num_list}")
    
    model0 = MLPClassifier(
        nn2_cls_gen_dict['nn2_in_nchannels'],
        hid_layers_num_list,
        nn2_cls_gen_dict['nn2_out_nclasses'],
        nn2_cls_gen_dict["nn2_activation"]
        )
        
    return model0

In [5]:
NN2_model_fusion = create_nn2_ptmodel_from_class_generator(nn2_MLP_models_class_generator)

NameError: name 'nn2_MLP_models_class_generator' is not defined

In [None]:
NN2_model_fusion.to(cuda_str)

In [None]:
#inp_test = torch.from_numpy(np.random.random(size=(4096, 36)).astype(np.float32)) # batches, 12 inputs each
inp_test = torch.from_numpy(np.random.random(size=(4096, 36))).float().to(cuda_str) # batches, 12 inputs each

In [None]:
NN2_model_fusion.eval()

In [None]:
NN2_model_fusion(inp_test)

Setup training

In [None]:
#Setup optimizer and scaler
model=NN2_model_fusion
optimizer = torch.optim.AdamW(model.parameters(), lr=nn2_lr)
scaler=torch.cuda.amp.GradScaler()

epochs = nn2_train_epochs
#epochs = 10

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr= nn2_max_lr,
    steps_per_epoch=len(nn2_train_loader),
    epochs=epochs,
    #pct_start=0.1, #default=0.3
    )

nn2_loss_func_and_activ= {"func": nn.CrossEntropyLoss(), "activ":None}
# activ = torch.nn.Sigmoid() we may need this

In [None]:
train_model(
    model,
    nn2_train_loader,
    None, # use train data as test?
    nn2_loss_func_and_activ,
    optimizer, scaler, scheduler,
    epochs=epochs,
    metric_fn=None
)

# Save the 3 models and the MLP classifier

In [None]:
import datetime
DATE=str(datetime.date.today())
DATE

In [None]:
TIME=f"{datetime.datetime.now().hour:02d}{datetime.datetime.now().minute:02d}"
TIME

In [None]:
fname_stem=f"{DATE}_{TIME}"
fname_stem

In [None]:
model_fn = f"{fname_stem}_model.lgsegm2"
model_fn

In [None]:
# NN1_models_state_dict = []
# for i,m in enumerate(NN1_models):
#     #NN1_models_dict[str(i)] = m.state_dict()
#     NN1_models_state_dict.append( m.state_dict())

In [None]:
NN1_models_state_dict = [ m.state_dict() for m in NN1_models]

In [None]:
NN1_models_state_dict

In [None]:
train_info = f"""
nn1_loss_criterion: {nn1_loss_criterion}
nn1_eval_metric: {nn1_eval_metric}
nn1_lr: {nn1_lr}
nn1_max_lr: {nn1_max_lr}
nn1_epochs: {nn1_epochs}

nn1_batch_size = 2
nn1_num_workers = 1

nn2_ntrain: {nn2_ntrain}
nn2_train_epochs: {nn2_train_epochs}
nn2_batch_size: {nn2_batch_size}
nn2_lr: {nn2_lr}
nn2_max_lr: {nn2_max_lr}

"""

In [None]:
dict_to_save={
    "nn1_models_class_generator": nn1_models_class_generator,
    "nn1_axes_to_models_indices": nn1_axes_to_models_indices,
    "data_vol_norm_process": data_vol_norm_process,
    "NN1_models_state_dict": NN1_models_state_dict,

    "nn2_MLP_models_class_generator": nn2_MLP_models_class_generator,
    "NN2_model_dict":NN2_model_fusion.state_dict(),

    "train_info": train_info
}

Saves

In [None]:
torch.save(dict_to_save, model_fn)

# NN2 predictions with one volume

In [None]:
data_all_np5d.shape

In [None]:
type(data_all_np5d)

In [None]:
data_4d = data_all_np5d[0]

In [None]:
s=data_4d.shape
s

In [None]:
p0= data_all_np5d[0].reshape( (s[0]*s[1], np.prod(s[2:])) )

In [None]:
p0.shape

In [None]:
data_flat_for_mlp= p0.transpose((1,0))

In [None]:
topred_tc= torch.from_numpy(data_flat_for_mlp).float()

In [None]:
data_tc_ds = TensorDataset(topred_tc)

In [None]:
data_tc_ds[0]

In [None]:
data_tc_batcher = DataLoader(data_tc_ds, batch_size=4096, shuffle=False)

In [None]:
for i,data_batch0 in enumerate(data_tc_batcher):
    #res= torch.squeeze(mlp_model(data_multi_preds_probs_np))
    print(i, data_batch0[0].shape)
    
    if i>5:
        break

In [None]:
len(data_tc_batcher)

In [None]:
from tqdm import tqdm

In [None]:
NN2_model_fusion.to("cpu")
NN2_model_fusion.eval()
res_s=[]
with torch.no_grad():
    logging.info("Beggining NN2 inference of whole volume")
    for data_batch in tqdm(data_tc_batcher):
        #res= torch.squeeze(mlp_model(data_multi_preds_probs_np))
        pred = NN2_model_fusion(data_batch[0])
        pred_argmax = torch.argmax(pred,dim=1)
        res_s.append(pred_argmax)
        #gc.collect()

In [None]:
r0 = torch.concatenate(res_s)

In [None]:
r0.shape

In [None]:
r2 = r0.detach().cpu().numpy().reshape(*s[2:])

In [None]:
r2.shape

In [None]:
r2

View results

In [None]:
labels0 = trainlabels_list[0]
labels0.shape

In [None]:
nimages=4
fig,axs = plt.subplots(2,nimages,figsize=(12,6))

randomints= np.random.default_rng().permutation(256)
#plt.tight_layout()
for i in range(nimages):
    r0 = randomints[i]
    datai = data_to_predict[r0,:,:]
    pred_labeli = r2[r0,:,:]
    gnd_labeli = labels0[r0,:,:]
    axs[0,i].imshow(datai, cmap="gray")
    axs[0,i].set_axis_off()
    axs[0,i].imshow(pred_labeli , cmap='tab10', alpha=0.5, vmax=10)
    axs[0,i].set_axis_off()
    axs[1,i].imshow(datai, cmap="gray")
    axs[1,i].set_axis_off()
    axs[1,i].imshow(gnd_labeli , cmap='tab10', alpha=0.5, vmax=10)
    axs[1,i].set_axis_off()
plt.tight_layout()

Quite good results when looking along Z

In [None]:
import napari
NV = napari.Viewer()
NV.add_image(data_to_predict)
NV.add_labels(r2)
NV.add_labels(labels0)

# Load model and run predictions

From restart

In [None]:
import numpy as np
import dask.array as da
#import subprocess
import tempfile
from pathlib import Path
import os
cwd = os.getcwd()
import tempfile
import logging
from types import SimpleNamespace
import tqdm #progress bar in iterations
import pandas as pd

from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset

import albumentations as alb
import albumentations.pytorch

import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils

import matplotlib.pyplot as plt

In [None]:
import logging
logging.basicConfig(level=logging.INFO)
import tifffile

Load models

In [None]:
load_model = torch.load("2024-06-15_2054_model.lgsegm2")

In [None]:
load_model

In [None]:
cuda_str = "cuda:0"

In [None]:
def create_nn1_ptmodel_from_class_generator(nn1_cls_gen_dict: dict):
    # get segm model from dictionary item
    model0=None

    if nn1_cls_gen_dict['class'].lower()=='smp': #unet, AttentionNet (manet) and fpn
        #Segmentation models pytorch
        arch = nn1_cls_gen_dict['arch'].lower()
        if arch=="unet" or arch=="u_net":
            NN_class = smp.Unet
        elif arch=="manet":
            model0 = smp.MAnet
        elif arch=="fpn":
            model0 = smp.FPN
        else:
            raise ValueError(f"arch:{arch} not valid.")
        
        model0 = NN_class(
            encoder_name = nn1_cls_gen_dict['encoder_name'],
            encoder_weights = nn1_cls_gen_dict['encoder_weights'],
            in_channels = nn1_cls_gen_dict['in_nchannels'],
            classes = nn1_cls_gen_dict['nclasses'],
            #activation = "sigmoid" # Whether to use activation or not, depends whether the loss function require slogits or not
            activation = None
            )
    else:
        raise ValueError(f"class {nn1_cls_gen_dict['class']} not supported.")
    
    # TODO: add other 2D model support, not just SMPs

    return model0

create NN1 models from loaded dict

In [None]:
load_model["NN1_models_state_dict"]

In [None]:
len(load_model["NN1_models_state_dict"])

In [None]:
load_model["nn1_models_class_generator"]

In [None]:
# NN1_models=[]
# for i,w0 in load_model["NN1_models_dict"].items():
#     cg0= load_model["nn1_models_class_generator"][int(i)]
#     m0 = create_nn1_ptmodel_from_class_generator(cg0)
#     m0.load_state_dict(w0)
#     NN1_models.append( m0.to(cuda_str) )

In [None]:
NN1_models=[]
for cg0,w0 in zip(load_model["nn1_models_class_generator"], load_model["NN1_models_state_dict"] ):
    cg0['encoder_weights'] = None # Ensure no weights are preloaded
    m0 = create_nn1_ptmodel_from_class_generator(cg0)
    m0.load_state_dict(w0)
    NN1_models.append( m0)

In [None]:
NN1_models

In [None]:
_ = [ m.to(cuda_str) for m in NN1_models]

NN2 (MLP)

In [None]:
class MLPClassifier(nn.Module):
    # MLP classifier with sigmoid activation

    # Should I add softmax?
    def __init__(self, input_size:int, hiden_sizes_list:list, output_size:int, activ_str:str):
        super().__init__()

        size0= input_size

        self.hidden = nn.ModuleList()

        for hls in hiden_sizes_list:
            hid_layer0 =  nn.Linear(size0, hls)
            self.hidden.append(hid_layer0)
            size0=hls
        #last layer
        self.hidden.append(nn.Linear(size0, output_size))

        if "tanh" in activ_str.lower():
            self.activ = nn.functional.tanh
        elif "relu" in activ_str.lower():
            self.activ = nn.functional.relu
        elif "sigm" in activ_str.lower():
            self.activ = nn.functional.sigmoid
        else:
            raise ValueError(f"activ_str {activ_str} not valid")

    def forward(self, x):
        # for i,hlayer in self.hidden:
        #     x= self.activ(hlayer(x))
        for i in range(len(self.hidden)-1):
            x= self.activ(self.hidden[i](x))
        
        #Last layer
        x = self.hidden[-1](x)
        
        #x = self.sigm(x)
        return x #returns logits
    
    # def predict_class_as_cpu_np(self,x):
    #     p0 = self.forward(x)
    #     pred = torch.squeeze(torch.argmax(p0, dim=1))
    #     return pred.detach().cpu().numpy()

In [None]:
def create_nn2_ptmodel_from_class_generator(nn2_cls_gen_dict: dict ):
    hid_layers = nn2_cls_gen_dict['nn2_hidden_layer_sizes'].split(",")

    if len(hid_layers)==0:
        ValueError(f"Invalid nn2_hidden_layer_sizes : {nn2_cls_gen_dict['nn2_hidden_layer_sizes']}")

    hid_layers_num_list = list(map(int, hid_layers))
    logging.info(f"hid_layers_num_list: {hid_layers_num_list}")
    
    model0 = MLPClassifier(
        nn2_cls_gen_dict['nn2_in_nchannels'],
        hid_layers_num_list,
        nn2_cls_gen_dict['nn2_out_nclasses'],
        nn2_cls_gen_dict["nn2_activation"]
        )
    
    if "NN2_model_dict" in nn2_cls_gen_dict.keys():
        logging.info("NN2: load weights from dict")
        model0.load_state_dict(nn2_cls_gen_dict["NN2_model_dict"])
        
    return model0

In [None]:
nn2_dict = load_model["nn2_MLP_models_class_generator"] 

In [None]:
nn2_dict

In [None]:
NN2_model_fusion = create_nn2_ptmodel_from_class_generator(load_model["nn2_MLP_models_class_generator"] )

In [None]:
NN2_model_fusion.load_state_dict(load_model["NN2_model_dict"])

In [None]:
NN2_model_fusion

More global settings

In [None]:
nn1_axes_to_models_indices = load_model["nn1_axes_to_models_indices"]
nn1_axes_to_models_indices

In [None]:
device_str = cuda_str

In [None]:
nn1_batch_size=2 # does not load but needed

TODO: Run predictions using these models

Collect functions from header "Predict volume(s) using the models"

In [None]:
import h5py
def _save_pred_data(folder, data, count,axis, rot):
    # Saves predicted data to h5 file in tempdir and return file path in case it is needed
    file_path = f"{folder}/pred_{count}_{axis}_{rot}.h5"

    logging.info(f"Saving data of shape {data.shape} to {file_path}.")
    with h5py.File(file_path, "w") as f:
        f.create_dataset("/data", data=data)

    return file_path

class VolumeSlicerDataset(Dataset):

    def __init__(self, datavol, axis, per_slice_tfms=None, device_str="cuda:0"):
        assert datavol.ndim==3
        assert axis==0 or axis==1 or axis==2

        self.datavol=datavol
        self.axis=axis
        self.per_slice_tfms=per_slice_tfms
        self.device_str = device_str

    def __len__(self):
        return self.datavol.shape[self.axis]

    def __getitem__(self, idx):
        
        data_slice=None
        if self.axis==0:
            data_slice = self.datavol[idx,:,:]
        elif self.axis==1:
            data_slice = self.datavol[:,idx,:]
        elif self.axis==2:
            data_slice = self.datavol[:,:,idx]

        res = data_slice
        # Apply transform
        if self.per_slice_tfms is not None:
            res = self.per_slice_tfms(data_slice)

        #Convert to tensor and send to device
        res_torch = torch.unsqueeze(torch.from_numpy(res), dim=0).float().to(self.device_str)

        return res_torch



In [None]:
def nn1_predict_slices_along_axis_1(datavol, axis):
    ds0 = VolumeSlicerDataset(datavol, axis , per_slice_tfms=None, device_str=device_str)
    dl0 = DataLoader(dataset=ds0, batch_size=nn1_batch_size, shuffle=False)

    # Get correct model
    model_index = nn1_axes_to_models_indices[axis]
    #model = NN1_models[model_index]
    model = NN1_models[model_index]
    logging.info(f"axis:{axis}, use model_index: {model_index}")

    model.eval()
    
    SM_func = torch.nn.Softmax(dim=1)

    preds_list = []
    labels_list = []
    for ibatch, x in enumerate(dl0):
        # x.shape is (batchsize, 1, 256,256) with 256 being the imagesize
        X= model(x)
        #pred shape is (batchsize, 3, 256, 256)

        pred_probs_slice = SM_func(X) #Convert to probabilities

        # get labels using argmax
        lbl_slice = torch.argmax(pred_probs_slice, dim=1)
        #labels_list.append(lbl_slice)

        # need to move away from device, otherwise it uses too much VRAM
        pred_probs_slice_np = pred_probs_slice.detach().cpu().numpy()
        lbl_slice_np = lbl_slice.detach().cpu().numpy().astype(np.uint8)

        preds_list.append(pred_probs_slice_np)
        labels_list.append(lbl_slice_np)

    logging.info("Prediction of all slices complete. Now stacking and getting the right orientation.")
    # stack slices
    preds_list_conc = np.concatenate(preds_list, axis=0) # shape will be (256,3,256,256)
    labels_pred_conc = np.concatenate(labels_list, axis=0)

    pred_oriented = None
    labels_oriented = None
    if axis==0:
        pred_oriented = np.transpose(preds_list_conc, axes=(1,0,2,3))
        labels_oriented = labels_pred_conc # no need to orient
    elif axis==1:
        pred_oriented = np.transpose(preds_list_conc, axes=(1,2,0,3))
        labels_oriented = np.transpose(labels_pred_conc, axes=(1,0,2))
    elif axis==2:
        pred_oriented = np.transpose(preds_list_conc, axes=(1,2,3,0))
        labels_oriented = np.transpose(labels_pred_conc, axes=(1,2,0))

    #with pred_oriented note that class probability is at the start
    return pred_oriented, labels_oriented

In [None]:
def predict_NN1(data_to_predict_l, path_out_results):


    pred_data_probs_filenames=[] #Will store results in files, and keep the filenames as reference
    pred_data_labels_filenames=[]
    pred_sets=[]
    pred_planes=[]
    pred_rots=[]
    pred_ipred=[]
    pred_shapes=[]
    itag=0

    for iset, data_to_predict in enumerate(data_to_predict_l):
        logging.info(f"Data to predict iset:{iset}")
        #data_vol = np.array(data_to_predict0) #Copies

        for krot in range(0, 4): #Around axis rotations
            rot_angle_degrees = krot * 90
            logging.info(f"Volume to be rotated by {rot_angle_degrees} degrees")

            #Predict 3 axis
            #YX, along Z
            # planeYX=(1,2)
            logging.info("Predicting YX slices, along Z")
            data_vol = np.array(np.rot90(data_to_predict,krot, axes=(1,2))) #rotate

            #prob0,lab0 = nn1_predict_slices_along_axis(data_vol, axis=0, device_str=cuda_str)
            prob0,lab0 = nn1_predict_slices_along_axis_1(data_vol, 0)

            #invert rotations before saving
            pred_probs = np.rot90(prob0, -krot, axes=(2,3)) 
            pred_labels = np.rot90(lab0, -krot, axes=(1,2)) #note that class is at start

            fn = _save_pred_data(path_out_results,pred_probs, iset, "YX", rot_angle_degrees)
            pred_data_probs_filenames.append(fn)
            fn = _save_pred_data(path_out_results,pred_labels, iset, "YX_labels", rot_angle_degrees)
            pred_data_labels_filenames.append(fn)
            
            pred_sets.append(iset)
            pred_planes.append("YX")
            pred_rots.append(rot_angle_degrees)
            pred_ipred.append(itag)
            pred_shapes.append(pred_labels.shape)
            itag+=1



            #ZX
            logging.info("Predicting ZX slices, along Y")
            #planeZX=(0,2)
            data_vol = np.array(np.rot90(data_to_predict,krot, axes=(0,2))) #rotate
            #prob0,lab0 = nn1_predict_slices_along_axis(data_vol, axis=1, device_str=cuda_str)
            prob0,lab0 = nn1_predict_slices_along_axis_1(data_vol, 1)


            pred_probs = np.rot90(prob0, -krot, axes=(1,3)) #invert rotation before saving
            pred_labels = np.rot90(lab0, -krot, axes=(0,2))

            fn = _save_pred_data(path_out_results,pred_probs, iset, "ZX", rot_angle_degrees)
            pred_data_probs_filenames.append(fn)
            fn = _save_pred_data(path_out_results,pred_labels, iset, "ZX_labels", rot_angle_degrees)
            pred_data_labels_filenames.append(fn)
            
            pred_sets.append(iset)
            pred_planes.append("ZX")
            pred_rots.append(rot_angle_degrees)
            pred_ipred.append(itag)
            pred_shapes.append(pred_labels.shape)
            itag+=1



            #ZY
            logging.info("Predicting ZY slices, along X")
            #planeZY=(0,1)
            data_vol = np.array(np.rot90(data_to_predict,krot, axes=(0,1))) #rotate
            #prob0,lab0 = nn1_predict_slices_along_axis(data_vol, axis=2, device_str=cuda_str)
            prob0,lab0 = nn1_predict_slices_along_axis_1(data_vol, 2)

            pred_probs = np.rot90(prob0, -krot, axes=(1,2)) #invert rotation before saving
            pred_labels = np.rot90(lab0, -krot, axes=(0,1))
            
            fn = _save_pred_data(path_out_results,pred_probs, iset, "ZY", rot_angle_degrees)
            pred_data_probs_filenames.append(fn)
            fn = _save_pred_data(path_out_results,pred_labels, iset, "ZY_labels", rot_angle_degrees)
            pred_data_labels_filenames.append(fn)
            
            pred_sets.append(iset)
            pred_planes.append("ZY")
            pred_rots.append(rot_angle_degrees)
            pred_ipred.append(itag)
            pred_shapes.append(pred_labels.shape)
            itag+=1

    all_pred_pd = pd.DataFrame({
        'pred_data_probs_filenames': pred_data_probs_filenames,
        'pred_data_labels_filenames': pred_data_labels_filenames,
        'pred_sets':pred_sets,
        'pred_planes':pred_planes,
        'pred_rots':pred_rots,
        'pred_ipred':pred_ipred,
        'pred_shapes': pred_shapes,
    })

    return all_pred_pd

load validation data which will be used to test predictions

In [None]:
val_data_l = []

val_data = tifffile.imread("test_data\TS_0005_crop_val.tif")
val_labels_gnd = tifffile.imread("test_data\TS_0005_ribos_membr_crop_val.tif")

val_data_l = [val_data]

In [None]:
def correct_data(d0):
    d0_mean = np.mean(d0)
    d0_std = np.std(d0)

    if d0_std==0:
        raise ValueError("Error. Stdev of data volume is zero.")
    
    d0_corr = (d0.astype(np.float32) - d0_mean) / d0_std
    d0_corr = (np.clip(d0_corr, -3.0, 3.0) +3.0) / 6.0
    
    return (d0_corr*255).astype(np.uint8)

In [None]:
val_data_l_corr = [ correct_data(d) for d in val_data_l]

Create tempdir here. If created inside the function it will be deleted after returning

In [None]:
tempdir_pred= tempfile.TemporaryDirectory()
path_out_results = Path(tempdir_pred.name)
logging.info(f"tempdir_pred_path:{path_out_results}")

In [None]:
res_pd = predict_NN1(val_data_l_corr, path_out_results)

In [None]:
res_pd

In [None]:
res_pd["pred_ipred"].max()

In [None]:
from tqdm import tqdm

Predict NN2 from pandas df

In [None]:
def aggregate_data_from_pd(all_pred_pd):
    data_all_np5d=None

    logging.debug("Aggregating multiple sets onto a single volume data_all_np5d")
    # aggregate multiple sets for data
    for i,prow in all_pred_pd.iterrows():

        prob_filename = prow['pred_data_probs_filenames']
        with h5py.File(prob_filename,'r') as f:
            data0 = np.array(f["data"])

        if i==0:
            #initialise
            logging.info(f"filename:{prob_filename} , shape:{data0.shape}")
            all_shape0 = (
                all_pred_pd["pred_sets"].max()+1, # needs to be adjusted
                all_pred_pd["pred_ipred"].max()+1, # needs to be adjusted, perhaps can be collected from dataframe
                *data0.shape
                )

            data_all_np5d=np.zeros( all_shape0 , dtype=data0.dtype)

        
        ipred=prow['pred_ipred']
        iset=prow['pred_sets']

        data_all_np5d[iset,ipred, :,:,:, :] = data0
    
    return data_all_np5d

def NN2_predict_from_pd(all_pred_pd, device_str="cpu"):

    #Collect all data and put it in a very large array
    data_all_np5d = aggregate_data_from_pd(all_pred_pd)
    logging.info(f"data_all_np5d.shape: {data_all_np5d.shape}")

    nsets = data_all_np5d.shape[0]
    logging.info(f"nsets: {nsets}")

    nn2_preds = []
    for iset in range(nsets):
        data_4d = data_all_np5d[iset]
        s = data_4d.shape
        p0= data_4d.reshape( (s[0]*s[1], np.prod(s[2:])) )
        data_flat_for_mlp= p0.transpose((1,0))
        topred_tc= torch.from_numpy(data_flat_for_mlp).float().to(device_str)
        data_tc_ds = TensorDataset(topred_tc)
        data_tc_batcher = DataLoader(data_tc_ds, batch_size=4096, shuffle=False)

        NN2_model_fusion.to(device_str)
        NN2_model_fusion.eval()
        res_s=[]
        with torch.no_grad():
            logging.info("Beggining NN2 inference of whole volume")
            for data_batch in tqdm(data_tc_batcher):
                #res= torch.squeeze(mlp_model(data_multi_preds_probs_np))
                pred = NN2_model_fusion(data_batch[0])
                pred_argmax = torch.argmax(pred,dim=1)
                res_s.append(pred_argmax)
        r0 = torch.concatenate(res_s)
        r2 = r0.detach().cpu().numpy().reshape(*s[2:])
        logging.info(f"iset:{iset}, nn2 prediction shape:{r2.shape}")

        nn2_preds.append(r2)

    return nn2_preds

Run NN2 prediction on the result panda dataframe

In [None]:
nn2_preds = NN2_predict_from_pd(res_pd)

In [None]:
nn2_preds[0]

In [None]:
import napari
NV=napari.Viewer()
NV.add_image(val_data)
NV.add_labels(val_labels_gnd)
NV.add_labels(nn2_preds[0].astype(np.uint16))

# Test segmentor2 train_nn2

In [1]:
import leopardgecko.segmentor2 as lgs2
import numpy as np

import logging
#logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s — %(name)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
        )

In [2]:
lgs2.nn1_train_epochs=2 # debug low number
lgs2.nn2_train_epochs=2


lgs2.nn2_MLP_model_class_generator= lgs2.nn2_MLP_model_class_generator_default

lgs2.nn2_ntrain = 8

lgs2.update_nn2_model_from_generator()

2024-07-01 01:32:12,237 — root — INFO — update_nn2_model_from_generator:979 — update_NN2_model_from_generator()
2024-07-01 01:32:12,237 — root — INFO — create_nn2_ptmodel_from_class_generator:946 — create_nn2_ptmodel_from_class_generator()
2024-07-01 01:32:12,247 — root — INFO — create_nn2_ptmodel_from_class_generator:954 — hid_layers_num_list: [10, 10]


Create pretend input for nn2 train

In [3]:
inp_data = np.random.random( (2, 12, 3, 64,64,64)) # 2 sets, 12 ways, 3 classes

inp_labels = np.random.randint(0,3,size=(2,64,64,64))

In [4]:
lgs2.train_nn2(inp_data, inp_labels)

2024-07-01 01:32:12,554 — root — INFO — train_nn2:1074 — NN2_train()
2024-07-01 01:32:12,555 — root — INFO — train_nn2:1075 — data_all_np5d.shape:(2, 12, 3, 64, 64, 64), len(trainlabels_list): 2
2024-07-01 01:32:12,556 — root — INFO — train_nn2:1099 — Selecting only nn2_ntrain voxel coordinates from data and ground truth for training
2024-07-01 01:32:12,559 — root — INFO — train_nn2:1159 — X_train_subset_t and y_train_subset_t created
2024-07-01 01:32:12,559 — root — INFO — train_nn2:1164 — dataset_X_y_train created
2024-07-01 01:32:12,559 — root — INFO — train_nn2:1168 — Creating test dataset
2024-07-01 01:32:12,565 — root — INFO — train_nn2:1173 — X_test_subset_t and y_test_subset_t created
2024-07-01 01:32:12,565 — root — INFO — train_nn2:1178 — dataset_X_y_test created
2024-07-01 01:32:12,568 — root — INFO — train_nn2:1233 — Beggining training NN2.
2024-07-01 01:32:12,570 — root — INFO — train_model:531 — train_model()
2024-07-01 01:32:12,570 — root — INFO — train_model:534 — ---- 