In [73]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from dipy.core.sphere import disperse_charges, HemiSphere, Sphere
from dipy.direction import peak_directions
from dipy.core.gradients import gradient_table
from dipy.data import get_sphere
from dipy.sims.voxel import multi_tensor, multi_tensor_odf, single_tensor
from mpl_toolkits.mplot3d import Axes3D
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import nibabel as nib

In [74]:
thetas = list(np.load("synthetic_data/thetas.npy"))
phis = list(np.load("synthetic_data/phis.npy"))

hemisphere = HemiSphere(theta=thetas, phi=phis) # We already dispersed charges when building the hemisphere in dipy_test
sphere = Sphere(xyz=np.vstack((hemisphere.vertices, -hemisphere.vertices)))


def detect_peaks(F, relative_peak_threshold, min_separation_angle):
    peak_format = np.zeros((len(F), 42))
    max_peak_count = 0
    for i, sample in enumerate(F):
        # Duplicate the sample for both hemispheres
        F_sphere = np.hstack((sample, sample)) / 2

        # Find peak directions
        directions, values, indices = peak_directions(F_sphere, sphere, relative_peak_threshold, min_separation_angle)
        directions = sample[indices][:, np.newaxis] * directions # multiplying with fractions
        directions_flattened = directions.flatten()
        peak_format[i][0:len(directions_flattened)] = directions_flattened
        if len(directions_flattened) / 3 > max_peak_count:
            max_peak_count = len(directions_flattened) / 3
    return peak_format, max_peak_count

In [115]:
class MatrixFactorizationNet(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MatrixFactorizationNet, self).__init__()
        # Define the network layers
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.fc3 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.fc4 = nn.Linear(hidden_sizes[2], output_size)
        self.softplus = nn.Softplus()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.softplus(self.fc4(x))  # Softplus activation for the output layer
        return x

class MatrixDataset(Dataset):
    def __init__(self, S, F):
        self.S = S
        self.F = F

    def __len__(self):
        return len(self.S)

    def __getitem__(self, idx):
        return self.S[idx], self.F[idx]
    
def apply_sparsity(F, sparsity_threshold=0.9):
    max_value, _ = F.max(dim = 1, keepdim = True)
    F[F < 0.1 * max_value] = 0
    return F


S = np.load("synthetic_data/S.npy")
F = np.load("synthetic_data/F.npy")
print(np.shape(S))
print(np.shape(F))

# Convert them to PyTorch tensors
S = torch.from_numpy(S).float()
# Separate into training and testing set
split = 0.9
num_entries = len(S)
train_len = int(num_entries * split)

S_train = S[:train_len]
F_train = F[:train_len]

S_test = S[train_len:]
F_test = F[train_len:]
batch_size = 1024

criterion = nn.MSELoss()

train_dataset = MatrixDataset(S_train, F_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = MatrixDataset(S_test, F_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

N_test, _ = np.shape(S_test)
N, n_b = np.shape(S_train)
_, n = np.shape(F_train)

input_size = n_b
output_size = n

net = MatrixFactorizationNet(input_size, [512, 256, 128], output_size)
net.load_state_dict(torch.load("trained_model/model_13_12_noluyo.pt", map_location=torch.device('cpu')))

(100000, 65)
(100000, 180)


<All keys matched successfully>

In [77]:
mask = nib.load('real_data/mask.nii').get_fdata().astype(bool)
white_mask = nib.load('real_data/white_matter_mask.nii').get_fdata().astype(bool)

hardi_snr10 = nib.load('real_data/DWIS_hardi-scheme_SNR-10.nii').get_fdata()
hardi_snr20 = nib.load('real_data/DWIS_hardi-scheme_SNR-20.nii').get_fdata()
hardi_snr30 = nib.load('real_data/DWIS_hardi-scheme_SNR-30.nii').get_fdata()

hardi_snr10_masked = hardi_snr10[mask]
hardi_snr20_masked = hardi_snr20[mask]
hardi_snr30_masked = hardi_snr30[mask]
hardi_snr10_masked *= (100/hardi_snr10_masked[0][0])
hardi_snr20_masked *= (100/hardi_snr20_masked[0][0])
hardi_snr30_masked *= (100/hardi_snr30_masked[0][0])

hardi_snr10_white_masked = hardi_snr10[white_mask]
hardi_snr20_white_masked = hardi_snr20[white_mask]
hardi_snr30_white_masked = hardi_snr30[white_mask]
hardi_snr10_white_masked *= (100/hardi_snr10_white_masked[0][0])
hardi_snr20_white_masked *= (100/hardi_snr20_white_masked[0][0])
hardi_snr30_white_masked *= (100/hardi_snr30_white_masked[0][0])

In [63]:
test_set_hardi = torch.from_numpy(hardi_snr30_white_masked).float()
F_test_set_hardi = torch.from_numpy(np.zeros(hardi_snr30_white_masked.shape[0])).float()
test_dataset_hardi = MatrixDataset(test_set_hardi, F_test_set_hardi)
test_loader_hardi = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [116]:
with torch.no_grad():
    total_loss = 0.0
    F_pred_matrix_normalized_list = []

    for S_batch, F_batch in test_loader:
        S_batch_flattened = S_batch.to_dense().view(S_batch.size(0), -1)

        F_batch_pred = net(S_batch_flattened)
        F_batch_pred = apply_sparsity(F_batch_pred)
        F_batch_pred_normalized = F_batch_pred / F_batch_pred.sum(dim=1, keepdim=True)

        loss = criterion(F_batch_pred_normalized, F_batch)

        total_loss += loss.item()
        F_pred_matrix_normalized_list.append(F_batch_pred_normalized)

    # Calculate the average loss over all test samples
    avg_loss = total_loss / len(test_loader)
    print("Average Test Loss:", avg_loss)


Average Test Loss: 0.0003444569720370217


In [109]:
F_pred_np = np.concatenate([tensor.cpu().numpy().flatten() for tensor in F_pred_matrix_normalized_list])
F_pred_np = F_pred_np.reshape((-1,180)).astype(np.double)

In [110]:
F_pred_np[0]

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.09886476, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     

In [111]:
np.count_nonzero(F_pred_np, axis=1)

array([6, 1, 6, ..., 3, 2, 1], dtype=int64)

In [112]:
indices_top10 = np.argpartition(F_pred_np[2], -10)[-10:]

# Get the top 10 largest elements
print(F_pred_np[2][indices_top10])
print(indices_top10)

indices_top3_true = np.argpartition(F_test[2], -3)[-3:]
# Get the top 3 largest elements
print(F_test[2][indices_top3_true])
print(indices_top3_true)

[0.         0.         0.         0.         0.06632972 0.07354342
 0.08396713 0.39832243 0.08077925 0.29705805]
[ 62  60  61   7   0 154  88  93  86   1]
[0.        0.3446201 0.6553799]
[60  1 93]


In [84]:
relative_peak_threshold = 0.1
min_separation_angle = 25

peak_format, peak_count = detect_peaks(F_pred_np, relative_peak_threshold, min_separation_angle)

In [85]:
peak_format

array([[ 0.2258295 ,  0.03338549,  0.47234027, ...,  0.        ,
         0.        ,  0.        ],
       [-0.1632948 , -0.03411156,  0.98598743, ...,  0.        ,
         0.        ,  0.        ],
       [-0.03947505, -0.00283443,  0.39635136, ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [-0.11682584, -0.15266421,  0.70320939, ...,  0.        ,
         0.        ,  0.        ],
       [-0.65627344,  0.42391103,  0.13080321, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.7085509 ,  0.54002708,  0.45423162, ...,  0.        ,
         0.        ,  0.        ]])

In [86]:
peak_count

11.0

In [72]:
pd.read_pickle("real_data/ground_truth_peaks_white_masked.pkl")

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])