## Profiling and improving the Nimbus Model

In [1]:
# import required packages
import warnings
warnings.simplefilter("ignore")
import os
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
from nimbus_inference.nimbus import Nimbus, prep_naming_convention
from nimbus_inference.utils import MultiplexDataset
from alpineer import io_utils
from nimbus_inference import example_dataset
from nimbus_inference.viewer_widget import NimbusViewer
from torch.profiler import profile, ProfilerActivity, record_function

In [4]:
# set up the base directory
base_dir = os.path.normpath("../data/example_dataset")

example_dataset.get_example_dataset(dataset="cluster_pixels", save_dir = base_dir, overwrite_existing = False)
# set up file paths
tiff_dir = os.path.join(base_dir, "image_data")
deepcell_output_dir = os.path.join(base_dir, "segmentation", "deepcell_output")
nimbus_output_dir = os.path.join(base_dir, "nimbus_output")

# Create nimbus output directory
os.makedirs(nimbus_output_dir, exist_ok=True)

# Check if paths exist
io_utils.validate_paths([base_dir, tiff_dir, deepcell_output_dir, nimbus_output_dir])

# define the channels to include
include_channels = [
    "CD3", "CD4", "CD8", "CD14", "CD20", "CD31", "CD45", "CD68", "CD163", "CK17", "Collagen1",
    "ECAD", "Fibronectin", "GLUT1", "HLADR", "IDO", "Ki67", "PD1", "SMA", "Vim"
]

# either get all fovs in the folder...
fov_names = os.listdir(tiff_dir)
# ... or optionally, select a specific set of fovs manually
# fovs = ["fov0", "fov1"]

# make sure to filter paths out that don't lead to FoVs, e.g. .DS_Store files.
fov_names = [fov_name for fov_name in fov_names if not fov_name.startswith(".")] 

# construct paths for fovs
fov_paths = [os.path.join(tiff_dir, fov_name) for fov_name in fov_names]

# Prepare segmentation naming convention that maps a fov_path to the according segmentation label map
segmentation_naming_convention = prep_naming_convention(deepcell_output_dir)

# test segmentation_naming_convention
if os.path.exists(segmentation_naming_convention(fov_paths[0])):
    print("Segmentation data exists for fov 0 and naming convention is correct")
else:
    print("Segmentation data does not exist for fov 0 or naming convention is incorrect")

dataset = MultiplexDataset(
    fov_paths=fov_paths,
    suffix=".tiff", # or .png, .jpg, .jpeg, .tif or .ome.tiff
    include_channels=include_channels,
    segmentation_naming_convention=segmentation_naming_convention,
    output_dir=	nimbus_output_dir,
)

Segmentation data exists for fov 0 and naming convention is correct
All inputs are valid


In [5]:
nimbus = Nimbus(
    dataset=dataset,
    save_predictions=True,
    batch_size=4,
    test_time_aug=True,
    input_shape=[1024,1024],
    device="auto",
    output_dir=nimbus_output_dir,
)

# check if all inputs are valid
nimbus.check_inputs()

Checking for updated model checkpoints on HuggingFace Hub...
Using existing checkpoint: /Users/jrumber/Desktop/Nimbus-Inference/src/nimbus_inference/assets/V1.pt
Loaded weights from /Users/jrumber/Desktop/Nimbus-Inference/src/nimbus_inference/assets/V1.pt
All inputs are valid.


In [None]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
    with record_function("model_inference"):
        cell_table = nimbus.predict_fovs()

Available GPUs:  0
Predictions will be saved in ../data/example_dataset/nimbus_output
Iterating through fovs will take a while...
Predicting ../data/example_dataset/image_data/fov4...


  0%|          | 0/20 [00:00<?, ?it/s]

No normalization dict found. Preparing normalization dict...
Iterate over fovs...
Predicting ../data/example_dataset/image_data/fov3...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov10...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov2...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov5...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov0...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov7...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov9...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov8...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov6...


  0%|          | 0/20 [00:00<?, ?it/s]

Predicting ../data/example_dataset/image_data/fov1...


  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     model_inference        18.25%       22.526s       100.00%      123.435s      123.435s             1  
                            aten::to         0.01%       7.177ms        74.29%       91.698s      11.265ms          8140  
                      aten::_to_copy         0.02%      18.601ms        74.28%       91.690s      21.936ms          4180  
                         aten::copy_        74.25%       91.653s        74.26%       91.659s      21.928ms          4180  
                   aten::convolution         0.05%      61.318ms         2.56%        3.162s      74.852us         42240  
                

In [None]:
prof.export_chrome_trace("trace.json")

ERROR:tornado.general:Uncaught exception in ZMQStream callback
Traceback (most recent call last):
  File "/Users/jrumber/.local/lib/python3.11/site-packages/zmq/eventloop/zmqstream.py", line 565, in _log_error
    f.result()
  File "/Users/jrumber/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 302, in dispatch_control
    await self.process_control(msg)
  File "/Users/jrumber/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 308, in process_control
    idents, msg = self.session.feed_identities(msg, copy=False)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jrumber/.local/lib/python3.11/site-packages/jupyter_client/session.py", line 994, in feed_identities
    raise ValueError(msg)
ValueError: DELIM not in msg_list
ERROR:tornado.general:Uncaught exception in ZMQStream callback
Traceback (most recent call last):
  File "/Users/jrumber/.local/lib/python3.11/site-packages/zmq/eventloop/zmqstream.py", line 565, in _log_er