Connected to vx (Python 3.9.16)

In [1]:
import torch
import torchvision.transforms as transforms
import tifffile
import stackview
import numpy as np

In [2]:

# Load and visualize an image stack and convert to dtype uInt8

image_path = "20170928CH_Exp8_M28.mt2.tif" # e.g. stacked OME-Tiff 
image = tifffile.imread(image_path)
image_uint8 = (image / image.max() * 255).astype(np.uint8)

image = torch.Tensor(image_uint8)

mean = 127.5
std = 127.5

stackview.slice(image, continuous_update=True)

VBox(children=(HBox(children=(VBox(children=(ImageWidget(height=256, width=256),)),)), IntSlider(value=16, des…

In [3]:
def center_crop_z(image: torch.Tensor):
    # Utility to center-crop an image stack along the z-dimension
    shape_z = image.shape[0]
    start_z = shape_z // 2 - 15
    end_z = shape_z // 2 + 15
    return image[start_z:end_z, :]

In [4]:
# Normalize and reshape/crop the data for input into the model

# Expected input shape for network: 196, 152, 30
image_transforms = transforms.Compose([transforms.Normalize(mean, std), transforms.CenterCrop(size=(196, 152)), transforms.Lambda(lambda image: center_crop_z(image))])
image_normalized = image_transforms(image)

stackview.slice(image_normalized, continuous_update=True)

VBox(children=(HBox(children=(VBox(children=(ImageWidget(height=196, width=152),)),)), IntSlider(value=15, des…

In [5]:
# Load model and run inference
model = torch.jit.load("lesion_model.pt")

batch = image_normalized.permute(2, 1, 0)[None, None, :]
logits = model(batch)
softmax = torch.nn.Softmax(dim=0)
probabilities = softmax(logits["segment_type_task"][0].squeeze())

In [6]:
# Threshold probabilities and create masks

BACKGROUND_LABEL = 0
LESION_LABEL = 1
THRESHOLD = 0.8

probabilities_numpy = probabilities.detach().numpy()

lesion_probabilities = probabilities_numpy[LESION_LABEL,:]
background_probabilities = probabilities_numpy[BACKGROUND_LABEL,:]

lesion_mask = np.where(lesion_probabilities > THRESHOLD, 1, 0)
background_mask = np.where(background_probabilities > THRESHOLD, 1, 0)

In [7]:
# Visualize the results

lesion_image = np.pad(lesion_mask, ((21, 21), (21, 21), (1,1))).transpose(2, 1, 0)
background_image = np.pad(background_mask, ((21, 21), (21, 21), (1,1))).transpose(2, 1, 0)

stackview.switch({"raw+lesion": 0.7 * image_normalized + 0.3 * lesion_image, "raw": image_normalized, "lesion": lesion_image, "background": background_image}, zoom_factor=2)

VBox(children=(HBox(children=(VBox(children=(ImageWidget(height=392, width=304),)),)), HBox(children=(Button(d…