<h1> Inference First </h1>

In [1]:
from TCFile import TCFile
import numpy as np
import torch
import cv2
from torchvision import models
import utils

def process(path, time, model, mini_model, crop_size=(160,160), adaptive_crop=False, overlap=True, stride_proportion=0.5,
            wanted_patches=[], background_vote=False):
    # Load                                                                                    & prep data
    file = TCFile(path, '2DMIP')
    crop_size = crop_size

    slice_2d = utils.resize_tomogram_mip(
        file[time],
        data_resolution=file.data_resolution,
        target_resolution=0.1632,
        mode='mip'
    )

    if adaptive_crop:
        # Adjust the patch size if you want the patches to fit the image nicely
        residual = slice_2d.shape[0] % crop_size[0]
        fit = slice_2d.shape[0] // crop_size[0]
        adapt = residual // fit if fit > 0 else 0
        crop_size = (crop_size[0] + adapt, crop_size[1] + adapt)
        stride = int(crop_size[0]*stride_proportion)

    base_image = slice_2d.copy()
    # base_image = utils.image_normalization(base_image, min=1.33, max=1.4)

    # Normal PyTorch model prep
    transform = models.ResNet101_Weights.IMAGENET1K_V2.transforms()
    mini_transform = models.ResNet50_Weights.IMAGENET1K_V2.transforms()

    model = model.cuda().eval()
    mini_model = mini_model.cuda().eval()

    # We'll store class votes for each pixel here: (H, W, num_classes)
    num_classes = 1+5  # Adjust if you actually have a different number
    label_counts = np.zeros((base_image.shape[0], base_image.shape[1], num_classes), dtype=np.int32)
    patch_layer = np.zeros((base_image.shape[0], base_image.shape[1]), dtype=np.int32)
    
    patch_coord = []
    wanted_patches = wanted_patches
    wanted_probabilities = []
    wanted_images = []
    
    if overlap:
        # Overlapping patches: define a smaller stride
        # e.g. half the patch size in each dimension
        stride_h = stride_w = stride
        patch_num = 0

        for top in range(0, base_image.shape[0] - crop_size[0] + 1, stride_h):
            for left in range(0, base_image.shape[1] - crop_size[1] + 1, stride_w):
                patch_coord.append((top, left, crop_size[0]))
                # Extract patch
                cropped = base_image[top:top+crop_size[0], left:left+crop_size[1]]

                # Clamp intensities (like your code)
                cropped = np.clip(cropped, 1.33, 1.40)
                cropped = utils.image_normalization(cropped, min=1.33, max=1.40)

                # Decide if we run the main model or the mini_model logic
                temp = cropped.copy()
                temp[temp < 45] = 0
                temp[temp > 45] = 1
                proportion = np.count_nonzero(temp == 1) / (np.count_nonzero(temp == 0) + np.count_nonzero(temp == 1))

                if proportion < 0.05:
                    # Use mini_model to check necrosis
                    mini_temp = cv2.equalizeHist(cropped.astype(np.uint8))
                    mini_temp = torch.from_numpy(mini_temp).repeat(3, 1, 1).float()
                    mini_temp = mini_transform(mini_temp).cuda()

                    is_necrosis = torch.max(mini_model(mini_temp.unsqueeze(0)), 1)[1].item()

                    if is_necrosis == 1:
                        # If necrosis, run the main model
                        image_tensor = torch.from_numpy(cropped).repeat(3, 1, 1).float()
                        image_tensor = transform(image_tensor).cuda()
                        output = model(image_tensor.unsqueeze(0)).cpu()
                        _, pred = torch.max(output, 1)
                        class_label = pred.item() + 1  # 0..(num_classes-1)
                    else:
                        class_label = 0  # or some "background" label
                else:
                    # Directly run main model
                    image_tensor = torch.from_numpy(cropped).repeat(3, 1, 1).float()
                    image_tensor = transform(image_tensor).cuda()
                    output = model(image_tensor.unsqueeze(0)).cpu()
                    _, pred = torch.max(output, 1)
                    class_label = pred.item() + 1
                    
                
                if patch_num in wanted_patches:
                    wanted_probabilities.append(output.detach().numpy())
                    wanted_images.append(base_image[top:top+crop_size[0], left:left+crop_size[1]])

                # Add a vote for this class in label_counts
                label_counts[top:top+crop_size[0], left:left+crop_size[1], class_label] += 1
                patch_layer[top:top+crop_size[0], left:left+crop_size[1]] = patch_num
                patch_num += 1
                
    else:
        # Original no-overlap approach
        patches = utils.crop_patch(slice_2d, crop_size=crop_size, overlap=False)
        patch_num = 0
        for patch_index, patch in enumerate(patches):
            top, left = patch[0], patch[1]
            patch_coord.append((top, left, crop_size[0]))
            cropped = base_image[top:top+crop_size[0], left:left+crop_size[1]]

            cropped = np.clip(cropped, 1.33, 1.40)
            cropped = utils.image_normalization(cropped, min=1.33, max=1.40)

            temp = cropped.copy()
            temp[temp < 45] = 0
            temp[temp > 45] = 1
            proportion = np.count_nonzero(temp == 1) / (np.count_nonzero(temp == 0) + np.count_nonzero(temp == 1))

            if proportion < 0.05:
                mini_temp = cv2.equalizeHist(cropped.astype(np.uint8))
                mini_temp = torch.from_numpy(mini_temp).repeat(3, 1, 1).float()
                mini_temp = mini_transform(mini_temp).cuda()
                is_necrosis = torch.max(mini_model(mini_temp.unsqueeze(0)), 1)[1].item()

                if is_necrosis == 1:
                    image_tensor = torch.from_numpy(cropped).repeat(3, 1, 1).float()
                    image_tensor = transform(image_tensor).cuda()
                    output = model(image_tensor.unsqueeze(0)).cpu()
                    _, pred = torch.max(output, 1)
                    class_label = pred.item() + 1
                else:
                    class_label = 0
            else:
                image_tensor = torch.from_numpy(cropped).repeat(3, 1, 1).float()
                image_tensor = transform(image_tensor).cuda()
                output = model(image_tensor.unsqueeze(0)).cpu()
                _, pred = torch.max(output, 1)
                class_label = pred.item() + 1
                
                
            if patch_num in wanted_patches:
                wanted_probabilities.append(output.detach().numpy())
                wanted_images.append(base_image[top:top+crop_size[0], left:left+crop_size[1]])
                
            label_counts[top:top+crop_size[0], left:left+crop_size[1], class_label] += 1
            patch_layer[top:top+crop_size[0], left:left+crop_size[1]] = patch_num
            patch_num += 1
    # Final pixelwise label = most frequent vote + 1 (if you want 1-based classes)
    # label_image = np.argmax(label_counts, axis=-1)
    
    # --- MODIFIED FINAL LABEL ASSIGNMENT ---
    # Final pixelwise label assignment with background priority

    # Calculate the standard argmax first (most frequent vote overall)
    provisional_label_image = np.argmax(label_counts, axis=-1)

    if background_vote:
        # Identify pixels where at least one patch voted for background (class 0)
        # label_counts[:, :, 0] accesses the counts for class 0 for all pixels
        total = np.sum(label_counts, axis=2)
        has_background_vote = 0.15 < label_counts[:, :, 0]/total

        # Initialize the final label image with the provisional result
        label_image = provisional_label_image.copy()

        # Override: wherever a background vote exists, set the final label to 0
        label_image[has_background_vote] = 0
        # --- END OF MODIFIED SECTION ---
    else:
        label_image = provisional_label_image.copy()

    return base_image, label_image, patch_layer, patch_coord, wanted_probabilities, wanted_images


In [2]:
from tqdm import tqdm
import torch
from torchvision import models
import utils
import numpy as np

path = r"C:\rkka_Projects\cell_death_v2\Data\9. A549_FasL(20250410)\250409.170229.A549_FasL_01.025.Group3.B5.T025P03.TCF"
class_num = 5
model_path = r"C:\rkka_Projects\cell_death_v2\trained_models\test_5_classes_22.032991_0.9728_sota.pth"
mini_model_path = r"C:\rkka_Projects\cell_death_v2\trained_models\mini_ai_epoch_9_0.000861_1.0000.pth"
file = TCFile(path, '2DMIP')
# Load Model
model = models.resnet101(pretrained=True)
num_features = model.fc.in_features
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.2),
    torch.nn.Linear(num_features, 5)
)
model.load_state_dict(torch.load(model_path))

# Load mini Model
mini_model = models.resnet50(pretrained=True)
num_features = mini_model.fc.in_features
mini_model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.2),
    torch.nn.Linear(num_features, 2)
)
mini_model.load_state_dict(torch.load(mini_model_path))

# Process
base_stack = []
label_stack = []
probabilities_stack = []

# wanted_patches = [110, 334, 329, 258, 192]
# wanted_patches = [0]

# for i in tqdm(range(len(file))):
for i in tqdm(range(0, 73)):
    # base_image, label_image = process(path, i, model, mini_model, adaptive_crop=True)
    base_image, label_image, patch_layer, patch_coord, wanted_probabilities, wanted_images = process(path, i, model, mini_model, 
                                                                                                     crop_size=(240,240),
                                      adaptive_crop=True, overlap=False, stride_proportion=0.1, wanted_patches=[],
                                      background_vote=True)
    base_stack.append(base_image)
    label_stack.append(label_image)

  has_background_vote = 0.15 < label_counts[:, :, 0]/total
100%|██████████| 73/73 [00:12<00:00,  5.90it/s]


In [3]:
base_stack_array = np.array(base_stack)
label_stack_array = np.array(label_stack)

In [4]:
import napari
viewer = napari.Viewer()

In [5]:
viewer.add_image(base_stack_array)

colors = {1: '#FF9B9B', 2: '#FFD89C', 3: '#B99470', 4: '#3B6790', 5: "#98D8EF"}
labels_layer = viewer.add_labels(label_stack_array.astype(int), colormap=colors)
labels_layer.color_mode = 'direct'

  warn(


<h1> Fluorescence one picture </h1>

In [12]:
from TCFile import TCFile
import numpy as np
import cv2

time = 0
base = base_stack_array[0].copy()
# viewer.add_image(base_stack)

colors = ['Blue', 'Green', 'Red']

ch0_list = []
ch1_list = []
ch2_list = []
for i in range(1):
    file = TCFile(path, '3DFL', channel=i, only_one=True)
    temp = np.max(file[time], axis=0)
    temp = cv2.resize(temp, dsize=base.shape)
    if i==0:
        ch0_list.append(temp)
    elif i==1:
        ch1_list.append(temp)
    elif i==2:
        ch2_list.append(temp)
            
ch0_stack = np.array(ch0_list)    
ch1_stack = np.array(ch1_list)
ch2_stack = np.array(ch2_list)

viewer.add_image(base)
viewer.add_labels(patch_layer)
viewer.add_image(ch0_stack, opacity=0.3, colormap=colors[0])
# viewer.add_image(ch1_stack, opacity=0.3, colormap=colors[1])
# viewer.add_image(ch2_stack, opacity=0.3, colormap=colors[2])

<Image layer 'ch0_stack [1]' at 0x2920443b800>

In [19]:
from PIL import Image

# red = ch2_stack[0]
# green = ch1_stack[0]
blue = ch0_stack[0]

# red = utils.image_normalization(red, min=np.min(red), max=np.max(red))
# green = utils.image_normalization(green, min=np.min(green), max=np.max(green))
blue = utils.image_normalization(blue, min=np.min(blue), max=np.max(blue))

overlay = np.zeros((blue.shape[0], blue.shape[1], 3))
# overlay[:,:,0] = red
# overlay[:,:,1] = green
overlay[:,:,2] = blue
overlay = overlay.astype(np.uint8)
image = Image.fromarray(overlay)
image.save('test.png')

In [15]:
patch_coord[5]

(248, 248, 248)

In [None]:
patch = 192

top, left, crop_size = patch_coord[patch]
temp = overlay[top:top+crop_size, left:left+crop_size, :].copy()
temp = temp.astype(np.uint8)
image = Image.fromarray(temp)
image.save(f'figures/figure4/FL_patch_{patch}.png')