In [2]:
import numpy as np
import matplotlib.pyplot as plt
import os
from cellpose import models, io, utils

# Load image
img_path = 'data/example-CXA-wrongfull-comparison.png'
img = io.imread(img_path)

# Define the model
model = models.Cellpose(gpu=True, model_type='cyto')

# Get masks, flows, etc.
masks, flows, styles, diams = model.eval(img, diameter=None, flow_threshold=None)


In [4]:
# Define the results directory and ensure it exists
results_dir = 'results/cellpose_wrongful'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

# Define filenames
base_filename = os.path.basename(img_path).replace('.png', '')
original_save_path = os.path.join(results_dir, f'{base_filename}_original.png')
mask_save_path = os.path.join(results_dir, f'{base_filename}_mask.png')
outline_save_path = os.path.join(results_dir, f'{base_filename}_outlines.png')
cellpose_save_path = os.path.join(results_dir, f'{base_filename}_cellpose.png')

# Save original image
io.imsave(original_save_path, img)

# Save predicted mask
io.imsave(mask_save_path, masks)

# Get and save predicted outlines
outlines = utils.masks_to_outlines(masks)
outline_img = np.zeros_like(img)
outline_img[outlines] = [0, 255, 0]
io.imsave(outline_save_path, outline_img)

# Save visualization of the cell pose
dpi = 100
fig = plt.figure(figsize=(img.shape[1]/dpi, img.shape[0]/dpi), dpi=dpi)
plt.imshow(img)
plt.imshow(masks, alpha=0.5)
plt.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.savefig(cellpose_save_path, dpi=dpi, bbox_inches='tight', pad_inches=0)
plt.close(fig)