In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import scipy.io as sio

model_file = 'BLAHFCN_PATCH_A2C_SH1200_saved_model_split'
dirs = ['H','L','M','N','O']


In [2]:
class SHDataSet(Dataset):
    def __init__(self,X):
        
        self.X = X
        self.to_tensor = transforms.ToTensor()
        
    def __len__(self):
        return np.amax(X.shape)
    
    def __getitem__(self, i):
        
        vec_a = self.X[:,:,:,:,i]
        vec_a = np.transpose(vec_a, (3, 0, 1, 2))
        a = torch.Tensor(vec_a)
    
        return a


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #self.fc1 = nn.Linear(3 * 3 * 3 * 15, 600)
        self.cn1 = nn.Conv3d(15,128,kernel_size=(3,3,3),stride=1,padding=(1,1,1))
        #print(self.cn1.shape)
        #self.fc2 = nn.Linear(600,300)
        self.bn = nn.BatchNorm3d(128)
        self.fc2 = nn.Linear(128*3*3*3, 300)
        self.fc3 = nn.Linear(300,60)
        self.fc4 = nn.Linear(60,200)
        self.fc5 = nn.Linear(200,15)
        
    def forward(self, x):
        x = F.relu(self.cn1(x))
        #print(x.shape)
        dimensions = x.shape
        x = self.bn(x)
        x = x.view(dimensions[0], -1)
        #x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x
        
net = Net()
print(net)  


Net(
  (cn1): Conv3d(15, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bn): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=3456, out_features=300, bias=True)
  (fc3): Linear(in_features=300, out_features=60, bias=True)
  (fc4): Linear(in_features=60, out_features=200, bias=True)
  (fc5): Linear(in_features=200, out_features=15, bias=True)
)


In [4]:
#model = net.load_state_dict(torch.load(model_file))
#torch.load('my_file.pt', map_location=lambda storage, loc: storage)
model = net.load_state_dict(torch.load(model_file, map_location='cpu'))
#model = net.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))

In [5]:

for i in range(len(dirs)):
    apply_to_file = "/Users/kurtschilling/Data/harmonization/patches/%s_A2C_Patches.mat" % (dirs[i])
    nii_file = "/Users/kurtschilling/Data/harmonization/CDMRI_Challenge_2018/Testing_Data/%s/prisma/st/A2C_norm_dwi_SHfitOrder4_EvenOdd2_SH1200.nii.gz" % (dirs[i])
    save_file = "BLAHFCN_PATCHES_s%s_A2C_SH1200.nii.gz" % (dirs[i])
    print(apply_to_file)
    print(nii_file)
    print(save_file)
    # load and reshape
    mat_contents = sio.loadmat(apply_to_file)

    #print(mat_contents)
    X1 = mat_contents['input1200_1']
    #print(X1.shape)
    X2 = mat_contents['input1200_2']
    #print(X2.shape)
    X3 = mat_contents['input1200_3']
    #print(X3.shape)
    X4 = mat_contents['input1200_4']
    #print(X4.shape)
    X5 = mat_contents['input1200_5']
    #print(X5.shape)
    X6 = mat_contents['input1200_6']
    #print(X6.shape)
    X7 = mat_contents['input1200_7']
    #print(X7.shape)
    X8 = mat_contents['input1200_8']
    #print(X8.shape)
    X9 = mat_contents['input1200_9']
    #print(X8.shape)

    dims = X1.shape
    print(dims)

    X = np.empty((3,3,3,15,dims[4]))
    print(X.shape)

    X[0,0,:,:,:] = X1
    X[0,1,:,:,:] = X2
    X[0,2,:,:,:] = X3
    X[1,0,:,:,:] = X4
    X[1,1,:,:,:] = X5
    X[1,2,:,:,:] = X6
    X[2,0,:,:,:] = X7
    X[2,1,:,:,:] = X8
    X[2,1,:,:,:] = X9

    nifti_path = os.path.normpath(nii_file)
    img = nib.load(nifti_path)
    data = img.get_fdata()
    dims = data.shape
    print(dims)
    
    Y = np.zeros((dims[0],dims[1],dims[2],dims[3]))
    print(Y.shape)
    Y = np.reshape(Y,(dims[0]*dims[1]*dims[2],dims[3]))
    print(Y.shape)
    
    shset = SHDataSet(X)
    print(len(shset))

    from torch.utils.data import DataLoader
    test_loader = DataLoader(shset, batch_size=1)

    batch_idx = 0
    for batch_idx, (data) in enumerate(test_loader):
    
        data = Variable(data).float()
    
        net_out = net(data)
        b = net_out.detach().numpy()
        c = np.reshape(b,(1,15))
        #print(c.shape)
        #print(batch_idx)
        Y[batch_idx,:]=c;
    
    
        if batch_idx % 100000 == 0:
            print(batch_idx)
            #print(b)
            
            
    Y = np.reshape(Y,(dims[0],dims[1],dims[2],dims[3]))
    print(Y.shape)
    
    new_img = nib.Nifti1Image(Y,img.affine,img.header)
    nib.save(new_img,save_file)

/Users/kurtschilling/Data/harmonization/patches/H_A2C_Patches.mat
/Users/kurtschilling/Data/harmonization/CDMRI_Challenge_2018/Testing_Data/H/prisma/st/A2C_norm_dwi_SHfitOrder4_EvenOdd2_SH1200.nii.gz
BLAHFCN_PATCHES_sH_A2C_SH1200.nii.gz
(1, 1, 3, 15, 2276736)
(3, 3, 3, 15, 2276736)
(154, 154, 96, 15)
(154, 154, 96, 15)
(2276736, 15)
2276736
0
100000
200000
300000
400000
500000
600000
700000
800000
900000
1000000
1100000
1200000
1300000
1400000
1500000
1600000
1700000
1800000
1900000
2000000
2100000
2200000
(154, 154, 96, 15)
/Users/kurtschilling/Data/harmonization/patches/L_A2C_Patches.mat
/Users/kurtschilling/Data/harmonization/CDMRI_Challenge_2018/Testing_Data/L/prisma/st/A2C_norm_dwi_SHfitOrder4_EvenOdd2_SH1200.nii.gz
BLAHFCN_PATCHES_sL_A2C_SH1200.nii.gz
(1, 1, 3, 15, 2276736)
(3, 3, 3, 15, 2276736)
(154, 154, 96, 15)
(154, 154, 96, 15)
(2276736, 15)
2276736
0
100000
200000
300000
400000
500000
600000
700000
800000
900000
1000000
1100000
1200000
1300000
1400000
1500000
1600000
17000