<h1> Inference First </h1>

In [None]:
from TCFile import TCFile
import numpy as np
import torch
import cv2
from torchvision import models
import utils

def process(path, time, model, mini_model, crop_size=(160,160), adaptive_crop=False, overlap=True, stride_proportion=0.5,
            wanted_patches=[], background_vote=False):
    # Load                                                                                    & prep data
    file = TCFile(path, '2DMIP')
    crop_size = crop_size

    slice_2d = utils.resize_tomogram_mip(
        file[time],
        data_resolution=file.data_resolution,
        target_resolution=0.1632,
        mode='mip'
    )

    if adaptive_crop:
        # Adjust the patch size if you want the patches to fit the image nicely
        residual = slice_2d.shape[0] % crop_size[0]
        fit = slice_2d.shape[0] // crop_size[0]
        adapt = residual // fit if fit > 0 else 0
        crop_size = (crop_size[0] + adapt, crop_size[1] + adapt)
        stride = int(crop_size[0]*stride_proportion)

    base_image = slice_2d.copy()
    # base_image = utils.image_normalization(base_image, min=1.33, max=1.4)

    # Normal PyTorch model prep
    transform = models.ResNet101_Weights.IMAGENET1K_V2.transforms()
    mini_transform = models.ResNet50_Weights.IMAGENET1K_V2.transforms()

    model = model.cuda().eval()
    mini_model = mini_model.cuda().eval()

    # We'll store class votes for each pixel here: (H, W, num_classes)
    num_classes = 1+5  # Adjust if you actually have a different number
    label_counts = np.zeros((base_image.shape[0], base_image.shape[1], num_classes), dtype=np.int32)
    patch_layer = np.zeros((base_image.shape[0], base_image.shape[1]), dtype=np.int32)
    
    patch_coord = []
    wanted_patches = wanted_patches
    wanted_probabilities = []
    wanted_images = []
    
    if overlap:
        # Overlapping patches: define a smaller stride
        # e.g. half the patch size in each dimension
        stride_h = stride_w = stride
        patch_num = 0

        for top in range(0, base_image.shape[0] - crop_size[0] + 1, stride_h):
            for left in range(0, base_image.shape[1] - crop_size[1] + 1, stride_w):
                patch_coord.append((top, left, crop_size[0]))
                # Extract patch
                cropped = base_image[top:top+crop_size[0], left:left+crop_size[1]]

                # Clamp intensities (like your code)
                cropped = np.clip(cropped, 1.33, 1.40)
                cropped = utils.image_normalization(cropped, min=1.33, max=1.40)

                # Decide if we run the main model or the mini_model logic
                temp = cropped.copy()
                temp[temp < 45] = 0
                temp[temp > 45] = 1
                proportion = np.count_nonzero(temp == 1) / (np.count_nonzero(temp == 0) + np.count_nonzero(temp == 1))

                if proportion < 0.05:
                    # Use mini_model to check necrosis
                    mini_temp = cv2.equalizeHist(cropped.astype(np.uint8))
                    mini_temp = torch.from_numpy(mini_temp).repeat(3, 1, 1).float()
                    mini_temp = mini_transform(mini_temp).cuda()

                    is_necrosis = torch.max(mini_model(mini_temp.unsqueeze(0)), 1)[1].item()

                    if is_necrosis == 1:
                        # If necrosis, run the main model
                        image_tensor = torch.from_numpy(cropped).repeat(3, 1, 1).float()
                        image_tensor = transform(image_tensor).cuda()
                        output = model(image_tensor.unsqueeze(0)).cpu()
                        _, pred = torch.max(output, 1)
                        class_label = pred.item() + 1  # 0..(num_classes-1)
                    else:
                        class_label = 0  # or some "background" label
                else:
                    # Directly run main model
                    image_tensor = torch.from_numpy(cropped).repeat(3, 1, 1).float()
                    image_tensor = transform(image_tensor).cuda()
                    output = model(image_tensor.unsqueeze(0)).cpu()
                    _, pred = torch.max(output, 1)
                    class_label = pred.item() + 1
                    
                
                if patch_num in wanted_patches:
                    wanted_probabilities.append(output.detach().numpy())
                    wanted_images.append(base_image[top:top+crop_size[0], left:left+crop_size[1]])

                # Add a vote for this class in label_counts
                label_counts[top:top+crop_size[0], left:left+crop_size[1], class_label] += 1
                patch_layer[top:top+crop_size[0], left:left+crop_size[1]] = patch_num
                patch_num += 1
                
    else:
        # Original no-overlap approach
        patches = utils.crop_patch(slice_2d, crop_size=crop_size, overlap=False)
        patch_num = 0
        for patch_index, patch in enumerate(patches):
            top, left = patch[0], patch[1]
            patch_coord.append((top, left, crop_size[0]))
            cropped = base_image[top:top+crop_size[0], left:left+crop_size[1]]

            cropped = np.clip(cropped, 1.33, 1.40)
            cropped = utils.image_normalization(cropped, min=1.33, max=1.40)

            temp = cropped.copy()
            temp[temp < 45] = 0
            temp[temp > 45] = 1
            proportion = np.count_nonzero(temp == 1) / (np.count_nonzero(temp == 0) + np.count_nonzero(temp == 1))

            if proportion < 0.05:
                mini_temp = cv2.equalizeHist(cropped.astype(np.uint8))
                mini_temp = torch.from_numpy(mini_temp).repeat(3, 1, 1).float()
                mini_temp = mini_transform(mini_temp).cuda()
                is_necrosis = torch.max(mini_model(mini_temp.unsqueeze(0)), 1)[1].item()

                if is_necrosis == 1:
                    image_tensor = torch.from_numpy(cropped).repeat(3, 1, 1).float()
                    image_tensor = transform(image_tensor).cuda()
                    output = model(image_tensor.unsqueeze(0)).cpu()
                    _, pred = torch.max(output, 1)
                    class_label = pred.item() + 1
                else:
                    class_label = 0
            else:
                image_tensor = torch.from_numpy(cropped).repeat(3, 1, 1).float()
                image_tensor = transform(image_tensor).cuda()
                output = model(image_tensor.unsqueeze(0)).cpu()
                _, pred = torch.max(output, 1)
                class_label = pred.item() + 1
                
                
            if patch_num in wanted_patches:
                wanted_probabilities.append(output.detach().numpy())
                wanted_images.append(base_image[top:top+crop_size[0], left:left+crop_size[1]])
                
            label_counts[top:top+crop_size[0], left:left+crop_size[1], class_label] += 1
            patch_layer[top:top+crop_size[0], left:left+crop_size[1]] = patch_num
            patch_num += 1
    # Final pixelwise label = most frequent vote + 1 (if you want 1-based classes)
    # label_image = np.argmax(label_counts, axis=-1)
    
    # --- MODIFIED FINAL LABEL ASSIGNMENT ---
    # Final pixelwise label assignment with background priority

    # Calculate the standard argmax first (most frequent vote overall)
    provisional_label_image = np.argmax(label_counts, axis=-1)

    if background_vote:
        # Identify pixels where at least one patch voted for background (class 0)
        # label_counts[:, :, 0] accesses the counts for class 0 for all pixels
        total = np.sum(label_counts, axis=2)
        has_background_vote = 0.15 < label_counts[:, :, 0]/total

        # Initialize the final label image with the provisional result
        label_image = provisional_label_image.copy()

        # Override: wherever a background vote exists, set the final label to 0
        label_image[has_background_vote] = 0
        # --- END OF MODIFIED SECTION ---
    else:
        label_image = provisional_label_image.copy()

    return base_image, label_image, patch_layer, patch_coord, wanted_probabilities, wanted_images


In [None]:
from tqdm import tqdm
import torch
from torchvision import models
import utils
import numpy as np

path = r"C:\rkka_Projects\cell_death_v2\Data\9. A549_FasL(20250410)\250409.170229.A549_FasL_01.025.Group3.B5.T025P03.TCF"
class_num = 5
model_path = r"C:\rkka_Projects\cell_death_v2\trained_models\test_5_classes_22.032991_0.9728_sota.pth"
mini_model_path = r"C:\rkka_Projects\cell_death_v2\trained_models\mini_ai_epoch_9_0.000861_1.0000.pth"
file = TCFile(path, '2DMIP')
# Load Model
model = models.resnet101(pretrained=True)
num_features = model.fc.in_features
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.2),
    torch.nn.Linear(num_features, 5)
)
model.load_state_dict(torch.load(model_path))

# Load mini Model
mini_model = models.resnet50(pretrained=True)
num_features = mini_model.fc.in_features
mini_model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.2),
    torch.nn.Linear(num_features, 2)
)
mini_model.load_state_dict(torch.load(mini_model_path))

# Process
base_stack = []
label_stack = []
probabilities_stack = []

# wanted_patches = [110, 334, 329, 258, 192]
# wanted_patches = [0]

# for i in tqdm(range(len(file))):
for i in tqdm(range(0, 73)):
    # base_image, label_image = process(path, i, model, mini_model, adaptive_crop=True)
    base_image, label_image, patch_layer, patch_coord, wanted_probabilities, wanted_images = process(path, i, model, mini_model, 
                                                                                                     crop_size=(240,240),
                                      adaptive_crop=True, overlap=True, stride_proportion=0.1, wanted_patches=[],
                                      background_vote=True)
    base_stack.append(base_image)
    label_stack.append(label_image)

  has_background_vote = 0.15 < label_counts[:, :, 0]/total
100%|██████████| 73/73 [11:08<00:00,  9.16s/it]


In [None]:
base_stack_array = np.array(base_stack)
label_stack_array = np.array(label_stack)

In [None]:
import napari
viewer = napari.Viewer()

In [None]:
viewer.add_image(base_stack_array)

colors = {1: '#FF9B9B', 2: '#FFD89C', 3: '#B99470', 4: '#3B6790', 5: "#98D8EF"}
labels_layer = viewer.add_labels(label_stack_array.astype(int), colormap=colors)
labels_layer.color_mode = 'direct'

<Image layer 'base_stack_array' at 0x162c6711190>

<h1> Label overlay </h1>

In [None]:
import numpy as np
import cv2
from matplotlib.colors import ListedColormap
import utils
from PIL import Image

time = 0
patch = 110
# Load grayscale base image
# path = r"C:\rkka_Projects\cell_death_v2\Data\Large_FOV\240808.200653.death_B4C4B5.005.Group2.C4.T001P01.TCF"
# base = TCFile(path, '2DMIP')
# base_stack_array = np.array([b for b in base])[:33]
# Load grayscale base image

ri_max = np.max(base_stack_array[time])
base_stack_array = np.array(base_stack)
base = utils.image_normalization(base_stack_array[time], min=1.33, max=1.42)  # Ensure it's in the correct format

# base = base[0:960, 0:960]
# Load label data
label = label_stack_array[time]

# label = label[0:960, 0:960]
# Define colors for labels
# colors = {0: (0, 0, 0), 1: (255, 155, 155), 2: (255, 216, 156), 
#           3: (185, 148, 112), 4: (59, 103, 144), 5: (152, 216, 239)}

colors = {0: (0, 0, 0), 1: (255, 155, 155), 2: (255, 216, 156), 
          3: (185, 148, 112), 4: (152, 216, 239), 5: (152, 216, 239)}

# Create a blank color image to store overlay
overlay = np.zeros((*label.shape, 3), dtype=np.uint8)

# Assign colors to labels
for lbl, color in colors.items():
    overlay[label == lbl] = color

# Blend the label overlay with the grayscale image
alpha = 0.4  # Transparency factor
blended = cv2.addWeighted(cv2.cvtColor(base, cv2.COLOR_GRAY2BGR), 1, overlay, alpha, 0)

# Save the final overlaid image
image = Image.fromarray(blended)
image.save(f'figures/figure5/whole_t_{time//2}h.png')