# Matrix

In [34]:
import pandas as pd
import numpy as np

# Radiosensitivity rankings
sensitivity_rankings = {
    'Background': 4, 
    'Bone': 4, 
    'Obtur': 4,
    'TZ': 0,       
    'CG': 0,        
    'Bladder': 6,
    'SV': 2,        
    'Rectum': 9,   
    'NVB': 10,      
}

labels = list(sensitivity_rankings.keys())
values = list(sensitivity_rankings.values())
n_classes = len(labels)

# Create grids 
sens_of_predicted = np.array([values] * n_classes)
sens_of_actual    = np.array([values] * n_classes).T

# Calculate error gap
sens_matrix = sens_of_predicted - sens_of_actual

# Apply 2x penalty for false negatives, dangerous errors (predicted < actual)
FN_weighting = 2
sens_matrix[sens_matrix < 0] = sens_matrix[sens_matrix < 0] * FN_weighting
sens_matrix = np.abs(sens_matrix)

In [35]:
# OLD CALCS
centroids = {
    'Background':(100, 100, 100),
    'TZ':        (0, 0, 0),       # Centre
    'CG':        (0, -5, 0),      # Touching TZ
    'Bladder':   (0, 10, 0),
    'Obtur':     (30, 0, 0),
    'Bone':      (50, 0, 0),      # Far laterally
    'Rectum':    (0, -15, 0),     # Posterior to prostate
    'SV':        (0, 5, 5),       # Superior/Posterior
    'NVB':       (5, -10, 0),     # Postero-lateral
}

# Euclidean distances
dist_matrix = np.zeros((n_classes, n_classes))
coords = [centroids[label] for label in labels]

for i in range(n_classes):
    for j in range(n_classes):
        # Distance between centroids
        dist = np.linalg.norm(np.array(coords[i]) - np.array(coords[j]))
        dist_matrix[i, j] = dist

# Invert distance for proximity weight
proximity_matrix = 100 / (dist_matrix + 10.0) 

np.fill_diagonal(proximity_matrix, 0)

total_weight_matrix = sens_matrix + proximity_matrix

df = pd.DataFrame(np.round(total_weight_matrix, 1), index=labels, columns=labels)

#Normalize matrix weights, average 1
df = (df * 81 / df.sum().sum())

In [36]:
# Sami calculated these using mean voxel distances from Prostate centres

avg_distances = {
    'Background': 81.75,
    'TZ': 20.27,      
    'CG': 14.44, 
    'Bladder': 45.27,    
    'Obtur': 49.67,      
    'Bone': 81.98,         
    'Rectum': 47.8,    
    'SV': 38.29,           
    'NVB': 29.50,  
}

In [37]:
# Placeholder of distance matrix:

distances = list(avg_distances.values())

distance_matrix = np.ones((9, 9))
for i in range(9):
    distance_matrix[i, :] = distances[i]

In [38]:
distance_matrix

array([[81.75, 81.75, 81.75, 81.75, 81.75, 81.75, 81.75, 81.75, 81.75],
       [20.27, 20.27, 20.27, 20.27, 20.27, 20.27, 20.27, 20.27, 20.27],
       [14.44, 14.44, 14.44, 14.44, 14.44, 14.44, 14.44, 14.44, 14.44],
       [45.27, 45.27, 45.27, 45.27, 45.27, 45.27, 45.27, 45.27, 45.27],
       [49.67, 49.67, 49.67, 49.67, 49.67, 49.67, 49.67, 49.67, 49.67],
       [81.98, 81.98, 81.98, 81.98, 81.98, 81.98, 81.98, 81.98, 81.98],
       [47.8 , 47.8 , 47.8 , 47.8 , 47.8 , 47.8 , 47.8 , 47.8 , 47.8 ],
       [38.29, 38.29, 38.29, 38.29, 38.29, 38.29, 38.29, 38.29, 38.29],
       [29.5 , 29.5 , 29.5 , 29.5 , 29.5 , 29.5 , 29.5 , 29.5 , 29.5 ]])

In [39]:
distance_weighting = 1 / distance_matrix

distance_matrix = (distance_weighting * 81 / distance_weighting.sum())

# Normalize
pd.DataFrame(distance_matrix.round(2), index=labels, columns=labels)

Unnamed: 0,Background,Bone,Obtur,TZ,CG,Bladder,SV,Rectum,NVB
Background,0.41,0.41,0.41,0.41,0.41,0.41,0.41,0.41,0.41
Bone,1.67,1.67,1.67,1.67,1.67,1.67,1.67,1.67,1.67
Obtur,2.34,2.34,2.34,2.34,2.34,2.34,2.34,2.34,2.34
TZ,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
CG,0.68,0.68,0.68,0.68,0.68,0.68,0.68,0.68,0.68
Bladder,0.41,0.41,0.41,0.41,0.41,0.41,0.41,0.41,0.41
SV,0.71,0.71,0.71,0.71,0.71,0.71,0.71,0.71,0.71
Rectum,0.88,0.88,0.88,0.88,0.88,0.88,0.88,0.88,0.88
NVB,1.15,1.15,1.15,1.15,1.15,1.15,1.15,1.15,1.15


In [40]:
accuracy_matrix = 1 - np.eye(9) 
# Used for asserting that we want segmentations to be accurate in general, to give doctors a better picture of the overall image
# Equivalent to regular dice loss

In [41]:
print(pd.DataFrame(sens_matrix.round(1), index=labels, columns=labels))
print('-' * 50)
print(pd.DataFrame(distance_matrix.round(1), index=labels, columns=labels))
print('-' * 50)
print(pd.DataFrame(accuracy_matrix.round(1), index=labels, columns=labels))

            Background  Bone  Obtur  TZ  CG  Bladder  SV  Rectum  NVB
Background           0     0      0   8   8        2   4       5    6
Bone                 0     0      0   8   8        2   4       5    6
Obtur                0     0      0   8   8        2   4       5    6
TZ                   4     4      4   0   0        6   2       9   10
CG                   4     4      4   0   0        6   2       9   10
Bladder              4     4      4  12  12        0   8       3    4
SV                   2     2      2   4   4        4   0       7    8
Rectum              10    10     10  18  18        6  14       0    1
NVB                 12    12     12  20  20        8  16       2    0
--------------------------------------------------
            Background  Bone  Obtur   TZ   CG  Bladder   SV  Rectum  NVB
Background         0.4   0.4    0.4  0.4  0.4      0.4  0.4     0.4  0.4
Bone               1.7   1.7    1.7  1.7  1.7      1.7  1.7     1.7  1.7
Obtur              2.3   2.3  

In [42]:
sens_matrix.sum()

np.int64(450)

In [None]:
# Final matrix calculation:

# Normalize
sens_matrix = sens_matrix * 72 / sens_matrix.sum() # Now both sens and acc matrices have a total of 72

sens_importance = 3/4 # We say that sensitivity distinguishing is twice as important as regular dice loss

# Calculate overall importance
importance_matrix = (sens_importance * sens_matrix * distance_matrix + (1 - sens_importance) * accuracy_matrix)

# Apply distances
# overall_matrix = importance_matrix * distance_matrix # We multiply both the accuracy and sensitivity matrices by distance, since both concerns are more important closer to the region of interest i.e. prostate

# Normalize a final time
final_matrix = importance_matrix * 72 / importance_matrix.sum()

df = pd.DataFrame(np.round(final_matrix, 2), index=labels, columns=labels)

df

Unnamed: 0,Background,Bone,Obtur,TZ,CG,Bladder,SV,Rectum,NVB
Background,0.0,0.11,0.11,0.51,0.51,0.21,0.31,0.36,0.41
Bone,0.42,0.0,0.42,2.05,2.05,0.83,1.24,1.44,1.65
Obtur,0.6,0.6,0.0,2.88,2.88,1.17,1.74,2.02,2.31
TZ,0.55,0.55,0.55,0.0,0.19,0.74,0.37,1.01,1.1
CG,0.51,0.51,0.51,0.17,0.0,0.67,0.34,0.92,1.0
Bladder,0.31,0.31,0.31,0.71,0.71,0.0,0.51,0.26,0.31
SV,0.35,0.35,0.35,0.53,0.53,0.53,0.0,0.78,0.87
Rectum,1.3,1.3,1.3,2.17,2.17,0.87,1.73,0.0,0.33
NVB,1.97,1.97,1.97,3.09,3.09,1.41,2.53,0.57,0.0


In [199]:
# Check that it's normalized, should be 72
df.sum().sum()

np.float64(72.03999999999998)

## Metrics and functions

# Non-differentiable metric function

In [200]:
import torch

def compute_weighted_dice_score(pred, target, weight_matrix, num_classes=3, epsilon=1e-6):

    pred = pred.view(-1).long()
    target = target.view(-1).long()

    ids = target * num_classes + pred
    
    cm = torch.bincount(ids, minlength=num_classes**2).view(num_classes, num_classes).t().float()

    weighted_cm = cm * weight_matrix

    tp = cm.diagonal()

    fp = weighted_cm.sum(dim=1) # Summing across rows, where we always predict a particular class

    fn = weighted_cm.sum(dim=0) # Summing across columns, where a particular class is always the target

    score = (2 * tp) / (2 * tp + fp + fn + epsilon)

    return score

## Testing

In [201]:
weight_matrix = final_matrix

target = torch.Tensor([[
    [8, 1, 1, 3],
    [5, 1, 1, 6],
    [1, 1, 1, 0],
    [2, 2, 0, 0],
]])

pred = torch.Tensor([[
    [8, 1, 1, 8],
    [5, 1, 1, 5],
    [2, 1, 1, 1],
    [2, 2, 0, 0],
]])


scores = compute_weighted_dice_score(pred, target, weight_matrix, num_classes=9)

print("Confusion Matrix (Internal Calculation):")

print("-" * 30)
for i in range(9):
    print(f"Class {i} Loss: {scores[i]:.4f}")
print("-" * 30)
print(f"Mean Dice Loss: {scores.mean():.4f}")

Confusion Matrix (Internal Calculation):
------------------------------
Class 0 Loss: 0.8763
Class 1 Loss: 0.8984
Class 2 Loss: 0.8346
Class 3 Loss: 0.0000
Class 4 Loss: 0.0000
Class 5 Loss: 0.8010
Class 6 Loss: 0.0000
Class 7 Loss: 0.0000
Class 8 Loss: 0.4106
------------------------------
Mean Dice Loss: 0.4245


  weighted_cm = cm * weight_matrix


## Differentiable matrix

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# This assumes that the target is input not already one-hotted
class WeightedDiceScore(nn.Module):
    def __init__(self, weight_matrix, epsilon=1e-6):
        super().__init__()
        # self.register_buffer("weight_matrix", weight_matrix)
        self.weight_matrix = weight_matrix
        self.epsilon = epsilon

    def forward(self, pred, target):
        """
        Args:
            pred (torch.Tensor): Probabilities (B, C, H, W) - Output of Softmax
            target (torch.Tensor): Ground Truth Indices (B, H, W)
        """
        num_classes = pred.shape[1]
        
        # 1. Convert Target to One-Hot (B, C, H, W)
        # We assume target has shape (B, H, W)
        target_onehot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()

        # 2. Flatten spatial dimensions for efficient matrix multiplication
        # Shapes become: (B, C, N) where N = H*W
        pred_flat = pred.flatten(2) 
        target_flat = target_onehot.flatten(2)
        
        # 3. Compute Soft Confusion Matrix via Einsum
        # Equation: For every batch b, sum over pixels n: pred(c) * target(k)
        # Result: (C, C) matrix where row=Pred, col=Target
        # This effectively sums up the "probability mass" for every pred/target pair
        soft_cm = torch.einsum("bcn, bkn -> ck", pred_flat, target_flat)
        
        # 4. Apply Weights
        # weight_matrix is (C, C). Element-wise multiplication applies penalties.
        # Since diagonal of weight_matrix is 0, TP contributions are zeroed out here.
        weighted_cm = soft_cm * self.weight_matrix

        # 5. Calculate Components
        
        # TP: Diagonal of the original Soft Confusion Matrix (unweighted)
        tp = torch.diagonal(soft_cm)
        
        # FP_weighted: Sum of Weighted Rows (Predicted c, Actual k)
        # sum(dim=1) collapses columns
        fp_weighted = weighted_cm.sum(dim=1)
        
        # FN_weighted: Sum of Weighted Columns (Actual c, Predicted k)
        # sum(dim=0) collapses rows
        fn_weighted = weighted_cm.sum(dim=0)

        # 6. Dice Formula
        numerator = 2 * tp
        denominator = (2 * tp) + fp_weighted + fn_weighted + self.epsilon
        
        scores = numerator / denominator
        
        return scores

## Testing

In [203]:
weight_matrix = torch.Tensor([
    [0, 1, 1],
    [1, 0, 1],
    [1, 1, 0],
])

# Create a differentiable "Model Output" (Logits)
# Batch=1, Classes=3, Height=2, Width=2
# logits = torch.randn(1, 3, 2, 2, requires_grad=True)
a = 1000.0
logits = torch.tensor([[
    [[a, 0.0],
    [0.0, 0.0]],
    [[0.0, a],
    [-100, 0.0]],
    [[0.0, 0.0],
    [0.0, a]],
]], requires_grad=True)

logits2 = torch.zeros((1, 6, 2, 2))

logits = torch.cat([logits, logits2], dim=1)

# Apply Softmax to get probabilities (Required for this loss)
probs = F.softmax(logits, dim=1)

print(logits.shape)
print(probs.shape)

# Target (Indices)
target = torch.tensor([[[0, 1],
                        [1, 2]]]) # Shape (1, 2, 2)

# 2. Initialize Module
dice_calc = WeightedDiceScore(torch.Tensor(final_matrix))

# 3. Forward Pass
scores = dice_calc(probs, target)
print(f"Dice Scores per class: {scores}")

torch.Size([1, 9, 2, 2])
torch.Size([1, 9, 2, 2])
Dice Scores per class: tensor([0.9913, 0.7313, 0.9528, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       grad_fn=<DivBackward0>)
