In [17]:
import pandas as pd
from dipy.core.sphere import HemiSphere, Sphere
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
import glob
from helpers import *


In [18]:
thetas = list(np.load("synthetic_data/thetas.npy"))
phis = list(np.load("synthetic_data/phis.npy"))

thetas_360 = list(np.load("synthetic_data/thetas_360.npy"))
phis_360 = list(np.load("synthetic_data/phis_360.npy"))

thetas_90 = list(np.load("synthetic_data/thetas_90.npy"))
phis_90 = list(np.load("synthetic_data/phis_90.npy"))

In [3]:
mask = nib.load('real_data/mask.nii').get_fdata().astype(bool)
white_mask = nib.load('real_data/white_matter_mask.nii').get_fdata().astype(bool)

hardi_snr10 = nib.load('real_data/DWIS_hardi-scheme_SNR-10.nii').get_fdata()
hardi_snr20 = nib.load('real_data/DWIS_hardi-scheme_SNR-20.nii').get_fdata()
hardi_snr30 = nib.load('real_data/DWIS_hardi-scheme_SNR-30.nii').get_fdata()

hardi_snr10_masked = hardi_snr10[mask]
hardi_snr20_masked = hardi_snr20[mask]
hardi_snr30_masked = hardi_snr30[mask]
hardi_snr10_masked *= (100/hardi_snr10_masked[0][0])
hardi_snr20_masked *= (100/hardi_snr20_masked[0][0])
hardi_snr30_masked *= (100/hardi_snr30_masked[0][0])

hardi_snr10_white_masked = hardi_snr10[white_mask]
hardi_snr20_white_masked = hardi_snr20[white_mask]
hardi_snr30_white_masked = hardi_snr30[white_mask]
hardi_snr10_white_masked *= (100/hardi_snr10_white_masked[0][0])
hardi_snr20_white_masked *= (100/hardi_snr20_white_masked[0][0])
hardi_snr30_white_masked *= (100/hardi_snr30_white_masked[0][0])

gt_white_masked = pd.read_pickle("real_data/ground_truth_peaks_white_masked.pkl")

In [20]:
validation = pd.read_pickle("real_data/validation_set.pkl")
validation_peaks = pd.read_pickle("real_data/gt_white_masked_validation.pkl")

In [21]:
batch_size = 1024
test_set_hardi = torch.from_numpy(validation).float()
F_test_set_hardi = torch.from_numpy(np.zeros(validation.shape[0])).float()
test_dataset_hardi = MatrixDataset(test_set_hardi, F_test_set_hardi)
test_loader_hardi = DataLoader(test_dataset_hardi, batch_size=batch_size, shuffle=False)

In [22]:
trained_files = glob.glob("trained_model/models/*/*.pt")

for model in trained_files:
    hidden_sizes = None
    net = None
    model_name = model.split('\\')[-1][:-3]
    dataset_name = model.split('\\')[-2]
    output_dim = 180
    hemisphere = HemiSphere(theta=thetas, phi=phis)
    
    for parameter in model_name.split('_'):
        key, value = parameter.split('-', 1)
        if key == 'od':
            output_dim = int(value)
            if output_dim == 90:
                hemisphere = HemiSphere(theta=thetas_90, phi=phis_90)
            elif output_dim == 360:
                hemisphere = HemiSphere(theta=thetas_360, phi=phis_360)
        if key == 'hs':
            hidden_sizes = value.split(',')
            hidden_sizes = [int(num) for num in hidden_sizes]
        if key == 'BN':
            if value == 'y':
                if len(hidden_sizes) == 3:
                    net = MatrixFactorizationBatchNormalizationNet_4_layer(65, hidden_sizes, output_dim)
                elif len(hidden_sizes) == 2:
                    net = MatrixFactorizationBatchNormalizationNet_3_layer(65, hidden_sizes, output_dim)
                else:
                    net = MatrixFactorizationBatchNormalizationNet_5_layer(65, hidden_sizes, output_dim)
            else:
                net = MatrixFactorizationNet(65, hidden_sizes, 180)
    net.load_state_dict(torch.load(model, map_location=torch.device('cpu')))
    
    for sparsity_threshold in [0.1]:
        net.eval()
        with torch.no_grad():
            F_pred_matrix_normalized_list = []
            for S_batch, F_batch in test_loader_hardi:
                S_batch_flattened = S_batch.to_dense().view(S_batch.size(0), -1)
        
                F_batch_pred = net(S_batch_flattened)
                F_batch_pred = apply_sparsity(F_batch_pred, sparsity_threshold)
                F_batch_pred_normalized = F_batch_pred / (F_batch_pred.sum(dim=1, keepdim=True)+1e-8)
        
                F_pred_matrix_normalized_list.append(F_batch_pred_normalized)
            
        F_pred_np = np.concatenate([tensor.cpu().numpy().flatten() for tensor in F_pred_matrix_normalized_list])
        F_pred_np = F_pred_np.reshape((-1,output_dim)).astype(np.double)
        
        
        relative_peak_threshold = 0.1
        min_separation_angle = 25
    
        peak_format, peak_count, min_peak_count = detect_peaks(F_pred_np, hemisphere, 
        relative_peak_threshold, min_separation_angle)
        print(model_name, peak_count)
        pd.to_pickle(peak_format, "trained_model/peak_results/" + dataset_name + "_" +  model_name + "_sparsity" + str(sparsity_threshold) + ".pkl")
        pd.to_pickle(F_pred_np, "trained_model/predicted_odfs/" + dataset_name + "_" +  model_name + "_sparsity" + str(sparsity_threshold) + ".pkl")
            

od-360_ds-1.2e6_hs-512,256,128_BN-y_loss-F_lr-0.01 7.0
od-90_ds-1.2e6_hs-512,256,128_BN-y_loss-SF_lr-0.001 8.0
ds-6e5_hs-512,256,128_BN-y_loss-F_lr-0.001 7.0
ds-6e5_hs-512,256,128_BN-y_loss-F_lr-0.001 8.0
ds-1.2e6_hs-1024,512,256_BN-y_loss-F_lr-0.001 5.0
ds-1.2e6_hs-1024,512,256_BN-y_loss-F_lr-0.01 6.0
ds-1.2e6_hs-1024,512,256_BN-y_loss-SF_lr-0.001_WD-y 9.0
ds-1.2e6_hs-256,512,256_BN-y_loss-F_lr-0.001 7.0
ds-1.2e6_hs-512,256,128_BN-y_loss-F_lr-0.001 5.0
ds-1.2e6_hs-512,256,128_BN-y_loss-SF_lr-0.001 6.0
ds-3e5_hs-512,256,128_BN-y_loss-F_lr-0.001 5.0
ds-3e5_hs-65,128,180_BN-y_loss-F_lr-0.001 6.0
ds-6e5_hs-128,256,256,128_BN-y_loss-F_lr-0.001 6.0
ds-6e5_hs-128,256,512_BN-y_loss-F_lr-0.001 5.0
ds-6e5_hs-512,256,128_BN-y_loss-F_lr-0.001 5.0
ds-6e5_hs-512,256,128_BN-y_loss-F_lr-0.01 5.0
ds-6e5_hs-512,256,128_BN-y_loss-SF_lr-0.001 6.0
ds-6e5_hs-512,256,128_BN-y_loss-SF_lr-0.001_WD-y 7.0
ds-6e5_hs-64,128,256_BN-y_loss-F_lr-0.001 7.0
ds-6e5_hs-256,512,256_BN-y_loss-SF_lr-0.001_WD-n 7.0
ds-6e5_h