In [1]:
from utils import watershed_from_boundary_distance, dice_coefficient_from_instances
import os 
import numpy as np 
import napari
import torch
from skimage.io import imread
import tifffile
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
assert torch.cuda.is_available()
from torchmetrics.classification import Dice, MulticlassAccuracy
from model3d import Unet3D

from dataset import (
    BlastoDataset
)

from model3d import Unet3D

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if device is None:
    # You can pass in a device or we will default to using
    # the gpu. Feel free to try training on the cpu to see
    # what sort of performance difference there is
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")


In [2]:
viewer = napari.Viewer()



In [7]:


model1_path = '/localscratch/DL4MIA_2024/BlastoSeg/saved_models/unet3d_model_stepsize1_best.pth'
model2_path = '/localscratch/DL4MIA_2024/BlastoSeg/saved_models/unet3d_model_stepsize2_try2_best.pth'
model5_path = '/localscratch/DL4MIA_2024/BlastoSeg/saved_models/unet3d_model_stepsize5_try2_best.pth'

model_paths = [model1_path, model2_path, model5_path]

In [9]:
for model_path in model_paths:
    model_name = os.path.basename(model_path)
    checkpoint = torch.load(model_path)
    model = Unet3D(n_classes=2)
    model.load_state_dict(checkpoint)
    model.eval()

    val_data = BlastoDataset("/group/dl4miacourse/projects/BlastoSeg/validation")
    val_loader = DataLoader(val_data, batch_size = 1, shuffle=False, num_workers=8)

    val_data_names = sorted([f.split('raw.tif')[0] for f in os.listdir("/group/dl4miacourse/projects/BlastoSeg/validation/raw/") if '.tif' in f])
    print(val_data_names)

    pred_dir = "/group/dl4miacourse/projects/BlastoSeg/presentation_files/predicted_distances"
    if not os.path.exists(pred_dir): 
        os.mkdir(pred_dir)

    step_size = 118
    count = -1
    with torch.no_grad():
        for x_batch, y_batch, _ in val_loader: 
            count += 1

            print(x_batch.shape, y_batch.shape)
            z_slices = x_batch.shape[1]
            num_iterations = int(z_slices / step_size)
            for i in range(0,num_iterations,step_size): 

                start_index = i
                end_index = min(start_index + step_size, x_batch.size(2))

                x = x_batch[:, start_index:end_index, :, :] # [1,2,256,256] BDHW
                x = torch.unsqueeze(x, 0) # [1,1,2,256,256] BCDHW

                y = y_batch[:, start_index:end_index, :, :]
                y = torch.unsqueeze(y, 0)

                prediction = model(x)  # Assuming model expects a batch dimension
        
                if y.dtype != prediction.dtype:
                    y = y.type(prediction.dtype)
                
                viewer.add_image(prediction.numpy())
                tifffile.imwrite(os.path.join(pred_dir, f"{val_data_names[count]}_{model_name}_pred_dsts.tif"), np.array(prediction.numpy(), dtype = np.float32))

['t0004_', 't0044_']
torch.Size([1, 118, 256, 256]) torch.Size([1, 118, 256, 256])
torch.Size([1, 118, 256, 256]) torch.Size([1, 118, 256, 256])
['t0004_', 't0044_']
torch.Size([1, 118, 256, 256]) torch.Size([1, 118, 256, 256])
torch.Size([1, 118, 256, 256]) torch.Size([1, 118, 256, 256])
['t0004_', 't0044_']
torch.Size([1, 118, 256, 256]) torch.Size([1, 118, 256, 256])
torch.Size([1, 118, 256, 256]) torch.Size([1, 118, 256, 256])


torch.Size([1, 118, 256, 256]) torch.Size([1, 118, 256, 256])
torch.Size([1, 118, 256, 256]) torch.Size([1, 118, 256, 256])


In [4]:
# load predictions back in in case they are no longer in memory

pred_dir = "/group/dl4miacourse/projects/BlastoSeg/presentation_files/predicted_distances"
pred_instance_dir = "/group/dl4miacourse/projects/BlastoSeg/presentation_files/predicted_instances"

if not os.path.exists(pred_instance_dir):
    os.mkdir(pred_instance_dir)

prediction_files = [f for f in os.listdir(pred_dir) if '_dst.tif' in f]

threshold = 0.1
min_seed_distance = 10
for f in prediction_files:
    pred = imread(os.path.join(pred_dir, f))
    seeds, seg = watershed_from_boundary_distance(pred, pred>threshold, min_seed_distance = min_seed_distance)
    viewer.add_image(pred)
    viewer.add_image(seeds)
    viewer.add_labels(seg)

    tifffile.imwrite(os.path.join(pred_instance_dir, f), np.array(seg, dtype = np.uint16))
   

In [5]:
# Calculate the dice coefficient for all valid labels (so no ignore index) from the validation data

val_dir = "/group/dl4miacourse/projects/BlastoSeg/validation/gt/"

gt_files = sorted([f for f in os.listdir(val_dir) if '.tif' in f])
pred_instance_files = sorted([f for f in os.listdir(pred_instance_dir) if '.tif' in f])

print(gt_files)
print(pred_instance_files)
assert len(gt_files) == len(pred_instance_files)

gt = [imread(os.path.join(val_dir, f)) for f in gt_files]
pred_inst = [imread(os.path.join(pred_instance_dir, f)) for f in pred_instance_files]

dice_coeffs = []
for gt_img, pred_img in zip(gt, pred_inst): 
    invalid_label_mask = gt_img <= 1
    valid_labels = gt_img.copy()
    valid_labels[invalid_label_mask] = 0
    
    # compute dice coefficient 
    dice = dice_coefficient_from_instances(gt_img, pred_img)

    print('mean dice coefficient between validation labels and predicted labels', np.mean(dice))
    dice_coeffs.append(dice)

print(np.mean(dice_coeffs[0]), np.mean(dice_coeffs[1]))
print('combined average dice coeff', np.mean([np.mean(d) for d in dice_coeffs]))







['t0004_gt.tif', 't0044_gt.tif']
['t0004_pred_stepsize2_inst.tif', 't0044_pred_stepsize2_inst.tif']
looking for label 9
the largest corresponding predicted label is 340
dice coefficient is 0.5165929203539823
looking for label 16
the largest corresponding predicted label is 979
dice coefficient is 0.8121987393400074
looking for label 21
the largest corresponding predicted label is 1017
dice coefficient is 0.8327794439490962
looking for label 22
the largest corresponding predicted label is 1300
dice coefficient is 0.6514120799835086
looking for label 32
the largest corresponding predicted label is 1389
dice coefficient is 0.8331533925753675
looking for label 42
the largest corresponding predicted label is 1532
dice coefficient is 0.8691717171717172
looking for label 48
the largest corresponding predicted label is 1248
dice coefficient is 0.5914260717410323
looking for label 51
the largest corresponding predicted label is 1308
dice coefficient is 0.6216904276985743
looking for label 53
th

In [6]:
print(np.mean(dice_coeffs[0]), np.mean(dice_coeffs[1]))

0.7345840522326083 0.7578012892530416
