In [2]:
import sys
sys.path.append('../src')
import os
import glob
import json
import torch
from monai.networks.nets import UNet as monai_unet
from monai.networks.nets import AttentionUnet as monai_att
from semantic_bac_segment.models.pytorch_attention import AttentionUNet as base_att
from monai.networks.nets import UNETR as monai_unetr
from semantic_bac_segment.models.pytorch_cnnunet import Unet as atomai_unet
from semantic_bac_segment.segmentator import Segmentator3
from tifffile import imread

# Set the paths to the results folder and the source images
results_folder = '../results/'
source_images_folder = '../data/source_norm/'

# Initialize an empty list to store the prediction stacks
prediction_stacks = []

images=glob.glob(os.path.join(results_folder, '*.tiff'))

In [15]:
import torch
from semantic_bac_segment.utils import get_device, normalize_percentile, empty_gpu_cache
from monai.inferers import SlidingWindowInferer
import gc

class Segmentator4:
    """
    A class representing a segmentation model.

    Attributes:
        model (torch.nn.Module): The segmentation model.
        device (torch.device): The device on which the model is loaded.
    """

    def __init__(self, model_path, model_graph, patch_size, overlap_ratio, half_precision=False):
        """
        Initializes a Segmentator object.

        Args:
            model_path (str): The path to the model.
            model_graph (str): The model graph.
            patch_size (int): The size of the patches.
            overlap_ratio (float): The overlap ratio between patches.
        """
        self.device = get_device()
        self.model = self.get_model(model_path, self.device, model_graph=model_graph)
        self.patch_size = patch_size
        self.overlap_ratio = overlap_ratio
        self.model.eval()
        self.half_precision=half_precision
        if self.half_precision:
            self.model.half()  


    def predict(self, image):
        """
        Predicts the segmentation mask for the given image. It can handle 2D images or a stack of 2D images.

        Args:
            image (numpy.ndarray): The input image or image stack.

        Returns:
            numpy.ndarray: The segmentation mask or stack of segmentation masks.
        """
        
        # Normalize image
        image = normalize_percentile(image)

        # Check if the image is a stack
        if len(image.shape) > 2:
            # Store the original image size and number of slices
            original_size = image.shape[1:]
            num_slices = image.shape[0]

            # Convert the entire stack to a tensor and add batch dimension
            img_tensor = torch.from_numpy(image).unsqueeze(1).to(self.device)  # Shape: [1, C, H, W]
            if self.half_precision:
                img_tensor = img_tensor.half()  # Convert input to half-precision

            # Create SlidingWindowInferer
            inferer = SlidingWindowInferer(roi_size=self.patch_size, sw_batch_size=1, overlap=self.overlap_ratio)

            with torch.no_grad():
                output_mask = inferer(img_tensor, self.model)

            output_mask = output_mask.cpu().numpy()  # Remove batch dimension

            # Split the result back into the original chunks
            #output_mask = np.dsplit(output_mask, num_slices)

        else:
            # Process the image as before
            img_tensor = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).to(self.device)

            if self.half_precision:
                img_tensor = img_tensor.half()  # Convert input to half-precision

            # Create SlidingWindowInferer
            inferer = SlidingWindowInferer(roi_size=self.patch_size, sw_batch_size=350, overlap=self.overlap_ratio)

            with torch.no_grad():
                output_mask = inferer(img_tensor, self.model)

            output_mask = output_mask.cpu().numpy()

            # Free up tensors
#            del img_tensor, image  
#            gc.collect() 
#            empty_gpu_cache(self.device)
        print(output_mask.shape)
        return output_mask
   

    def get_model(self, path, device, model_graph=None):
        """
        Loads a model from the specified path and returns it.

        Args:
            path (str): The path to the model file.
            device (str): The device to load the model onto.
            model_graph (Optional[torch.nn.Module]): An optional pre-initialized model graph.

        Returns:
            torch.nn.Module: The loaded model.

        Raises:
            FileNotFoundError: If the model file is not found at the specified path.
            RuntimeError: If an error occurs while loading the model.
            Exception: If an unexpected error occurs.
        """
        try:
            if model_graph is None:
                model = torch.load(path, map_location=device)
            else:
                model = model_graph
                state_dict = torch.load(path, map_location=device)
                
                # Check if the loaded state dictionary is compatible with the model architecture
                if not set(state_dict.keys()).issubset(set(model.state_dict().keys())):
                    raise ValueError("Loaded state dictionary does not match the model architecture.")
                
                model.load_state_dict(state_dict)
            
            model.to(device)
            torch.compile(model, mode = 'max-autotune')
            
            return model
        
        except FileNotFoundError:
            raise FileNotFoundError(f"Model file not found at path: {path}")
        
        except RuntimeError as e:
            raise RuntimeError(f"Error occurred while loading the model: {str(e)}")
        
        except Exception as e:
            raise Exception(f"Unexpected error occurred: {str(e)}")

In [16]:
models_dict = []

for file in os.listdir(results_folder):
    if file.endswith('_model.pth'):
        model_path = os.path.join(results_folder, file)
        config_path = os.path.join(results_folder, file.replace('_model.pth', '_config.json'))

        # Check if the corresponding config file exists
        if os.path.exists(config_path):
            # Add the model and its configuration to the dictionary
            models_dict.append((model_path, config_path))
        else:
            print(f"Warning: Config file not found for model {file}")

In [17]:
models_dict

[('../results/unet_monai_best_model.pth',
  '../results/unet_monai_best_config.json'),
 ('../results/AttentionUNet-2_best_model.pth',
  '../results/AttentionUNet-2_best_config.json'),
 ('../results/MonaiUnet-1_best_model.pth',
  '../results/MonaiUnet-1_best_config.json'),
 ('../results/MonaiUnet-2_final_model.pth',
  '../results/MonaiUnet-2_final_config.json'),
 ('../results/atomai_unet-8_best_model.pth',
  '../results/atomai_unet-8_best_config.json'),
 ('../results/MonaiUnet-2_best_model.pth',
  '../results/MonaiUnet-2_best_config.json'),
 ('../results/unet_monai_best2_model.pth',
  '../results/unet_monai_best2_config.json'),
 ('../results/unet_model_best-binary2-channel0_model.pth',
  '../results/unet_model_best-binary2-channel0_config.json'),
 ('../results/MonaiUnet-1_final_model.pth',
  '../results/MonaiUnet-1_final_config.json'),
 ('../results/AHNet-8_best_model.pth', '../results/AHNet-8_best_config.json')]

In [18]:
for model_path, config_file in models_dict:
    print(model_path)
    print(config_file)

../results/unet_monai_best_model.pth
../results/unet_monai_best_config.json
../results/AttentionUNet-2_best_model.pth
../results/AttentionUNet-2_best_config.json
../results/MonaiUnet-1_best_model.pth
../results/MonaiUnet-1_best_config.json
../results/MonaiUnet-2_final_model.pth
../results/MonaiUnet-2_final_config.json
../results/atomai_unet-8_best_model.pth
../results/atomai_unet-8_best_config.json
../results/MonaiUnet-2_best_model.pth
../results/MonaiUnet-2_best_config.json
../results/unet_monai_best2_model.pth
../results/unet_monai_best2_config.json
../results/unet_model_best-binary2-channel0_model.pth
../results/unet_model_best-binary2-channel0_config.json
../results/MonaiUnet-1_final_model.pth
../results/MonaiUnet-1_final_config.json
../results/AHNet-8_best_model.pth
../results/AHNet-8_best_config.json


In [19]:
import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))


In [23]:
from  tqdm.auto import tqdm 

# Iterate over each pair of model/config in the results folder
for model_path, config_path in tqdm(models_dict):
        print(f'Processing {model_path} with config {config_path}')
        # Load the model configuration
        with open(config_path, 'r') as f:
            config = json.load(f)

        # Determine the net architecture based on the model name
        if 'atomai' in model_path:
            print(f'Using atomai_unet architecture for {model_path}')
            net_architecture = atomai_unet
        elif 'Attention' in model_path:
            print(f'Using base_att architecture for {model_path}')
            net_architecture = base_att
        elif 'AHNet' in model_path:
            print(f'Using monai_unetr architecture for {model_path}')
            net_architecture = monai_unetr
        else:
            print(f'Using monai_unet architecture for {model_path}')
            net_architecture = monai_unet

        # Load the model with the configuration parameters
        #model = net_architecture(**config['model_args'])

        # Initialize the Segmentator3 with the loaded model
        segmentator = Segmentator4(model_path, net_architecture(**config['model_args']),
                                    patch_size=256, overlap_ratio=0.25, half_precision=True)
        
        # Initialize an empty list to store the predictions for the current model
        model_predictions = []

        # Iterate over each source image
        for image_file in os.listdir(source_images_folder):
            if image_file.endswith('.tiff'):
                image_path = os.path.join(source_images_folder, image_file)
                image = imread(image_path)

                # Get predictions
                prediction = segmentator.predict(image)
                prediction=sigmoid(prediction)
                model_predictions.append(prediction)

        prediction_stacks.append(model_predictions)



  0%|          | 0/10 [00:00<?, ?it/s]

Processing ../results/unet_monai_best_model.pth with config ../results/unet_monai_best_config.json
Using monai_unet architecture for ../results/unet_monai_best_model.pth
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)
(1, 3, 2400, 2400)


 10%|█         | 1/10 [00:43<06:31, 43.47s/it]

(1, 3, 2400, 2400)
Processing ../results/AttentionUNet-2_best_model.pth with config ../results/AttentionUNet-2_best_config.json
Using base_att architecture for ../results/AttentionUNet-2_best_model.pth
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)


 20%|██        | 2/10 [02:25<10:23, 77.96s/it]

(1, 1, 2400, 2400)
Processing ../results/MonaiUnet-1_best_model.pth with config ../results/MonaiUnet-1_best_config.json
Using monai_unet architecture for ../results/MonaiUnet-1_best_model.pth
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)


 30%|███       | 3/10 [03:11<07:24, 63.53s/it]

(1, 1, 2400, 2400)
Processing ../results/MonaiUnet-2_final_model.pth with config ../results/MonaiUnet-2_final_config.json
Using monai_unet architecture for ../results/MonaiUnet-2_final_model.pth
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)


 40%|████      | 4/10 [06:15<11:05, 110.99s/it]

(1, 1, 2400, 2400)
Processing ../results/atomai_unet-8_best_model.pth with config ../results/atomai_unet-8_best_config.json
Using atomai_unet architecture for ../results/atomai_unet-8_best_model.pth
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)


 50%|█████     | 5/10 [07:51<08:47, 105.49s/it]

(1, 1, 2400, 2400)
Processing ../results/MonaiUnet-2_best_model.pth with config ../results/MonaiUnet-2_best_config.json
Using monai_unet architecture for ../results/MonaiUnet-2_best_model.pth
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)


 60%|██████    | 6/10 [10:49<08:40, 130.24s/it]

(1, 1, 2400, 2400)
Processing ../results/unet_monai_best2_model.pth with config ../results/unet_monai_best2_config.json
Using monai_unet architecture for ../results/unet_monai_best2_model.pth
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)


 70%|███████   | 7/10 [11:34<05:07, 102.33s/it]

(1, 1, 2400, 2400)
Processing ../results/unet_model_best-binary2-channel0_model.pth with config ../results/unet_model_best-binary2-channel0_config.json
Using monai_unet architecture for ../results/unet_model_best-binary2-channel0_model.pth
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)


 80%|████████  | 8/10 [12:18<02:47, 83.78s/it] 

(1, 1, 2400, 2400)
Processing ../results/MonaiUnet-1_final_model.pth with config ../results/MonaiUnet-1_final_config.json
Using monai_unet architecture for ../results/MonaiUnet-1_final_model.pth
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)
(1, 1, 2400, 2400)


 90%|█████████ | 9/10 [13:04<01:12, 72.07s/it]

(1, 1, 2400, 2400)
Processing ../results/AHNet-8_best_model.pth with config ../results/AHNet-8_best_config.json


 90%|█████████ | 9/10 [13:06<01:27, 87.34s/it]

Using monai_unetr architecture for ../results/AHNet-8_best_model.pth





TypeError: missing a required argument: 'img_size'

In [36]:
arr_list=[]
for i in prediction_stacks:
    arr_list.append(np.array(i))


  return 1 / (1 + np.exp(-x))


In [38]:
arr_list
for i in range(len(arr_list)):
    arr_list[i]=np.sum(arr_list[i], axis=2)

In [39]:
arr_list
for i in range(len(arr_list)):
    print(arr_list[i].shape)

(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)
(31, 1, 2400, 2400)


In [40]:
stacked_arrays = np.stack(arr_list, axis=0)

# Take the mean along the new axis (axis=0)
averaged_array = np.mean(stacked_arrays, axis=(0, 2))

In [42]:
averaged_array[0]

array([[0.1594 , 0.1362 , 0.12146, ..., 1.377  , 1.319  , 1.239  ],
       [0.1335 , 0.1042 , 0.1019 , ..., 1.403  , 1.391  , 1.309  ],
       [0.11926, 0.1024 , 0.0995 , ..., 1.412  , 1.407  , 1.332  ],
       ...,
       [0.1129 , 0.0903 , 0.087  , ..., 1.3545 , 1.331  , 1.286  ],
       [0.1129 , 0.1003 , 0.09204, ..., 1.328  , 1.305  , 1.25   ],
       [0.1339 , 0.1148 , 0.1081 , ..., 1.272  , 1.242  , 1.15   ]],
      dtype=float16)

In [43]:
import tifffile
tifffile.imwrite(os.path.join('../data/average_pred.tiff'), averaged_array)


In [49]:
image_titles = []

# Iterate over the image files and extract their names as titles
for image_file in os.listdir(source_images_folder):
    # Extract the image title from the file name
    image_title = os.path.splitext(image_file)[0]
    image_title = os.path.splitext(image_file)[0]
    if '.DS_Store' in image_title:
        continue

    print(image_title)
    image_titles.append(image_title)

# Create a dictionary to store the metadata
imagej_metadata = {
    'Labels': image_titles
}
tifffile.imwrite(os.path.join('../data/average_pred.tiff'), averaged_array.astype(np.float32), metadata=imagej_metadata, imagej=True)


coli_mask_frame_205
mabs_img_13
mabs_img_4
mabs_img_8
coli_mask_frame_1
coli_mask_frame_232
mabs_img_9
coli_mask_frame_101
mabs_img_5
mabs_img_12
mabs_img_19
coli_mask_frame_223
coli_mask_frame_274
mabs_img_15
mabs_img_2
mabs_img_3
mabs_img_14
coli_mask_frame_166
mabs_img_18
coli_mask_frame_10
coli_mask_frame_165
mabs_img_17
mabs_img_16
mabs_img_1
mabs_img_20
coli_mask_frame_109
mabs_img_6
mabs_img_11
coli_mask_frame_56
mabs_img_10
mabs_img_7


In [32]:
import tifffile
sum_array = np.sum(stacked_arrays, axis=(0, 2))
tifffile.imwrite(os.path.join('../data/summed_pred.tiff'), sum_array)

In [51]:
import os
import numpy as np
import tifffile

# ... (Your existing code to generate averaged_array) ...

# Convert averaged_array to a supported data type (e.g., float32)
averaged_array = averaged_array.astype(np.float32)

# Create a list to store the image titles and corresponding slices
image_data = []

# Iterate over the image files and extract their names as titles
for image_file in os.listdir(source_images_folder):
    # Extract the image title from the file name
    image_title = os.path.splitext(image_file)[0]
    if '.DS_Store' in image_title:
        continue
    # Assuming the slices are stored in averaged_array in the same order as the image files
    slice_index = len(image_data)
    
    # Append the image title and slice index to the image_data list
    image_data.append((image_title, slice_index))

# Sort the image_data list based on the image titles
image_data.sort(key=lambda x: x[0])

# Create lists to store the sorted image titles and slices
sorted_titles = [data[0] for data in image_data]
sorted_slices = [averaged_array[data[1]] for data in image_data]

# Stack the sorted slices to create the final sorted array
sorted_array = np.stack(sorted_slices)

# Create a dictionary to store the ImageJ metadata
imagej_metadata = {
    'Labels': sorted_titles
}

# Save the sorted array as a TIFF file with ImageJ metadata
tifffile.imwrite(os.path.join('../data/average_pred_sorted.tiff'), sorted_array, imagej=True, metadata=imagej_metadata)
