In [1]:
import sys


sys.path.append("../src/")

In [2]:
from sennet.custom_modules.metrics.surface_dice_metric_fast import create_table_neighbour_code_to_surface_area
import torch
import numpy as np

In [3]:
device = "cpu"


unfold = torch.nn.Unfold(kernel_size=(2, 2), padding=1)


area = create_table_neighbour_code_to_surface_area((1, 1, 1))[1:]
area = torch.from_numpy(area).to(device)  # torch.float32
area.shape

torch.Size([255])

In [4]:
bases = torch.tensor([
    [c=="1" for c in f"{i:08b}"]
    for i in range(1, 256)
]).float().permute((1, 0)).to(device)
print(f"{bases.shape=}")
print(bases)

bases.shape=torch.Size([8, 255])
tensor([[0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 1., 1.,  ..., 0., 1., 1.],
        [1., 0., 1.,  ..., 1., 0., 1.]])


In [8]:
batch_size = 3
zs = 2
ys = 10
xs = 10

pred = torch.rand((batch_size, zs, ys, xs)).to(device)
labels = (torch.rand((batch_size, zs, ys, xs)) > 0.5).to(device)

In [9]:
unfolded_pred = unfold(pred)
unfolded_labels = unfold(labels.float())

permuted_unfolded_pred = unfolded_pred.permute((0, 2, 1)).unsqueeze(-1)
permuted_unfolded_labels = unfolded_labels.permute((0, 2, 1)).unsqueeze(-1)

print(f"{unfolded_pred.shape=}")
print(f"{unfolded_labels.shape=}")
print(f"{permuted_unfolded_pred.shape=}")
print(f"{permuted_unfolded_labels.shape=}")

unfolded_pred.shape=torch.Size([3, 8, 121])
unfolded_labels.shape=torch.Size([3, 8, 121])
permuted_unfolded_pred.shape=torch.Size([3, 121, 8, 1])
permuted_unfolded_labels.shape=torch.Size([3, 121, 8, 1])


In [7]:
a_mat = bases.tile((3, 121, 1, 1))
b_mat = permuted_unfolded_pred
bases_weights = torch.linalg.lstsq(a_mat, b_mat).solution

In [8]:
print(bases_weights.shape)
print(area.shape)

torch.Size([3, 121, 255, 1])
torch.Size([255])


In [9]:
pred_areas = (area[None, None, :, None] * bases_weights).squeeze(-1).sum(2)
print(pred_areas.shape)

torch.Size([3, 121])


In [10]:
label_cubes_byte = torch.zeros((unfolded_labels.shape[0], unfolded_labels.shape[2]), dtype=torch.int32, device=device)
for k in range(8):
    label_cubes_byte += unfolded_labels[:, k, :].to(torch.int32) << k
label_cubes_byte.shape

torch.Size([3, 121])

In [11]:
label_areas = torch.zeros((label_cubes_byte.shape[0], label_cubes_byte.shape[1]), dtype=torch.float32, device=device)
for b in range(label_cubes_byte.shape[0]):
    label_areas[b, :] = torch.where(label_cubes_byte[b, :] == 0, label_areas[b, :], area[label_cubes_byte[b, :] - 1])
label_areas.shape

torch.Size([3, 121])

### this is from dice loss
```python
def forward(self, input, target):
    # Apply sigmoid to input (predictions)
    input_sigmoid = torch.sigmoid(input).reshape(-1)
    target_flat = target.reshape(-1)
    
    intersection = (input_sigmoid * target_flat).sum()

    return 1 - (
        (2.0 * intersection + self.smooth)
        / (input_sigmoid.sum() + target_flat.sum() + self.smooth)
    )
```

In [16]:
# batch, cube
label_cubes_byte.shape

torch.Size([3, 121])

In [13]:
# batch, point, weights, 1
bases_weights.shape

torch.Size([3, 121, 255, 1])

In [17]:
# batch, point
label_areas.shape

torch.Size([3, 121])

In [50]:
# intersection = use the base weights (255 items), 
#   then go and get the weighted sum of the label areas at each pixel
#   - so you need to know which cube hits the label
#   - then take the area of that guy

intersection = torch.zeros((bases_weights.shape[0], bases_weights.shape[1]), device=device, dtype=torch.float32)
for i in range(bases_weights.shape[0]):
    w = bases_weights[i, ...].squeeze(-1)  # (points, 255)
    cb = label_cubes_byte[i, ...]   # (points, )
    weight_at_label_cube = torch.where(
        cb==0, 
        torch.zeros((w.shape[0], ), dtype=torch.float32, device=device), 
        w[torch.arange(w.shape[0], device=device), cb-1]
    )
    # print(cb.shape, w.shape, weight_at_label_cube.shape)
    intersection[i, ...] = weight_at_label_cube * label_areas[i, ...]

In [52]:
intersection.shape

torch.Size([3, 121])

In [54]:
smooth = 1e-3
numerator = 2 * intersection
denominator = label_areas.sum() + pred_areas.sum()
dice = 1 - ((numerator + smooth) / (denominator + smooth))
print(dice.shape)

torch.Size([3, 121])


In [58]:
pred_areas

tensor([[0.4656, 0.4793, 0.0988, 0.6923, 0.9604, 0.4678, 0.6983, 0.9603, 0.9553,
         1.0143, 0.4355, 0.2285, 0.2338, 0.0616, 0.2554, 0.6058, 0.9157, 0.6282,
         0.3563, 0.5942, 0.8888, 0.5317, 0.3706, 0.8902, 0.6964, 0.1989, 0.5611,
         1.1239, 0.8275, 0.5722, 0.8289, 1.0813, 0.5819, 0.3365, 0.7988, 0.8631,
         0.9425, 0.7175, 0.3492, 0.3452, 0.6108, 0.6704, 0.7991, 0.5676, 0.4991,
         0.6122, 0.5233, 0.7960, 0.5337, 0.2093, 0.5249, 0.9658, 0.5793, 0.1362,
         0.0590, 0.5400, 0.7695, 0.7022, 0.4762, 0.4695, 0.9368, 1.0532, 1.1782,
         0.9123, 0.5671, 0.2505, 0.2096, 0.2541, 0.6507, 0.9839, 0.4153, 0.1836,
         0.5679, 0.6049, 0.5265, 0.5960, 0.2523, 0.4991, 0.7472, 0.7531, 0.9212,
         0.8747, 1.0340, 1.1916, 0.6446, 0.4539, 1.0000, 0.5745, 0.3911, 0.5702,
         0.4769, 0.3069, 0.1472, 0.7005, 0.8500, 0.8355, 1.1126, 1.1300, 0.5653,
         0.0688, 0.4531, 0.6126, 0.4813, 0.3813, 0.6598, 1.1301, 0.9666, 0.6259,
         0.6153, 0.3575, 0.0

In [None]:
unfolded_labels.shape

In [None]:
smooth = 1e-3
numerator = 
denominator = pred_areas.sum() + 

In [49]:
pred_areas.shape

torch.Size([3, 121])

In [None]:
torch.abs(torch.sum(permuted_unfolded_pred))

In [None]:
torch.abs(torch.sum(a_mat))

In [None]:
a = bases
# b = permuted_unfolded_pred[0, 0]
b = a[:, 0]
print(f"{a.shape = }")
print(f"{b.shape = }")
print(a)
print(b)
sol = torch.linalg.lstsq(bases, b)
print(sol.rank)
print(sol.solution.reshape(-1))

In [None]:
torch.linalg.matrix_rank(a)

In [None]:
a_np = a.cpu().numpy()
b_np = b.cpu().numpy()

weights, residuals, rank, singular_vals = np.linalg.lstsq(a_np, b_np, rcond=None)

print(a_np.shape)
print(b_np.shape)
print(b_np)
print(np.sum(np.abs(weights)))
print(residuals)
print(rank)
print(singular_vals)

In [None]:
for i in range(256):
    print(a_np[:, i])